🧵 Add multithreading when getting image assets

This commit is contained in:
2024-05-17 11:18:28 +02:00
parent 08f3e09b4b
commit e58e41a911
7 changed files with 49 additions and 39 deletions

View File

@@ -32,7 +32,7 @@ class A1111AIImageEngine(BaseAIImageEngine):
super().__init__() super().__init__()
def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip: def generate(self, prompt: str, start: float, end: float, i= "") -> mp.ImageClip:
max_width = self.ctx.width / 3 * 2 max_width = self.ctx.width / 3 * 2
try: try:
url = self.base_url + "/sdapi/v1/txt2img" url = self.base_url + "/sdapi/v1/txt2img"
@@ -43,8 +43,8 @@ class A1111AIImageEngine(BaseAIImageEngine):
} }
response = requests.post(url, json=payload) response = requests.post(url, json=payload)
response.raise_for_status() response.raise_for_status()
fname = f"temp{i}.png"
with open("temp.png", "wb") as f: with open(fname, "wb") as f:
f.write(base64.b64decode(response.json()["images"][0])) f.write(base64.b64decode(response.json()["images"][0]))
except Exception as e: except Exception as e:
gr.Warning(f"Failed to get image: {e}") gr.Warning(f"Failed to get image: {e}")
@@ -53,8 +53,8 @@ class A1111AIImageEngine(BaseAIImageEngine):
.with_duration(end - start) .with_duration(end - start)
.with_start(start) .with_start(start)
) )
img = mp.ImageClip("temp.png") img = mp.ImageClip(fname)
os.remove("temp.png") os.remove(fname)
position = ("center", "center") position = ("center", "center")
img = ( img = (

View File

@@ -11,7 +11,7 @@ class BaseAIImageEngine(BaseEngine):
""" """
@abstractmethod @abstractmethod
def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip: def generate(self, prompt: str, start: float, end: float, i = "") -> mp.ImageClip:
""" """
Ge Ge
""" """

View File

@@ -33,7 +33,7 @@ class DallEAIImageEngine(BaseAIImageEngine):
super().__init__() super().__init__()
def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip: def generate(self, prompt: str, start: float, end: float, i = "") -> mp.ImageClip:
max_width = self.ctx.width / 3 * 2 max_width = self.ctx.width / 3 * 2
size: Literal["1024x1024", "1024x1792", "1792x1024"] = ( size: Literal["1024x1024", "1024x1792", "1792x1024"] = (
"1024x1024" "1024x1024"
@@ -61,10 +61,11 @@ class DallEAIImageEngine(BaseAIImageEngine):
else: else:
raise raise
img_bytes = requests.get(response.data[0].url) img_bytes = requests.get(response.data[0].url)
with open("temp.png", "wb") as f: fname = f"temp{i}.png"
with open(fname, "wb") as f:
f.write(img_bytes.content) f.write(img_bytes.content)
img = mp.ImageClip("temp.png") img = mp.ImageClip(fname)
os.remove("temp.png") os.remove(fname)
position = ("center", "center") position = ("center", "center")
img = ( img = (

View File

@@ -3,6 +3,8 @@ import os
import gradio as gr import gradio as gr
import moviepy as mp import moviepy as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
from . import BasePipeline from . import BasePipeline
from ... import engines from ... import engines
from ...chore import GenerationContext from ...chore import GenerationContext
@@ -24,6 +26,27 @@ class ScriptedVideoPipeline(BasePipeline):
self.height = options[4] self.height = options[4]
super().__init__() super().__init__()
def get_asset(self, asset: dict[str, str | float], i) -> mp.VideoClip:
if asset["type"] == "stock":
return self.ctx.stockimageengine.get(
asset["query"], asset["start"], asset["end"], i
)
elif asset["type"] == "ai":
return self.ctx.aiimageengine.generate(
asset["prompt"], asset["start"], asset["end"], i
)
def get_assets_concurrent(self, assets: list[dict[str, str]]) -> list[mp.VideoClip]:
results = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.get_asset, asset, i) for i, asset in enumerate(assets)]
for future in as_completed(futures):
try:
results.append(future.result())
except Exception as e:
gr.Warning(f"Failed to generate an asset: {e}")
return results
def launch(self, ctx: GenerationContext) -> None: def launch(self, ctx: GenerationContext) -> None:
ctx.progress(0.1, "Loading settings...") ctx.progress(0.1, "Loading settings...")
@@ -119,25 +142,12 @@ class ScriptedVideoPipeline(BasePipeline):
max_tokens=4096, max_tokens=4096,
json_mode=True, json_mode=True,
)["assets"] )["assets"]
for i, asset in enumerate(assets): for asset in assets:
if asset["type"] == "stock": asset["start"] += ctx.duration
ctx.progress(0.5, f"Getting stock image {i + 1}...") asset["end"] += ctx.duration
ctx.index_4.append( ctx.progress(0.2, f"Generating assets for chapter: {chapter['title']}...")
ctx.stockimageengine.get( clips = self.get_assets_concurrent(assets)
asset["query"], ctx.index_5.extend(clips)
asset["start"] + ctx.duration,
asset["end"] + ctx.duration,
)
)
elif asset["type"] == "ai":
ctx.progress(0.5, f"Generating AI image {i + 1}...")
ctx.index_5.append(
ctx.aiimageengine.generate(
asset["prompt"],
asset["start"] + ctx.duration,
asset["end"] + ctx.duration,
)
)
ctx.duration += duration + 0.5 ctx.duration += duration + 0.5
ctx.audio.extend(text_audio) ctx.audio.extend(text_audio)
@@ -169,7 +179,7 @@ class ScriptedVideoPipeline(BasePipeline):
.with_audio(audio) .with_audio(audio)
) )
clip.write_videofile( clip.write_videofile(
ctx.get_file_path("final.mp4"), fps=60, threads=4, codec="h264_nvenc" ctx.get_file_path("final.mp4"), fps=60, threads=16, codec="av1_nvenc"
) )
system = prompts["description"]["system"] system = prompts["description"]["system"]
chat = prompts["description"]["chat"] chat = prompts["description"]["chat"]
@@ -188,7 +198,6 @@ class ScriptedVideoPipeline(BasePipeline):
ctx.title, ctx.description, ctx.get_file_path("final.mp4") ctx.title, ctx.description, ctx.get_file_path("final.mp4")
) )
except Exception as e: except Exception as e:
print(e)
gr.Warning(f"{engine.name} failed to upload the video.") gr.Warning(f"{engine.name} failed to upload the video.")
ctx.progress(0.99, "Storing in database...") ctx.progress(0.99, "Storing in database...")

View File

@@ -1 +0,0 @@
from .SettingsEngine import SettingsEngine

View File

@@ -11,7 +11,7 @@ class BaseStockImageEngine(BaseEngine):
""" """
@abstractmethod @abstractmethod
def get(self, query: str, start: float, end: float) -> mp.ImageClip: def get(self, query: str, start: float, end: float, i = "") -> mp.ImageClip:
""" """
Get a stock image based on a query. Get a stock image based on a query.
@@ -19,6 +19,7 @@ class BaseStockImageEngine(BaseEngine):
query (str): The query to search for. query (str): The query to search for.
start (float): The starting time of the video clip. start (float): The starting time of the video clip.
end (float): The ending time of the video clip. end (float): The ending time of the video clip.
i (str): Unique identifier for the image, mandatory if running concurrently.
Returns: Returns:
str: The path to the saved image. str: The path to the saved image.

View File

@@ -28,22 +28,22 @@ class GoogleStockImageEngine(BaseStockImageEngine):
self.google = GoogleImagesSearch(api_key, project_cx) self.google = GoogleImagesSearch(api_key, project_cx)
super().__init__() super().__init__()
def get(self, query: str, start: float, end: float) -> mp.ImageClip: def get(self, query: str, start: float, end: float, i="") -> mp.ImageClip:
max_width = int(self.ctx.width / 3 * 2) max_width = int(self.ctx.width / 3 * 2)
_search_params = { _search_params = {
"q": query, "q": query,
"num": 1, "num": 1,
} }
os.makedirs("temp", exist_ok=True) os.makedirs(f"temp{i}", exist_ok=True)
try: try:
self.google.search( self.google.search(
search_params=_search_params, search_params=_search_params,
path_to_dir="./temp/", path_to_dir=f"./temp{i}/",
custom_image_name="temp", custom_image_name="temp",
) )
# we find the file called temp. extension # we find the file called temp. extension
filename = [f for f in os.listdir("./temp/") if f.startswith("temp.")][0] filename = [f for f in os.listdir(f"./temp{i}/") if f.startswith("temp.")][0]
img = mp.ImageClip(f"./temp/{filename}") img = mp.ImageClip(f"./temp{i}/{filename}")
# delete the temp folder # delete the temp folder
except Exception as e: except Exception as e:
gr.Warning(f"Failed to get image: {e}") gr.Warning(f"Failed to get image: {e}")
@@ -53,7 +53,7 @@ class GoogleStockImageEngine(BaseStockImageEngine):
.with_start(start) .with_start(start)
) )
finally: finally:
shutil.rmtree("temp") shutil.rmtree(f"temp{i}")
img = ( img = (
img.with_duration(end - start) img.with_duration(end - start)