mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
🧵 Add multithreading when getting image assets
This commit is contained in:
@@ -32,7 +32,7 @@ class A1111AIImageEngine(BaseAIImageEngine):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip:
|
def generate(self, prompt: str, start: float, end: float, i= "") -> mp.ImageClip:
|
||||||
max_width = self.ctx.width / 3 * 2
|
max_width = self.ctx.width / 3 * 2
|
||||||
try:
|
try:
|
||||||
url = self.base_url + "/sdapi/v1/txt2img"
|
url = self.base_url + "/sdapi/v1/txt2img"
|
||||||
@@ -43,8 +43,8 @@ class A1111AIImageEngine(BaseAIImageEngine):
|
|||||||
}
|
}
|
||||||
response = requests.post(url, json=payload)
|
response = requests.post(url, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
fname = f"temp{i}.png"
|
||||||
with open("temp.png", "wb") as f:
|
with open(fname, "wb") as f:
|
||||||
f.write(base64.b64decode(response.json()["images"][0]))
|
f.write(base64.b64decode(response.json()["images"][0]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gr.Warning(f"Failed to get image: {e}")
|
gr.Warning(f"Failed to get image: {e}")
|
||||||
@@ -53,8 +53,8 @@ class A1111AIImageEngine(BaseAIImageEngine):
|
|||||||
.with_duration(end - start)
|
.with_duration(end - start)
|
||||||
.with_start(start)
|
.with_start(start)
|
||||||
)
|
)
|
||||||
img = mp.ImageClip("temp.png")
|
img = mp.ImageClip(fname)
|
||||||
os.remove("temp.png")
|
os.remove(fname)
|
||||||
|
|
||||||
position = ("center", "center")
|
position = ("center", "center")
|
||||||
img = (
|
img = (
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class BaseAIImageEngine(BaseEngine):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip:
|
def generate(self, prompt: str, start: float, end: float, i = "") -> mp.ImageClip:
|
||||||
"""
|
"""
|
||||||
Ge
|
Ge
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class DallEAIImageEngine(BaseAIImageEngine):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def generate(self, prompt: str, start: float, end: float) -> mp.ImageClip:
|
def generate(self, prompt: str, start: float, end: float, i = "") -> mp.ImageClip:
|
||||||
max_width = self.ctx.width / 3 * 2
|
max_width = self.ctx.width / 3 * 2
|
||||||
size: Literal["1024x1024", "1024x1792", "1792x1024"] = (
|
size: Literal["1024x1024", "1024x1792", "1792x1024"] = (
|
||||||
"1024x1024"
|
"1024x1024"
|
||||||
@@ -61,10 +61,11 @@ class DallEAIImageEngine(BaseAIImageEngine):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
img_bytes = requests.get(response.data[0].url)
|
img_bytes = requests.get(response.data[0].url)
|
||||||
with open("temp.png", "wb") as f:
|
fname = f"temp{i}.png"
|
||||||
|
with open(fname, "wb") as f:
|
||||||
f.write(img_bytes.content)
|
f.write(img_bytes.content)
|
||||||
img = mp.ImageClip("temp.png")
|
img = mp.ImageClip(fname)
|
||||||
os.remove("temp.png")
|
os.remove(fname)
|
||||||
|
|
||||||
position = ("center", "center")
|
position = ("center", "center")
|
||||||
img = (
|
img = (
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import os
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import moviepy as mp
|
import moviepy as mp
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
from . import BasePipeline
|
from . import BasePipeline
|
||||||
from ... import engines
|
from ... import engines
|
||||||
from ...chore import GenerationContext
|
from ...chore import GenerationContext
|
||||||
@@ -24,6 +26,27 @@ class ScriptedVideoPipeline(BasePipeline):
|
|||||||
self.height = options[4]
|
self.height = options[4]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def get_asset(self, asset: dict[str, str | float], i) -> mp.VideoClip:
|
||||||
|
if asset["type"] == "stock":
|
||||||
|
return self.ctx.stockimageengine.get(
|
||||||
|
asset["query"], asset["start"], asset["end"], i
|
||||||
|
)
|
||||||
|
elif asset["type"] == "ai":
|
||||||
|
return self.ctx.aiimageengine.generate(
|
||||||
|
asset["prompt"], asset["start"], asset["end"], i
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_assets_concurrent(self, assets: list[dict[str, str]]) -> list[mp.VideoClip]:
|
||||||
|
results = []
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [executor.submit(self.get_asset, asset, i) for i, asset in enumerate(assets)]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
try:
|
||||||
|
results.append(future.result())
|
||||||
|
except Exception as e:
|
||||||
|
gr.Warning(f"Failed to generate an asset: {e}")
|
||||||
|
return results
|
||||||
|
|
||||||
def launch(self, ctx: GenerationContext) -> None:
|
def launch(self, ctx: GenerationContext) -> None:
|
||||||
|
|
||||||
ctx.progress(0.1, "Loading settings...")
|
ctx.progress(0.1, "Loading settings...")
|
||||||
@@ -119,25 +142,12 @@ class ScriptedVideoPipeline(BasePipeline):
|
|||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
json_mode=True,
|
json_mode=True,
|
||||||
)["assets"]
|
)["assets"]
|
||||||
for i, asset in enumerate(assets):
|
for asset in assets:
|
||||||
if asset["type"] == "stock":
|
asset["start"] += ctx.duration
|
||||||
ctx.progress(0.5, f"Getting stock image {i + 1}...")
|
asset["end"] += ctx.duration
|
||||||
ctx.index_4.append(
|
ctx.progress(0.2, f"Generating assets for chapter: {chapter['title']}...")
|
||||||
ctx.stockimageengine.get(
|
clips = self.get_assets_concurrent(assets)
|
||||||
asset["query"],
|
ctx.index_5.extend(clips)
|
||||||
asset["start"] + ctx.duration,
|
|
||||||
asset["end"] + ctx.duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif asset["type"] == "ai":
|
|
||||||
ctx.progress(0.5, f"Generating AI image {i + 1}...")
|
|
||||||
ctx.index_5.append(
|
|
||||||
ctx.aiimageengine.generate(
|
|
||||||
asset["prompt"],
|
|
||||||
asset["start"] + ctx.duration,
|
|
||||||
asset["end"] + ctx.duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.duration += duration + 0.5
|
ctx.duration += duration + 0.5
|
||||||
ctx.audio.extend(text_audio)
|
ctx.audio.extend(text_audio)
|
||||||
@@ -169,7 +179,7 @@ class ScriptedVideoPipeline(BasePipeline):
|
|||||||
.with_audio(audio)
|
.with_audio(audio)
|
||||||
)
|
)
|
||||||
clip.write_videofile(
|
clip.write_videofile(
|
||||||
ctx.get_file_path("final.mp4"), fps=60, threads=4, codec="h264_nvenc"
|
ctx.get_file_path("final.mp4"), fps=60, threads=16, codec="av1_nvenc"
|
||||||
)
|
)
|
||||||
system = prompts["description"]["system"]
|
system = prompts["description"]["system"]
|
||||||
chat = prompts["description"]["chat"]
|
chat = prompts["description"]["chat"]
|
||||||
@@ -188,7 +198,6 @@ class ScriptedVideoPipeline(BasePipeline):
|
|||||||
ctx.title, ctx.description, ctx.get_file_path("final.mp4")
|
ctx.title, ctx.description, ctx.get_file_path("final.mp4")
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
|
||||||
gr.Warning(f"{engine.name} failed to upload the video.")
|
gr.Warning(f"{engine.name} failed to upload the video.")
|
||||||
|
|
||||||
ctx.progress(0.99, "Storing in database...")
|
ctx.progress(0.99, "Storing in database...")
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
from .SettingsEngine import SettingsEngine
|
|
||||||
@@ -11,7 +11,7 @@ class BaseStockImageEngine(BaseEngine):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, query: str, start: float, end: float) -> mp.ImageClip:
|
def get(self, query: str, start: float, end: float, i = "") -> mp.ImageClip:
|
||||||
"""
|
"""
|
||||||
Get a stock image based on a query.
|
Get a stock image based on a query.
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ class BaseStockImageEngine(BaseEngine):
|
|||||||
query (str): The query to search for.
|
query (str): The query to search for.
|
||||||
start (float): The starting time of the video clip.
|
start (float): The starting time of the video clip.
|
||||||
end (float): The ending time of the video clip.
|
end (float): The ending time of the video clip.
|
||||||
|
i (str): Unique identifier for the image, mandatory if running concurrently.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the saved image.
|
str: The path to the saved image.
|
||||||
|
|||||||
@@ -28,22 +28,22 @@ class GoogleStockImageEngine(BaseStockImageEngine):
|
|||||||
self.google = GoogleImagesSearch(api_key, project_cx)
|
self.google = GoogleImagesSearch(api_key, project_cx)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get(self, query: str, start: float, end: float) -> mp.ImageClip:
|
def get(self, query: str, start: float, end: float, i="") -> mp.ImageClip:
|
||||||
max_width = int(self.ctx.width / 3 * 2)
|
max_width = int(self.ctx.width / 3 * 2)
|
||||||
_search_params = {
|
_search_params = {
|
||||||
"q": query,
|
"q": query,
|
||||||
"num": 1,
|
"num": 1,
|
||||||
}
|
}
|
||||||
os.makedirs("temp", exist_ok=True)
|
os.makedirs(f"temp{i}", exist_ok=True)
|
||||||
try:
|
try:
|
||||||
self.google.search(
|
self.google.search(
|
||||||
search_params=_search_params,
|
search_params=_search_params,
|
||||||
path_to_dir="./temp/",
|
path_to_dir=f"./temp{i}/",
|
||||||
custom_image_name="temp",
|
custom_image_name="temp",
|
||||||
)
|
)
|
||||||
# we find the file called temp. extension
|
# we find the file called temp. extension
|
||||||
filename = [f for f in os.listdir("./temp/") if f.startswith("temp.")][0]
|
filename = [f for f in os.listdir(f"./temp{i}/") if f.startswith("temp.")][0]
|
||||||
img = mp.ImageClip(f"./temp/{filename}")
|
img = mp.ImageClip(f"./temp{i}/{filename}")
|
||||||
# delete the temp folder
|
# delete the temp folder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gr.Warning(f"Failed to get image: {e}")
|
gr.Warning(f"Failed to get image: {e}")
|
||||||
@@ -53,7 +53,7 @@ class GoogleStockImageEngine(BaseStockImageEngine):
|
|||||||
.with_start(start)
|
.with_start(start)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree("temp")
|
shutil.rmtree(f"temp{i}")
|
||||||
|
|
||||||
img = (
|
img = (
|
||||||
img.with_duration(end - start)
|
img.with_duration(end - start)
|
||||||
|
|||||||
Reference in New Issue
Block a user