From 6c48e19af0f187e4a9ca10902fca5825b1f7c742 Mon Sep 17 00:00:00 2001 From: Paillat Date: Sun, 18 Feb 2024 00:56:49 +0100 Subject: [PATCH] Add AssetsEngine and related files --- src/chore/GenerationContext.py | 21 ++++- .../AssetsEngine/AssetsEngineSelector.py | 30 +++++++ src/engines/AssetsEngine/BaseAssetsEngine.py | 24 ++++++ src/engines/AssetsEngine/DallEAssetsEngine.py | 84 +++++++++++++++++++ src/engines/AssetsEngine/__init__.py | 3 + src/engines/AssetsEngine/prompts/assets.yaml | 45 ++++++++++ src/engines/LLMEngine/BaseLLMEngine.py | 1 + src/engines/__init__.py | 5 ++ ui/gradio_ui.py | 13 +-- 9 files changed, 219 insertions(+), 7 deletions(-) create mode 100644 src/engines/AssetsEngine/AssetsEngineSelector.py create mode 100644 src/engines/AssetsEngine/BaseAssetsEngine.py create mode 100644 src/engines/AssetsEngine/DallEAssetsEngine.py create mode 100644 src/engines/AssetsEngine/__init__.py create mode 100644 src/engines/AssetsEngine/prompts/assets.yaml diff --git a/src/chore/GenerationContext.py b/src/chore/GenerationContext.py index 5701692..49e66eb 100644 --- a/src/chore/GenerationContext.py +++ b/src/chore/GenerationContext.py @@ -15,6 +15,7 @@ class GenerationContext: scriptengine, ttsengine, captioningengine, + assetsengine, ) -> None: self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0] self.powerfulllmengine.ctx = self @@ -33,6 +34,12 @@ class GenerationContext: ) self.captioningengine.ctx = self + self.assetsengine: list[engines.AssetsEngine.BaseAssetsEngine] = assetsengine + for eng in self.assetsengine: + eng.ctx = self + self.assetsengineselector = engines.AssetsEngine.AssetsEngineSelector() + self.assetsengineselector.ctx = self + def setup_dir(self): self.dir = f"output/{time.time()}" os.makedirs(self.dir) @@ -56,6 +63,15 @@ class GenerationContext: self.script, self.get_file_path("tts.wav") ) + self.assetsengine = [ + engine for engine in self.assetsengine if not isinstance(engine, engines.NoneEngine) + ] + if len(self.assetsengine) > 0: + self.assets = self.assetsengineselector.get_assets() + else: + self.assets = [] + + if not isinstance(self.captioningengine, engines.NoneEngine): self.captions = self.captioningengine.get_captions() else: @@ -65,8 +81,9 @@ class GenerationContext: # we render to a file called final.mp4 # using moviepy CompositeVideoClip - - clip = mp.CompositeVideoClip(self.captions, size=(self.width, self.height)) + clips = [*self.assets, *self.captions] + clip = mp.CompositeVideoClip(clips, size=(self.width, self.height)) audio = mp.AudioFileClip(self.get_file_path("tts.wav")) clip = clip.set_audio(audio) clip.write_videofile(self.get_file_path("final.mp4"), fps=60) + \ No newline at end of file diff --git a/src/engines/AssetsEngine/AssetsEngineSelector.py b/src/engines/AssetsEngine/AssetsEngineSelector.py new file mode 100644 index 0000000..b01cd40 --- /dev/null +++ b/src/engines/AssetsEngine/AssetsEngineSelector.py @@ -0,0 +1,30 @@ +import json + +from ...utils.prompting import get_prompt +from ...chore import GenerationContext +class AssetsEngineSelector: + def __init__(self): + self.ctx: GenerationContext + + def get_assets(self): + system_prompt, chat_prompt = get_prompt("assets", by_file_location=__file__) + engines_descriptors = "" + + for engine in self.ctx.assetsengine: + engines_descriptors += f"name: '{engine.name}'\n{json.dumps(engine.specification)}\n" + + system_prompt = system_prompt.replace("{engines}", engines_descriptors) + chat_prompt = chat_prompt.replace("{caption}", json.dumps(self.ctx.timed_script)) + + assets = self.ctx.powerfulllmengine.generate( + system_prompt=system_prompt, + chat_prompt=chat_prompt, + max_tokens=4096, + json_mode=True, + )["assets"] + clips: list = [] + for engine in self.ctx.assetsengine: + assets_opts = [asset for asset in assets if asset["engine"] == engine.name] + assets_opts = [asset["args"] for asset in assets_opts] + clips.extend(engine.get_assets(assets_opts)) + return clips \ No newline at end of file diff --git a/src/engines/AssetsEngine/BaseAssetsEngine.py b/src/engines/AssetsEngine/BaseAssetsEngine.py new file mode 100644 index 0000000..4b7f95b --- /dev/null +++ b/src/engines/AssetsEngine/BaseAssetsEngine.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from ..BaseEngine import BaseEngine +from typing import TypedDict +from moviepy.editor import ImageClip, VideoFileClip + + + +class BaseAssetsEngine(BaseEngine): + """ + The base class for all assets engines. + + Attributes: + specification (dict): A dictionary containing the specification of the engine, especially what an object returned by the llm should look like. + spec_name (str): A comprehensive name for the specification for purely llm purposes. + spec_description (str): A comprehensive description for the specification for purely llm purposes. + """ + + specification: dict + spec_name: str + spec_description: str + + @abstractmethod + def get_assets(self, options: list) -> list: + ... diff --git a/src/engines/AssetsEngine/DallEAssetsEngine.py b/src/engines/AssetsEngine/DallEAssetsEngine.py new file mode 100644 index 0000000..4826e07 --- /dev/null +++ b/src/engines/AssetsEngine/DallEAssetsEngine.py @@ -0,0 +1,84 @@ +import gradio as gr +import openai +import moviepy.editor as mp +import io +import base64 +import time +import requests +import os + +from typing import Literal, TypedDict + +from . import BaseAssetsEngine + +class Spec(TypedDict): + prompt: str + start: float + end: float + style: Literal["vivid", "natural"] + +class DallEAssetsEngine(BaseAssetsEngine): + + name = "DALL-E" + description = "A powerful image generation model by OpenAI." + spec_name = "dalle" + spec_description = "Use the dall-e 3 model to generate images from a detailed prompt." + specification = { + "prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]", + "start": "The starting time of the video clip. [float]", + "end": "The ending time of the video clip. [float]", + "style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]" + } + + num_options = 1 + + def __init__(self, options: dict): + self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0] + + super().__init__() + + def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]: + clips = [] + for option in options: + prompt = option["prompt"] + start = option["start"] + end = option["end"] + style = option["style"] + size = "1024x1024" if self.aspect_ratio == "square" else "1024x1792" if self.aspect_ratio == "portrait" else "1792x1024" + try: + response = openai.images.generate( + model="dall-e-3", + prompt=prompt, + size=size, + n=1, + style=style, + response_format="url" + ) + except openai.BadRequestError as e: + if e.code == "content_policy_violation": + #we skip this prompt + continue + else: + raise + img = requests.get(response.data[0].url) + with open("temp.png", "wb") as f: + f.write(img.content) + img = mp.ImageClip("temp.png") + os.remove("temp.png") + + img: mp.ImageClip = img.set_duration(end - start) + img = img.set_start(start) + if self.aspect_ratio == "portrait": + img = img.set_position(("center", "top")) + elif self.aspect_ratio == "landscape": + img = img.set_position(("center", "center")) + elif self.aspect_ratio == "square": + img = img.set_position(("center", "center")) + clips.append(img) + return clips + + @classmethod + def get_options(cls): + return [ + gr.Radio(["portrait", "square", "landscape"], label="Aspect Ratio", value="square") + ] \ No newline at end of file diff --git a/src/engines/AssetsEngine/__init__.py b/src/engines/AssetsEngine/__init__.py new file mode 100644 index 0000000..c714cff --- /dev/null +++ b/src/engines/AssetsEngine/__init__.py @@ -0,0 +1,3 @@ +from .BaseAssetsEngine import BaseAssetsEngine +from .DallEAssetsEngine import DallEAssetsEngine +from .AssetsEngineSelector import AssetsEngineSelector \ No newline at end of file diff --git a/src/engines/AssetsEngine/prompts/assets.yaml b/src/engines/AssetsEngine/prompts/assets.yaml new file mode 100644 index 0000000..d6df9b9 --- /dev/null +++ b/src/engines/AssetsEngine/prompts/assets.yaml @@ -0,0 +1,45 @@ +system: |- + You will be recieving a video script in a json format, like following: + [ + { + "text": "Hello", + "start": 0.00, + "end": 1.00 + }, + { + "text": "World", + "start": 1.00, + "end": 2.00 + }, + ... + ] + + Your job is to add assets for illustrating the video. At your disposition you will have one or more assets engines to use. + Each one of theese engines will have a specification wich will contain some arguments you will need to provide. + You cannot make two assets, even of different types, or even partially overlapping, to be used at the same time. This is VERY important. + Your output should be a json object as follows: + { + "assets": [ + { + "engine": "engine_name", # The name of the engine you used, very important + "args": { + "arg1": "value1", + "arg2": "value2", + ... + }, + { + "engine": "engine_name", + "args": { + "arg1": "value1", + "arg2": "value2", + ... + } + }, + ... + } + ] + } + Here are each of the engines you can use, and their specifications: + {engines} +chat: |- + {caption} \ No newline at end of file diff --git a/src/engines/LLMEngine/BaseLLMEngine.py b/src/engines/LLMEngine/BaseLLMEngine.py index da95196..b2aa5a7 100644 --- a/src/engines/LLMEngine/BaseLLMEngine.py +++ b/src/engines/LLMEngine/BaseLLMEngine.py @@ -12,6 +12,7 @@ class BaseLLMEngine(BaseEngine): chat_prompt: str, max_tokens: int, temperature: float, + json_mode: bool, top_p: float, frequency_penalty: float, presence_penalty: float, diff --git a/src/engines/__init__.py b/src/engines/__init__.py index 4cfce71..fa5f670 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -5,6 +5,7 @@ from . import TTSEngine from . import ScriptEngine from . import LLMEngine from . import CaptioningEngine +from . import AssetsEngine class EngineDict(TypedDict): @@ -36,4 +37,8 @@ ENGINES: dict[str, EngineDict] = { "classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine], "multiple": False, }, + "AssetsEngine": { + "classes": [AssetsEngine.DallEAssetsEngine, NoneEngine], + "multiple": True, + }, } diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 329cb30..7588a3e 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -13,10 +13,12 @@ class GenerateUI: """ def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]: - def switch(selected: str): + def switch(selected: str | list[str]): + if isinstance(selected, str): + selected = [selected] returnable = [] for i, name in enumerate(engine_names): - returnable.append(gr.update(visible=name == selected)) + returnable.append(gr.update(visible=name in selected)) return returnable @@ -51,14 +53,15 @@ class GenerateUI: choices=engine_names, value=engine_names[0], multiselect=multiselect, - label="Engine provider:" + label="Engine provider:" if not multiselect else "Engine providers:", ) inputs.append(engine_dropdown) engine_rows = [] for i, engine in enumerate(engines): - with gr.Row( - equal_height=True, visible=(i == 0) + with gr.Group( + visible=(i == 0) ) as engine_row: + gr.Label(engine.name) engine_rows.append(engine_row) options = engine.get_options() inputs.extend(options)