mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Add AssetsEngine and related files
This commit is contained in:
@@ -15,6 +15,7 @@ class GenerationContext:
|
|||||||
scriptengine,
|
scriptengine,
|
||||||
ttsengine,
|
ttsengine,
|
||||||
captioningengine,
|
captioningengine,
|
||||||
|
assetsengine,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0]
|
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0]
|
||||||
self.powerfulllmengine.ctx = self
|
self.powerfulllmengine.ctx = self
|
||||||
@@ -33,6 +34,12 @@ class GenerationContext:
|
|||||||
)
|
)
|
||||||
self.captioningengine.ctx = self
|
self.captioningengine.ctx = self
|
||||||
|
|
||||||
|
self.assetsengine: list[engines.AssetsEngine.BaseAssetsEngine] = assetsengine
|
||||||
|
for eng in self.assetsengine:
|
||||||
|
eng.ctx = self
|
||||||
|
self.assetsengineselector = engines.AssetsEngine.AssetsEngineSelector()
|
||||||
|
self.assetsengineselector.ctx = self
|
||||||
|
|
||||||
def setup_dir(self):
|
def setup_dir(self):
|
||||||
self.dir = f"output/{time.time()}"
|
self.dir = f"output/{time.time()}"
|
||||||
os.makedirs(self.dir)
|
os.makedirs(self.dir)
|
||||||
@@ -56,6 +63,15 @@ class GenerationContext:
|
|||||||
self.script, self.get_file_path("tts.wav")
|
self.script, self.get_file_path("tts.wav")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assetsengine = [
|
||||||
|
engine for engine in self.assetsengine if not isinstance(engine, engines.NoneEngine)
|
||||||
|
]
|
||||||
|
if len(self.assetsengine) > 0:
|
||||||
|
self.assets = self.assetsengineselector.get_assets()
|
||||||
|
else:
|
||||||
|
self.assets = []
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(self.captioningengine, engines.NoneEngine):
|
if not isinstance(self.captioningengine, engines.NoneEngine):
|
||||||
self.captions = self.captioningengine.get_captions()
|
self.captions = self.captioningengine.get_captions()
|
||||||
else:
|
else:
|
||||||
@@ -65,8 +81,9 @@ class GenerationContext:
|
|||||||
|
|
||||||
# we render to a file called final.mp4
|
# we render to a file called final.mp4
|
||||||
# using moviepy CompositeVideoClip
|
# using moviepy CompositeVideoClip
|
||||||
|
clips = [*self.assets, *self.captions]
|
||||||
clip = mp.CompositeVideoClip(self.captions, size=(self.width, self.height))
|
clip = mp.CompositeVideoClip(clips, size=(self.width, self.height))
|
||||||
audio = mp.AudioFileClip(self.get_file_path("tts.wav"))
|
audio = mp.AudioFileClip(self.get_file_path("tts.wav"))
|
||||||
clip = clip.set_audio(audio)
|
clip = clip.set_audio(audio)
|
||||||
clip.write_videofile(self.get_file_path("final.mp4"), fps=60)
|
clip.write_videofile(self.get_file_path("final.mp4"), fps=60)
|
||||||
|
|
||||||
30
src/engines/AssetsEngine/AssetsEngineSelector.py
Normal file
30
src/engines/AssetsEngine/AssetsEngineSelector.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from ...utils.prompting import get_prompt
|
||||||
|
from ...chore import GenerationContext
|
||||||
|
class AssetsEngineSelector:
|
||||||
|
def __init__(self):
|
||||||
|
self.ctx: GenerationContext
|
||||||
|
|
||||||
|
def get_assets(self):
|
||||||
|
system_prompt, chat_prompt = get_prompt("assets", by_file_location=__file__)
|
||||||
|
engines_descriptors = ""
|
||||||
|
|
||||||
|
for engine in self.ctx.assetsengine:
|
||||||
|
engines_descriptors += f"name: '{engine.name}'\n{json.dumps(engine.specification)}\n"
|
||||||
|
|
||||||
|
system_prompt = system_prompt.replace("{engines}", engines_descriptors)
|
||||||
|
chat_prompt = chat_prompt.replace("{caption}", json.dumps(self.ctx.timed_script))
|
||||||
|
|
||||||
|
assets = self.ctx.powerfulllmengine.generate(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
chat_prompt=chat_prompt,
|
||||||
|
max_tokens=4096,
|
||||||
|
json_mode=True,
|
||||||
|
)["assets"]
|
||||||
|
clips: list = []
|
||||||
|
for engine in self.ctx.assetsengine:
|
||||||
|
assets_opts = [asset for asset in assets if asset["engine"] == engine.name]
|
||||||
|
assets_opts = [asset["args"] for asset in assets_opts]
|
||||||
|
clips.extend(engine.get_assets(assets_opts))
|
||||||
|
return clips
|
||||||
24
src/engines/AssetsEngine/BaseAssetsEngine.py
Normal file
24
src/engines/AssetsEngine/BaseAssetsEngine.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from ..BaseEngine import BaseEngine
|
||||||
|
from typing import TypedDict
|
||||||
|
from moviepy.editor import ImageClip, VideoFileClip
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAssetsEngine(BaseEngine):
|
||||||
|
"""
|
||||||
|
The base class for all assets engines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
specification (dict): A dictionary containing the specification of the engine, especially what an object returned by the llm should look like.
|
||||||
|
spec_name (str): A comprehensive name for the specification for purely llm purposes.
|
||||||
|
spec_description (str): A comprehensive description for the specification for purely llm purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
specification: dict
|
||||||
|
spec_name: str
|
||||||
|
spec_description: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_assets(self, options: list) -> list:
|
||||||
|
...
|
||||||
84
src/engines/AssetsEngine/DallEAssetsEngine.py
Normal file
84
src/engines/AssetsEngine/DallEAssetsEngine.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import openai
|
||||||
|
import moviepy.editor as mp
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
|
from . import BaseAssetsEngine
|
||||||
|
|
||||||
|
class Spec(TypedDict):
|
||||||
|
prompt: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
style: Literal["vivid", "natural"]
|
||||||
|
|
||||||
|
class DallEAssetsEngine(BaseAssetsEngine):
|
||||||
|
|
||||||
|
name = "DALL-E"
|
||||||
|
description = "A powerful image generation model by OpenAI."
|
||||||
|
spec_name = "dalle"
|
||||||
|
spec_description = "Use the dall-e 3 model to generate images from a detailed prompt."
|
||||||
|
specification = {
|
||||||
|
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
|
||||||
|
"start": "The starting time of the video clip. [float]",
|
||||||
|
"end": "The ending time of the video clip. [float]",
|
||||||
|
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]"
|
||||||
|
}
|
||||||
|
|
||||||
|
num_options = 1
|
||||||
|
|
||||||
|
def __init__(self, options: dict):
|
||||||
|
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]:
|
||||||
|
clips = []
|
||||||
|
for option in options:
|
||||||
|
prompt = option["prompt"]
|
||||||
|
start = option["start"]
|
||||||
|
end = option["end"]
|
||||||
|
style = option["style"]
|
||||||
|
size = "1024x1024" if self.aspect_ratio == "square" else "1024x1792" if self.aspect_ratio == "portrait" else "1792x1024"
|
||||||
|
try:
|
||||||
|
response = openai.images.generate(
|
||||||
|
model="dall-e-3",
|
||||||
|
prompt=prompt,
|
||||||
|
size=size,
|
||||||
|
n=1,
|
||||||
|
style=style,
|
||||||
|
response_format="url"
|
||||||
|
)
|
||||||
|
except openai.BadRequestError as e:
|
||||||
|
if e.code == "content_policy_violation":
|
||||||
|
#we skip this prompt
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
img = requests.get(response.data[0].url)
|
||||||
|
with open("temp.png", "wb") as f:
|
||||||
|
f.write(img.content)
|
||||||
|
img = mp.ImageClip("temp.png")
|
||||||
|
os.remove("temp.png")
|
||||||
|
|
||||||
|
img: mp.ImageClip = img.set_duration(end - start)
|
||||||
|
img = img.set_start(start)
|
||||||
|
if self.aspect_ratio == "portrait":
|
||||||
|
img = img.set_position(("center", "top"))
|
||||||
|
elif self.aspect_ratio == "landscape":
|
||||||
|
img = img.set_position(("center", "center"))
|
||||||
|
elif self.aspect_ratio == "square":
|
||||||
|
img = img.set_position(("center", "center"))
|
||||||
|
clips.append(img)
|
||||||
|
return clips
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_options(cls):
|
||||||
|
return [
|
||||||
|
gr.Radio(["portrait", "square", "landscape"], label="Aspect Ratio", value="square")
|
||||||
|
]
|
||||||
3
src/engines/AssetsEngine/__init__.py
Normal file
3
src/engines/AssetsEngine/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .BaseAssetsEngine import BaseAssetsEngine
|
||||||
|
from .DallEAssetsEngine import DallEAssetsEngine
|
||||||
|
from .AssetsEngineSelector import AssetsEngineSelector
|
||||||
45
src/engines/AssetsEngine/prompts/assets.yaml
Normal file
45
src/engines/AssetsEngine/prompts/assets.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
system: |-
|
||||||
|
You will be recieving a video script in a json format, like following:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"text": "Hello",
|
||||||
|
"start": 0.00,
|
||||||
|
"end": 1.00
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "World",
|
||||||
|
"start": 1.00,
|
||||||
|
"end": 2.00
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
Your job is to add assets for illustrating the video. At your disposition you will have one or more assets engines to use.
|
||||||
|
Each one of theese engines will have a specification wich will contain some arguments you will need to provide.
|
||||||
|
You cannot make two assets, even of different types, or even partially overlapping, to be used at the same time. This is VERY important.
|
||||||
|
Your output should be a json object as follows:
|
||||||
|
{
|
||||||
|
"assets": [
|
||||||
|
{
|
||||||
|
"engine": "engine_name", # The name of the engine you used, very important
|
||||||
|
"args": {
|
||||||
|
"arg1": "value1",
|
||||||
|
"arg2": "value2",
|
||||||
|
...
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"engine": "engine_name",
|
||||||
|
"args": {
|
||||||
|
"arg1": "value1",
|
||||||
|
"arg2": "value2",
|
||||||
|
...
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
Here are each of the engines you can use, and their specifications:
|
||||||
|
{engines}
|
||||||
|
chat: |-
|
||||||
|
{caption}
|
||||||
@@ -12,6 +12,7 @@ class BaseLLMEngine(BaseEngine):
|
|||||||
chat_prompt: str,
|
chat_prompt: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
json_mode: bool,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
frequency_penalty: float,
|
frequency_penalty: float,
|
||||||
presence_penalty: float,
|
presence_penalty: float,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from . import TTSEngine
|
|||||||
from . import ScriptEngine
|
from . import ScriptEngine
|
||||||
from . import LLMEngine
|
from . import LLMEngine
|
||||||
from . import CaptioningEngine
|
from . import CaptioningEngine
|
||||||
|
from . import AssetsEngine
|
||||||
|
|
||||||
|
|
||||||
class EngineDict(TypedDict):
|
class EngineDict(TypedDict):
|
||||||
@@ -36,4 +37,8 @@ ENGINES: dict[str, EngineDict] = {
|
|||||||
"classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine],
|
"classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine],
|
||||||
"multiple": False,
|
"multiple": False,
|
||||||
},
|
},
|
||||||
|
"AssetsEngine": {
|
||||||
|
"classes": [AssetsEngine.DallEAssetsEngine, NoneEngine],
|
||||||
|
"multiple": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,10 +13,12 @@ class GenerateUI:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
|
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
|
||||||
def switch(selected: str):
|
def switch(selected: str | list[str]):
|
||||||
|
if isinstance(selected, str):
|
||||||
|
selected = [selected]
|
||||||
returnable = []
|
returnable = []
|
||||||
for i, name in enumerate(engine_names):
|
for i, name in enumerate(engine_names):
|
||||||
returnable.append(gr.update(visible=name == selected))
|
returnable.append(gr.update(visible=name in selected))
|
||||||
|
|
||||||
return returnable
|
return returnable
|
||||||
|
|
||||||
@@ -51,14 +53,15 @@ class GenerateUI:
|
|||||||
choices=engine_names,
|
choices=engine_names,
|
||||||
value=engine_names[0],
|
value=engine_names[0],
|
||||||
multiselect=multiselect,
|
multiselect=multiselect,
|
||||||
label="Engine provider:"
|
label="Engine provider:" if not multiselect else "Engine providers:",
|
||||||
)
|
)
|
||||||
inputs.append(engine_dropdown)
|
inputs.append(engine_dropdown)
|
||||||
engine_rows = []
|
engine_rows = []
|
||||||
for i, engine in enumerate(engines):
|
for i, engine in enumerate(engines):
|
||||||
with gr.Row(
|
with gr.Group(
|
||||||
equal_height=True, visible=(i == 0)
|
visible=(i == 0)
|
||||||
) as engine_row:
|
) as engine_row:
|
||||||
|
gr.Label(engine.name)
|
||||||
engine_rows.append(engine_row)
|
engine_rows.append(engine_row)
|
||||||
options = engine.get_options()
|
options = engine.get_options()
|
||||||
inputs.extend(options)
|
inputs.extend(options)
|
||||||
|
|||||||
Reference in New Issue
Block a user