Add AssetsEngine and related files

This commit is contained in:
2024-02-18 00:56:49 +01:00
parent e3229518d4
commit 6c48e19af0
9 changed files with 219 additions and 7 deletions

View File

@@ -15,6 +15,7 @@ class GenerationContext:
scriptengine, scriptengine,
ttsengine, ttsengine,
captioningengine, captioningengine,
assetsengine,
) -> None: ) -> None:
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0] self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0]
self.powerfulllmengine.ctx = self self.powerfulllmengine.ctx = self
@@ -33,6 +34,12 @@ class GenerationContext:
) )
self.captioningengine.ctx = self self.captioningengine.ctx = self
self.assetsengine: list[engines.AssetsEngine.BaseAssetsEngine] = assetsengine
for eng in self.assetsengine:
eng.ctx = self
self.assetsengineselector = engines.AssetsEngine.AssetsEngineSelector()
self.assetsengineselector.ctx = self
def setup_dir(self): def setup_dir(self):
self.dir = f"output/{time.time()}" self.dir = f"output/{time.time()}"
os.makedirs(self.dir) os.makedirs(self.dir)
@@ -56,6 +63,15 @@ class GenerationContext:
self.script, self.get_file_path("tts.wav") self.script, self.get_file_path("tts.wav")
) )
self.assetsengine = [
engine for engine in self.assetsengine if not isinstance(engine, engines.NoneEngine)
]
if len(self.assetsengine) > 0:
self.assets = self.assetsengineselector.get_assets()
else:
self.assets = []
if not isinstance(self.captioningengine, engines.NoneEngine): if not isinstance(self.captioningengine, engines.NoneEngine):
self.captions = self.captioningengine.get_captions() self.captions = self.captioningengine.get_captions()
else: else:
@@ -65,8 +81,9 @@ class GenerationContext:
# we render to a file called final.mp4 # we render to a file called final.mp4
# using moviepy CompositeVideoClip # using moviepy CompositeVideoClip
clips = [*self.assets, *self.captions]
clip = mp.CompositeVideoClip(self.captions, size=(self.width, self.height)) clip = mp.CompositeVideoClip(clips, size=(self.width, self.height))
audio = mp.AudioFileClip(self.get_file_path("tts.wav")) audio = mp.AudioFileClip(self.get_file_path("tts.wav"))
clip = clip.set_audio(audio) clip = clip.set_audio(audio)
clip.write_videofile(self.get_file_path("final.mp4"), fps=60) clip.write_videofile(self.get_file_path("final.mp4"), fps=60)

View File

@@ -0,0 +1,30 @@
import json
from ...utils.prompting import get_prompt
from ...chore import GenerationContext
class AssetsEngineSelector:
def __init__(self):
self.ctx: GenerationContext
def get_assets(self):
system_prompt, chat_prompt = get_prompt("assets", by_file_location=__file__)
engines_descriptors = ""
for engine in self.ctx.assetsengine:
engines_descriptors += f"name: '{engine.name}'\n{json.dumps(engine.specification)}\n"
system_prompt = system_prompt.replace("{engines}", engines_descriptors)
chat_prompt = chat_prompt.replace("{caption}", json.dumps(self.ctx.timed_script))
assets = self.ctx.powerfulllmengine.generate(
system_prompt=system_prompt,
chat_prompt=chat_prompt,
max_tokens=4096,
json_mode=True,
)["assets"]
clips: list = []
for engine in self.ctx.assetsengine:
assets_opts = [asset for asset in assets if asset["engine"] == engine.name]
assets_opts = [asset["args"] for asset in assets_opts]
clips.extend(engine.get_assets(assets_opts))
return clips

View File

@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from ..BaseEngine import BaseEngine
from typing import TypedDict
from moviepy.editor import ImageClip, VideoFileClip
class BaseAssetsEngine(BaseEngine):
"""
The base class for all assets engines.
Attributes:
specification (dict): A dictionary containing the specification of the engine, especially what an object returned by the llm should look like.
spec_name (str): A comprehensive name for the specification for purely llm purposes.
spec_description (str): A comprehensive description for the specification for purely llm purposes.
"""
specification: dict
spec_name: str
spec_description: str
@abstractmethod
def get_assets(self, options: list) -> list:
...

View File

@@ -0,0 +1,84 @@
import gradio as gr
import openai
import moviepy.editor as mp
import io
import base64
import time
import requests
import os
from typing import Literal, TypedDict
from . import BaseAssetsEngine
class Spec(TypedDict):
prompt: str
start: float
end: float
style: Literal["vivid", "natural"]
class DallEAssetsEngine(BaseAssetsEngine):
name = "DALL-E"
description = "A powerful image generation model by OpenAI."
spec_name = "dalle"
spec_description = "Use the dall-e 3 model to generate images from a detailed prompt."
specification = {
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
"start": "The starting time of the video clip. [float]",
"end": "The ending time of the video clip. [float]",
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]"
}
num_options = 1
def __init__(self, options: dict):
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
super().__init__()
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]:
clips = []
for option in options:
prompt = option["prompt"]
start = option["start"]
end = option["end"]
style = option["style"]
size = "1024x1024" if self.aspect_ratio == "square" else "1024x1792" if self.aspect_ratio == "portrait" else "1792x1024"
try:
response = openai.images.generate(
model="dall-e-3",
prompt=prompt,
size=size,
n=1,
style=style,
response_format="url"
)
except openai.BadRequestError as e:
if e.code == "content_policy_violation":
#we skip this prompt
continue
else:
raise
img = requests.get(response.data[0].url)
with open("temp.png", "wb") as f:
f.write(img.content)
img = mp.ImageClip("temp.png")
os.remove("temp.png")
img: mp.ImageClip = img.set_duration(end - start)
img = img.set_start(start)
if self.aspect_ratio == "portrait":
img = img.set_position(("center", "top"))
elif self.aspect_ratio == "landscape":
img = img.set_position(("center", "center"))
elif self.aspect_ratio == "square":
img = img.set_position(("center", "center"))
clips.append(img)
return clips
@classmethod
def get_options(cls):
return [
gr.Radio(["portrait", "square", "landscape"], label="Aspect Ratio", value="square")
]

View File

@@ -0,0 +1,3 @@
from .BaseAssetsEngine import BaseAssetsEngine
from .DallEAssetsEngine import DallEAssetsEngine
from .AssetsEngineSelector import AssetsEngineSelector

View File

@@ -0,0 +1,45 @@
system: |-
You will be recieving a video script in a json format, like following:
[
{
"text": "Hello",
"start": 0.00,
"end": 1.00
},
{
"text": "World",
"start": 1.00,
"end": 2.00
},
...
]
Your job is to add assets for illustrating the video. At your disposition you will have one or more assets engines to use.
Each one of theese engines will have a specification wich will contain some arguments you will need to provide.
You cannot make two assets, even of different types, or even partially overlapping, to be used at the same time. This is VERY important.
Your output should be a json object as follows:
{
"assets": [
{
"engine": "engine_name", # The name of the engine you used, very important
"args": {
"arg1": "value1",
"arg2": "value2",
...
},
{
"engine": "engine_name",
"args": {
"arg1": "value1",
"arg2": "value2",
...
}
},
...
}
]
}
Here are each of the engines you can use, and their specifications:
{engines}
chat: |-
{caption}

View File

@@ -12,6 +12,7 @@ class BaseLLMEngine(BaseEngine):
chat_prompt: str, chat_prompt: str,
max_tokens: int, max_tokens: int,
temperature: float, temperature: float,
json_mode: bool,
top_p: float, top_p: float,
frequency_penalty: float, frequency_penalty: float,
presence_penalty: float, presence_penalty: float,

View File

@@ -5,6 +5,7 @@ from . import TTSEngine
from . import ScriptEngine from . import ScriptEngine
from . import LLMEngine from . import LLMEngine
from . import CaptioningEngine from . import CaptioningEngine
from . import AssetsEngine
class EngineDict(TypedDict): class EngineDict(TypedDict):
@@ -36,4 +37,8 @@ ENGINES: dict[str, EngineDict] = {
"classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine], "classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine],
"multiple": False, "multiple": False,
}, },
"AssetsEngine": {
"classes": [AssetsEngine.DallEAssetsEngine, NoneEngine],
"multiple": True,
},
} }

View File

@@ -13,10 +13,12 @@ class GenerateUI:
""" """
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]: def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
def switch(selected: str): def switch(selected: str | list[str]):
if isinstance(selected, str):
selected = [selected]
returnable = [] returnable = []
for i, name in enumerate(engine_names): for i, name in enumerate(engine_names):
returnable.append(gr.update(visible=name == selected)) returnable.append(gr.update(visible=name in selected))
return returnable return returnable
@@ -51,14 +53,15 @@ class GenerateUI:
choices=engine_names, choices=engine_names,
value=engine_names[0], value=engine_names[0],
multiselect=multiselect, multiselect=multiselect,
label="Engine provider:" label="Engine provider:" if not multiselect else "Engine providers:",
) )
inputs.append(engine_dropdown) inputs.append(engine_dropdown)
engine_rows = [] engine_rows = []
for i, engine in enumerate(engines): for i, engine in enumerate(engines):
with gr.Row( with gr.Group(
equal_height=True, visible=(i == 0) visible=(i == 0)
) as engine_row: ) as engine_row:
gr.Label(engine.name)
engine_rows.append(engine_row) engine_rows.append(engine_row)
options = engine.get_options() options = engine.get_options()
inputs.extend(options) inputs.extend(options)