🧵 Add multithreading when getting image assets

This commit is contained in:
2024-05-17 11:18:28 +02:00
parent 08f3e09b4b
commit e58e41a911
7 changed files with 49 additions and 39 deletions

View File

@@ -3,6 +3,8 @@ import os
import gradio as gr
import moviepy as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
from . import BasePipeline
from ... import engines
from ...chore import GenerationContext
@@ -24,6 +26,27 @@ class ScriptedVideoPipeline(BasePipeline):
self.height = options[4]
super().__init__()
def get_asset(self, asset: dict[str, str | float], i) -> mp.VideoClip:
if asset["type"] == "stock":
return self.ctx.stockimageengine.get(
asset["query"], asset["start"], asset["end"], i
)
elif asset["type"] == "ai":
return self.ctx.aiimageengine.generate(
asset["prompt"], asset["start"], asset["end"], i
)
def get_assets_concurrent(self, assets: list[dict[str, str]]) -> list[mp.VideoClip]:
results = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.get_asset, asset, i) for i, asset in enumerate(assets)]
for future in as_completed(futures):
try:
results.append(future.result())
except Exception as e:
gr.Warning(f"Failed to generate an asset: {e}")
return results
def launch(self, ctx: GenerationContext) -> None:
ctx.progress(0.1, "Loading settings...")
@@ -119,25 +142,12 @@ class ScriptedVideoPipeline(BasePipeline):
max_tokens=4096,
json_mode=True,
)["assets"]
for i, asset in enumerate(assets):
if asset["type"] == "stock":
ctx.progress(0.5, f"Getting stock image {i + 1}...")
ctx.index_4.append(
ctx.stockimageengine.get(
asset["query"],
asset["start"] + ctx.duration,
asset["end"] + ctx.duration,
)
)
elif asset["type"] == "ai":
ctx.progress(0.5, f"Generating AI image {i + 1}...")
ctx.index_5.append(
ctx.aiimageengine.generate(
asset["prompt"],
asset["start"] + ctx.duration,
asset["end"] + ctx.duration,
)
)
for asset in assets:
asset["start"] += ctx.duration
asset["end"] += ctx.duration
ctx.progress(0.2, f"Generating assets for chapter: {chapter['title']}...")
clips = self.get_assets_concurrent(assets)
ctx.index_5.extend(clips)
ctx.duration += duration + 0.5
ctx.audio.extend(text_audio)
@@ -169,7 +179,7 @@ class ScriptedVideoPipeline(BasePipeline):
.with_audio(audio)
)
clip.write_videofile(
ctx.get_file_path("final.mp4"), fps=60, threads=4, codec="h264_nvenc"
ctx.get_file_path("final.mp4"), fps=60, threads=16, codec="av1_nvenc"
)
system = prompts["description"]["system"]
chat = prompts["description"]["chat"]
@@ -188,7 +198,6 @@ class ScriptedVideoPipeline(BasePipeline):
ctx.title, ctx.description, ctx.get_file_path("final.mp4")
)
except Exception as e:
print(e)
gr.Warning(f"{engine.name} failed to upload the video.")
ctx.progress(0.99, "Storing in database...")