Formatting

This commit is contained in:
2024-02-20 16:23:15 +01:00
parent ce128e400a
commit cced96d792
8 changed files with 37 additions and 15 deletions

View File

@@ -18,7 +18,7 @@ class GenerationContext:
assetsengine, assetsengine,
settingsengine, settingsengine,
backgroundengine, backgroundengine,
progress progress,
) -> None: ) -> None:
self.progress = progress self.progress = progress
@@ -64,8 +64,8 @@ class GenerationContext:
# ⚠️ IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation. # ⚠️ IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
# we start by loading the settings # we start by loading the settings
self.progress(0.1,"Loading settings...") self.progress(0.1, "Loading settings...")
self.settingsengine.load() self.settingsengine.load()
self.setup_dir() self.setup_dir()

View File

@@ -23,7 +23,11 @@ class SimpleBackgroundEngine(BaseBackgroundEngine):
@classmethod @classmethod
def get_options(cls) -> list: def get_options(cls) -> list:
assets = cls.get_assets(type="bcg_video") assets = cls.get_assets(type="bcg_video")
choices=[asset.data["name"] for asset in assets] if len(assets) > 0 else ["No videos available"] choices = (
[asset.data["name"] for asset in assets]
if len(assets) > 0
else ["No videos available"]
)
return [ return [
gr.Dropdown( gr.Dropdown(
@@ -35,9 +39,7 @@ class SimpleBackgroundEngine(BaseBackgroundEngine):
] ]
def get_background(self) -> mp.VideoClip: def get_background(self) -> mp.VideoClip:
background = mp.VideoFileClip( background = mp.VideoFileClip(f"{self.background_video}", audio=False)
f"{self.background_video}", audio=False
)
background_max_start = background.duration - self.ctx.duration background_max_start = background.duration - self.ctx.duration
if background_max_start < 0: if background_max_start < 0:
raise ValueError( raise ValueError(
@@ -46,7 +48,13 @@ class SimpleBackgroundEngine(BaseBackgroundEngine):
start = random.uniform(0, background_max_start) start = random.uniform(0, background_max_start)
clip = background.subclip(start, start + self.ctx.duration) clip = background.subclip(start, start + self.ctx.duration)
w, h = clip.size w, h = clip.size
return crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2) return crop(
clip,
width=self.ctx.width,
height=self.ctx.height,
x_center=w / 2,
y_center=h / 2,
)
@classmethod @classmethod
def get_settings(cls) -> list: def get_settings(cls) -> list:

View File

@@ -47,7 +47,7 @@ class BaseEngine(ABC):
.scalars() .scalars()
.all() .all()
) )
@classmethod @classmethod
def add_asset(cls, *, path: str, metadata: dict, type: str = None): def add_asset(cls, *, path: str, metadata: dict, type: str = None):
with SessionLocal() as db: with SessionLocal() as db:

View File

@@ -3,14 +3,16 @@ from typing import TypedDict
from .. import BaseEngine from .. import BaseEngine
class MetadataEngineSettings(TypedDict): class MetadataEngineSettings(TypedDict):
title: str title: str
description: str description: str
class BaseMetadataEngine(BaseEngine): class BaseMetadataEngine(BaseEngine):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
... ...
@abstractmethod @abstractmethod
def get_metadata(self, input: str) -> MetadataEngineSettings: def get_metadata(self, input: str) -> MetadataEngineSettings:
... ...

View File

@@ -2,15 +2,20 @@ from . import BaseMetadataEngine
from ...utils.prompting import get_prompt from ...utils.prompting import get_prompt
class ShortsMetadataEngine(BaseMetadataEngine): class ShortsMetadataEngine(BaseMetadataEngine):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
... ...
def get_metadata(self): def get_metadata(self):
sytsem_prompt, chat_prompt = get_prompt("ShortsMetadata", by_file_location=__file__) sytsem_prompt, chat_prompt = get_prompt(
"ShortsMetadata", by_file_location=__file__
)
chat_prompt = chat_prompt.replace("{script}", self.ctx.script) chat_prompt = chat_prompt.replace("{script}", self.ctx.script)
return self.ctx.simplellmengine.generate(chat_prompt=chat_prompt, system_prompt=sytsem_prompt, json_mode=True) return self.ctx.simplellmengine.generate(
chat_prompt=chat_prompt, system_prompt=sytsem_prompt, json_mode=True
)
def get_options(self): def get_options(self):
return [] return []

View File

@@ -1,2 +1,2 @@
from .BaseMetadataEngine import BaseMetadataEngine from .BaseMetadataEngine import BaseMetadataEngine
from .ShortsMetadataEngine import ShortsMetadataEngine from .ShortsMetadataEngine import ShortsMetadataEngine

View File

@@ -10,6 +10,7 @@ from . import SettingsEngine
from . import BackgroundEngine from . import BackgroundEngine
from . import MetadataEngine from . import MetadataEngine
class EngineDict(TypedDict): class EngineDict(TypedDict):
classes: list[BaseEngine] classes: list[BaseEngine]
multiple: bool multiple: bool

View File

@@ -4,6 +4,7 @@ import gradio as gr
from src.engines import ENGINES, BaseEngine from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext from src.chore import GenerationContext
class GenerateUI: class GenerateUI:
def __init__(self): def __init__(self):
self.css = """.generate_button { self.css = """.generate_button {
@@ -95,7 +96,11 @@ class GenerateUI:
elem_classes="generate_button", elem_classes="generate_button",
) )
output_gallery = gr.Markdown("aaa", render=False) output_gallery = gr.Markdown("aaa", render=False)
button.click(self.run_generate_interface, inputs=inputs, outputs=output_gallery) button.click(
self.run_generate_interface,
inputs=inputs,
outputs=output_gallery,
)
output_gallery.render() output_gallery.render()
return interface return interface
@@ -106,6 +111,7 @@ class GenerateUI:
ctx = GenerationContext(**arugments, progress=progress) ctx = GenerationContext(**arugments, progress=progress)
ctx.process() # Here we go ! 🚀 ctx.process() # Here we go ! 🚀
return gr.update(value=ctx.get_file_path("final.mp4")) return gr.update(value=ctx.get_file_path("final.mp4"))
def repack_options(self, *args) -> dict[BaseEngine]: def repack_options(self, *args) -> dict[BaseEngine]:
""" """
Repacks the options provided as arguments into a dictionary based on the selected engine. Repacks the options provided as arguments into a dictionary based on the selected engine.