From e58e41a9118034bafa4a49c630a2b08f9a7ba74d Mon Sep 17 00:00:00 2001 From: Paillat Date: Fri, 17 May 2024 11:18:28 +0200 Subject: [PATCH] :thread: Add multithreading when getting image assets --- .../AIImageEngine/A1111AIImageEngine.py | 10 ++-- .../AIImageEngine/BaseAIImageEngine.py | 2 +- .../AIImageEngine/DallEAIImageEngine.py | 9 ++-- .../Pipelines/ScriptedVideoPipeline.py | 51 +++++++++++-------- src/engines/SettingsEngine/__init__.py | 1 - .../StockImageEngine/BaseStockImageEngine.py | 3 +- .../GoogleStockImageEngine.py | 12 ++--- 7 files changed, 49 insertions(+), 39 deletions(-) delete mode 100644 src/engines/SettingsEngine/__init__.py diff --git a/src/engines/AIImageEngine/A1111AIImageEngine.py b/src/engines/AIImageEngine/A1111AIImageEngine.py index d8b1c0e..8a3006e 100644 --- a/src/engines/AIImageEngine/A1111AIImageEngine.py +++ b/src/engines/AIImageEngine/A1111AIImageEngine.py @@ -32,7 +32,7 @@ class A1111AIImageEngine(BaseAIImageEngine): super().__init__() - def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip: + def generate(self, prompt: str, start: float, end: float, i= "") -> mp.ImageClip: max_width = self.ctx.width / 3 * 2 try: url = self.base_url + "/sdapi/v1/txt2img" @@ -43,8 +43,8 @@ class A1111AIImageEngine(BaseAIImageEngine): } response = requests.post(url, json=payload) response.raise_for_status() - - with open("temp.png", "wb") as f: + fname = f"temp{i}.png" + with open(fname, "wb") as f: f.write(base64.b64decode(response.json()["images"][0])) except Exception as e: gr.Warning(f"Failed to get image: {e}") @@ -53,8 +53,8 @@ class A1111AIImageEngine(BaseAIImageEngine): .with_duration(end - start) .with_start(start) ) - img = mp.ImageClip("temp.png") - os.remove("temp.png") + img = mp.ImageClip(fname) + os.remove(fname) position = ("center", "center") img = ( diff --git a/src/engines/AIImageEngine/BaseAIImageEngine.py b/src/engines/AIImageEngine/BaseAIImageEngine.py index 5210698..a1adf90 100644 --- a/src/engines/AIImageEngine/BaseAIImageEngine.py +++ b/src/engines/AIImageEngine/BaseAIImageEngine.py @@ -11,7 +11,7 @@ class BaseAIImageEngine(BaseEngine): """ @abstractmethod - def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip: + def generate(self, prompt: str, start: float, end: float, i = "") -> mp.ImageClip: """ Ge """ diff --git a/src/engines/AIImageEngine/DallEAIImageEngine.py b/src/engines/AIImageEngine/DallEAIImageEngine.py index a44434f..74705fa 100644 --- a/src/engines/AIImageEngine/DallEAIImageEngine.py +++ b/src/engines/AIImageEngine/DallEAIImageEngine.py @@ -33,7 +33,7 @@ class DallEAIImageEngine(BaseAIImageEngine): super().__init__() - def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip: + def generate(self, prompt: str, start: float, end: float, i = "") -> mp.ImageClip: max_width = self.ctx.width / 3 * 2 size: Literal["1024x1024", "1024x1792", "1792x1024"] = ( "1024x1024" @@ -61,10 +61,11 @@ class DallEAIImageEngine(BaseAIImageEngine): else: raise img_bytes = requests.get(response.data[0].url) - with open("temp.png", "wb") as f: + fname = f"temp{i}.png" + with open(fname, "wb") as f: f.write(img_bytes.content) - img = mp.ImageClip("temp.png") - os.remove("temp.png") + img = mp.ImageClip(fname) + os.remove(fname) position = ("center", "center") img = ( diff --git a/src/engines/Pipelines/ScriptedVideoPipeline.py b/src/engines/Pipelines/ScriptedVideoPipeline.py index 0d79b6e..d62f293 100644 --- a/src/engines/Pipelines/ScriptedVideoPipeline.py +++ b/src/engines/Pipelines/ScriptedVideoPipeline.py @@ -3,6 +3,8 @@ import os import gradio as gr import moviepy as mp +from concurrent.futures import ThreadPoolExecutor, as_completed + from . import BasePipeline from ... import engines from ...chore import GenerationContext @@ -24,6 +26,27 @@ class ScriptedVideoPipeline(BasePipeline): self.height = options[4] super().__init__() + def get_asset(self, asset: dict[str, str | float], i) -> mp.VideoClip: + if asset["type"] == "stock": + return self.ctx.stockimageengine.get( + asset["query"], asset["start"], asset["end"], i + ) + elif asset["type"] == "ai": + return self.ctx.aiimageengine.generate( + asset["prompt"], asset["start"], asset["end"], i + ) + + def get_assets_concurrent(self, assets: list[dict[str, str]]) -> list[mp.VideoClip]: + results = [] + with ThreadPoolExecutor() as executor: + futures = [executor.submit(self.get_asset, asset, i) for i, asset in enumerate(assets)] + for future in as_completed(futures): + try: + results.append(future.result()) + except Exception as e: + gr.Warning(f"Failed to generate an asset: {e}") + return results + def launch(self, ctx: GenerationContext) -> None: ctx.progress(0.1, "Loading settings...") @@ -119,25 +142,12 @@ class ScriptedVideoPipeline(BasePipeline): max_tokens=4096, json_mode=True, )["assets"] - for i, asset in enumerate(assets): - if asset["type"] == "stock": - ctx.progress(0.5, f"Getting stock image {i + 1}...") - ctx.index_4.append( - ctx.stockimageengine.get( - asset["query"], - asset["start"] + ctx.duration, - asset["end"] + ctx.duration, - ) - ) - elif asset["type"] == "ai": - ctx.progress(0.5, f"Generating AI image {i + 1}...") - ctx.index_5.append( - ctx.aiimageengine.generate( - asset["prompt"], - asset["start"] + ctx.duration, - asset["end"] + ctx.duration, - ) - ) + for asset in assets: + asset["start"] += ctx.duration + asset["end"] += ctx.duration + ctx.progress(0.2, f"Generating assets for chapter: {chapter['title']}...") + clips = self.get_assets_concurrent(assets) + ctx.index_5.extend(clips) ctx.duration += duration + 0.5 ctx.audio.extend(text_audio) @@ -169,7 +179,7 @@ class ScriptedVideoPipeline(BasePipeline): .with_audio(audio) ) clip.write_videofile( - ctx.get_file_path("final.mp4"), fps=60, threads=4, codec="h264_nvenc" + ctx.get_file_path("final.mp4"), fps=60, threads=16, codec="av1_nvenc" ) system = prompts["description"]["system"] chat = prompts["description"]["chat"] @@ -188,7 +198,6 @@ class ScriptedVideoPipeline(BasePipeline): ctx.title, ctx.description, ctx.get_file_path("final.mp4") ) except Exception as e: - print(e) gr.Warning(f"{engine.name} failed to upload the video.") ctx.progress(0.99, "Storing in database...") diff --git a/src/engines/SettingsEngine/__init__.py b/src/engines/SettingsEngine/__init__.py deleted file mode 100644 index e9f74ba..0000000 --- a/src/engines/SettingsEngine/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .SettingsEngine import SettingsEngine diff --git a/src/engines/StockImageEngine/BaseStockImageEngine.py b/src/engines/StockImageEngine/BaseStockImageEngine.py index f9a1b8c..4aaae5d 100644 --- a/src/engines/StockImageEngine/BaseStockImageEngine.py +++ b/src/engines/StockImageEngine/BaseStockImageEngine.py @@ -11,7 +11,7 @@ class BaseStockImageEngine(BaseEngine): """ @abstractmethod - def get(self, query: str, start: float, end: float) -> mp.ImageClip: + def get(self, query: str, start: float, end: float, i = "") -> mp.ImageClip: """ Get a stock image based on a query. @@ -19,6 +19,7 @@ class BaseStockImageEngine(BaseEngine): query (str): The query to search for. start (float): The starting time of the video clip. end (float): The ending time of the video clip. + i (str): Unique identifier for the image, mandatory if running concurrently. Returns: str: The path to the saved image. diff --git a/src/engines/StockImageEngine/GoogleStockImageEngine.py b/src/engines/StockImageEngine/GoogleStockImageEngine.py index 36ea748..f2b16b7 100644 --- a/src/engines/StockImageEngine/GoogleStockImageEngine.py +++ b/src/engines/StockImageEngine/GoogleStockImageEngine.py @@ -28,22 +28,22 @@ class GoogleStockImageEngine(BaseStockImageEngine): self.google = GoogleImagesSearch(api_key, project_cx) super().__init__() - def get(self, query: str, start: float, end: float) -> mp.ImageClip: + def get(self, query: str, start: float, end: float, i="") -> mp.ImageClip: max_width = int(self.ctx.width / 3 * 2) _search_params = { "q": query, "num": 1, } - os.makedirs("temp", exist_ok=True) + os.makedirs(f"temp{i}", exist_ok=True) try: self.google.search( search_params=_search_params, - path_to_dir="./temp/", + path_to_dir=f"./temp{i}/", custom_image_name="temp", ) # we find the file called temp. extension - filename = [f for f in os.listdir("./temp/") if f.startswith("temp.")][0] - img = mp.ImageClip(f"./temp/{filename}") + filename = [f for f in os.listdir(f"./temp{i}/") if f.startswith("temp.")][0] + img = mp.ImageClip(f"./temp{i}/{filename}") # delete the temp folder except Exception as e: gr.Warning(f"Failed to get image: {e}") @@ -53,7 +53,7 @@ class GoogleStockImageEngine(BaseStockImageEngine): .with_start(start) ) finally: - shutil.rmtree("temp") + shutil.rmtree(f"temp{i}") img = ( img.with_duration(end - start)