From cced96d792a9d028f428112b6c09850452b136b5 Mon Sep 17 00:00:00 2001 From: Paillat Date: Tue, 20 Feb 2024 16:23:15 +0100 Subject: [PATCH] Formatting --- src/chore/GenerationContext.py | 6 +++--- .../BackgroundEngine/SimpleBackgroundEngine.py | 18 +++++++++++++----- src/engines/BaseEngine.py | 2 +- .../MetadataEngine/BaseMetadataEngine.py | 4 +++- .../MetadataEngine/ShortsMetadataEngine.py | 11 ++++++++--- src/engines/MetadataEngine/__init__.py | 2 +- src/engines/__init__.py | 1 + ui/gradio_ui.py | 8 +++++++- 8 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/chore/GenerationContext.py b/src/chore/GenerationContext.py index 6ff9230..4e740b7 100644 --- a/src/chore/GenerationContext.py +++ b/src/chore/GenerationContext.py @@ -18,7 +18,7 @@ class GenerationContext: assetsengine, settingsengine, backgroundengine, - progress + progress, ) -> None: self.progress = progress @@ -64,8 +64,8 @@ class GenerationContext: # ⚠️ IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation. # we start by loading the settings - - self.progress(0.1,"Loading settings...") + + self.progress(0.1, "Loading settings...") self.settingsengine.load() self.setup_dir() diff --git a/src/engines/BackgroundEngine/SimpleBackgroundEngine.py b/src/engines/BackgroundEngine/SimpleBackgroundEngine.py index 92f4440..642a218 100644 --- a/src/engines/BackgroundEngine/SimpleBackgroundEngine.py +++ b/src/engines/BackgroundEngine/SimpleBackgroundEngine.py @@ -23,7 +23,11 @@ class SimpleBackgroundEngine(BaseBackgroundEngine): @classmethod def get_options(cls) -> list: assets = cls.get_assets(type="bcg_video") - choices=[asset.data["name"] for asset in assets] if len(assets) > 0 else ["No videos available"] + choices = ( + [asset.data["name"] for asset in assets] + if len(assets) > 0 + else ["No videos available"] + ) return [ gr.Dropdown( @@ -35,9 +39,7 @@ class SimpleBackgroundEngine(BaseBackgroundEngine): ] def get_background(self) -> mp.VideoClip: - background = mp.VideoFileClip( - f"{self.background_video}", audio=False - ) + background = mp.VideoFileClip(f"{self.background_video}", audio=False) background_max_start = background.duration - self.ctx.duration if background_max_start < 0: raise ValueError( @@ -46,7 +48,13 @@ class SimpleBackgroundEngine(BaseBackgroundEngine): start = random.uniform(0, background_max_start) clip = background.subclip(start, start + self.ctx.duration) w, h = clip.size - return crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2) + return crop( + clip, + width=self.ctx.width, + height=self.ctx.height, + x_center=w / 2, + y_center=h / 2, + ) @classmethod def get_settings(cls) -> list: diff --git a/src/engines/BaseEngine.py b/src/engines/BaseEngine.py index 3a56afd..fd589c5 100644 --- a/src/engines/BaseEngine.py +++ b/src/engines/BaseEngine.py @@ -47,7 +47,7 @@ class BaseEngine(ABC): .scalars() .all() ) - + @classmethod def add_asset(cls, *, path: str, metadata: dict, type: str = None): with SessionLocal() as db: diff --git a/src/engines/MetadataEngine/BaseMetadataEngine.py b/src/engines/MetadataEngine/BaseMetadataEngine.py index 5fa4098..2a64a72 100644 --- a/src/engines/MetadataEngine/BaseMetadataEngine.py +++ b/src/engines/MetadataEngine/BaseMetadataEngine.py @@ -3,14 +3,16 @@ from typing import TypedDict from .. import BaseEngine + class MetadataEngineSettings(TypedDict): title: str description: str + class BaseMetadataEngine(BaseEngine): def __init__(self, **kwargs) -> None: ... @abstractmethod def get_metadata(self, input: str) -> MetadataEngineSettings: - ... \ No newline at end of file + ... diff --git a/src/engines/MetadataEngine/ShortsMetadataEngine.py b/src/engines/MetadataEngine/ShortsMetadataEngine.py index 1a45726..99b2622 100644 --- a/src/engines/MetadataEngine/ShortsMetadataEngine.py +++ b/src/engines/MetadataEngine/ShortsMetadataEngine.py @@ -2,15 +2,20 @@ from . import BaseMetadataEngine from ...utils.prompting import get_prompt + class ShortsMetadataEngine(BaseMetadataEngine): def __init__(self, **kwargs) -> None: ... def get_metadata(self): - sytsem_prompt, chat_prompt = get_prompt("ShortsMetadata", by_file_location=__file__) + sytsem_prompt, chat_prompt = get_prompt( + "ShortsMetadata", by_file_location=__file__ + ) chat_prompt = chat_prompt.replace("{script}", self.ctx.script) - return self.ctx.simplellmengine.generate(chat_prompt=chat_prompt, system_prompt=sytsem_prompt, json_mode=True) + return self.ctx.simplellmengine.generate( + chat_prompt=chat_prompt, system_prompt=sytsem_prompt, json_mode=True + ) def get_options(self): - return [] \ No newline at end of file + return [] diff --git a/src/engines/MetadataEngine/__init__.py b/src/engines/MetadataEngine/__init__.py index 643ffc2..413824c 100644 --- a/src/engines/MetadataEngine/__init__.py +++ b/src/engines/MetadataEngine/__init__.py @@ -1,2 +1,2 @@ from .BaseMetadataEngine import BaseMetadataEngine -from .ShortsMetadataEngine import ShortsMetadataEngine \ No newline at end of file +from .ShortsMetadataEngine import ShortsMetadataEngine diff --git a/src/engines/__init__.py b/src/engines/__init__.py index 72f9553..3126315 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -10,6 +10,7 @@ from . import SettingsEngine from . import BackgroundEngine from . import MetadataEngine + class EngineDict(TypedDict): classes: list[BaseEngine] multiple: bool diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 3c82150..7b9218c 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -4,6 +4,7 @@ import gradio as gr from src.engines import ENGINES, BaseEngine from src.chore import GenerationContext + class GenerateUI: def __init__(self): self.css = """.generate_button { @@ -95,7 +96,11 @@ class GenerateUI: elem_classes="generate_button", ) output_gallery = gr.Markdown("aaa", render=False) - button.click(self.run_generate_interface, inputs=inputs, outputs=output_gallery) + button.click( + self.run_generate_interface, + inputs=inputs, + outputs=output_gallery, + ) output_gallery.render() return interface @@ -106,6 +111,7 @@ class GenerateUI: ctx = GenerationContext(**arugments, progress=progress) ctx.process() # Here we go ! 🚀 return gr.update(value=ctx.get_file_path("final.mp4")) + def repack_options(self, *args) -> dict[BaseEngine]: """ Repacks the options provided as arguments into a dictionary based on the selected engine.