mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
🧵 Add multithreading when getting image assets
This commit is contained in:
@@ -3,6 +3,8 @@ import os
|
||||
import gradio as gr
|
||||
import moviepy as mp
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from . import BasePipeline
|
||||
from ... import engines
|
||||
from ...chore import GenerationContext
|
||||
@@ -24,6 +26,27 @@ class ScriptedVideoPipeline(BasePipeline):
|
||||
self.height = options[4]
|
||||
super().__init__()
|
||||
|
||||
def get_asset(self, asset: dict[str, str | float], i) -> mp.VideoClip:
|
||||
if asset["type"] == "stock":
|
||||
return self.ctx.stockimageengine.get(
|
||||
asset["query"], asset["start"], asset["end"], i
|
||||
)
|
||||
elif asset["type"] == "ai":
|
||||
return self.ctx.aiimageengine.generate(
|
||||
asset["prompt"], asset["start"], asset["end"], i
|
||||
)
|
||||
|
||||
def get_assets_concurrent(self, assets: list[dict[str, str]]) -> list[mp.VideoClip]:
|
||||
results = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(self.get_asset, asset, i) for i, asset in enumerate(assets)]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
gr.Warning(f"Failed to generate an asset: {e}")
|
||||
return results
|
||||
|
||||
def launch(self, ctx: GenerationContext) -> None:
|
||||
|
||||
ctx.progress(0.1, "Loading settings...")
|
||||
@@ -119,25 +142,12 @@ class ScriptedVideoPipeline(BasePipeline):
|
||||
max_tokens=4096,
|
||||
json_mode=True,
|
||||
)["assets"]
|
||||
for i, asset in enumerate(assets):
|
||||
if asset["type"] == "stock":
|
||||
ctx.progress(0.5, f"Getting stock image {i + 1}...")
|
||||
ctx.index_4.append(
|
||||
ctx.stockimageengine.get(
|
||||
asset["query"],
|
||||
asset["start"] + ctx.duration,
|
||||
asset["end"] + ctx.duration,
|
||||
)
|
||||
)
|
||||
elif asset["type"] == "ai":
|
||||
ctx.progress(0.5, f"Generating AI image {i + 1}...")
|
||||
ctx.index_5.append(
|
||||
ctx.aiimageengine.generate(
|
||||
asset["prompt"],
|
||||
asset["start"] + ctx.duration,
|
||||
asset["end"] + ctx.duration,
|
||||
)
|
||||
)
|
||||
for asset in assets:
|
||||
asset["start"] += ctx.duration
|
||||
asset["end"] += ctx.duration
|
||||
ctx.progress(0.2, f"Generating assets for chapter: {chapter['title']}...")
|
||||
clips = self.get_assets_concurrent(assets)
|
||||
ctx.index_5.extend(clips)
|
||||
|
||||
ctx.duration += duration + 0.5
|
||||
ctx.audio.extend(text_audio)
|
||||
@@ -169,7 +179,7 @@ class ScriptedVideoPipeline(BasePipeline):
|
||||
.with_audio(audio)
|
||||
)
|
||||
clip.write_videofile(
|
||||
ctx.get_file_path("final.mp4"), fps=60, threads=4, codec="h264_nvenc"
|
||||
ctx.get_file_path("final.mp4"), fps=60, threads=16, codec="av1_nvenc"
|
||||
)
|
||||
system = prompts["description"]["system"]
|
||||
chat = prompts["description"]["chat"]
|
||||
@@ -188,7 +198,6 @@ class ScriptedVideoPipeline(BasePipeline):
|
||||
ctx.title, ctx.description, ctx.get_file_path("final.mp4")
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
gr.Warning(f"{engine.name} failed to upload the video.")
|
||||
|
||||
ctx.progress(0.99, "Storing in database...")
|
||||
|
||||
Reference in New Issue
Block a user