diff --git a/main.py b/main.py index bfeeccc..77362d0 100644 --- a/main.py +++ b/main.py @@ -6,4 +6,4 @@ load_dotenv() if __name__ == "__main__": init_db() - launch() \ No newline at end of file + launch() diff --git a/src/engines/BackgroundEngine/VideoBackgroundEngine.py b/src/engines/BackgroundEngine/VideoBackgroundEngine.py index 74ff072..b25d634 100644 --- a/src/engines/BackgroundEngine/VideoBackgroundEngine.py +++ b/src/engines/BackgroundEngine/VideoBackgroundEngine.py @@ -39,7 +39,7 @@ class VideoBackgroundEngine(BaseBackgroundEngine): ) ] - def get_background(self) -> mp.VideoClip: + def get_background(self): background = mp.VideoFileClip(f"{self.background_video}", audio=False) background_max_start = background.duration - self.ctx.duration if background_max_start < 0: diff --git a/src/engines/BaseEngine.py b/src/engines/BaseEngine.py index 8c591a3..58980ab 100644 --- a/src/engines/BaseEngine.py +++ b/src/engines/BaseEngine.py @@ -28,6 +28,7 @@ class BaseEngine(ABC): def get_audio_duration(self, path: str) -> float: return mp.AudioFileClip(path).duration + # noinspection PyShadowingBuiltins @classmethod def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None: with SessionLocal() as db: @@ -58,6 +59,7 @@ class BaseEngine(ABC): .all() ) + # noinspection PyShadowingBuiltins @classmethod def add_asset(cls, *, path: str, metadata: dict, type: str = None): with SessionLocal() as db: @@ -70,6 +72,7 @@ class BaseEngine(ABC): db.execute(select(File).filter(File.path == path)).delete() db.commit() + # noinspection PyShadowingBuiltins @classmethod def store_setting(cls, *, type: str = None, data: dict): with SessionLocal() as db: @@ -92,6 +95,7 @@ class BaseEngine(ABC): """ return cls.retrieve_setting(*args, **kwargs) + # noinspection PyShadowingBuiltins @classmethod def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None: """ @@ -127,6 +131,7 @@ class BaseEngine(ABC): .all() ] + # noinspection PyShadowingBuiltins @classmethod def remove_setting(cls, *, identifier: str = None, type: str = None): """ diff --git a/src/engines/MetadataEngine/BaseMetadataEngine.py b/src/engines/MetadataEngine/BaseMetadataEngine.py index 5317909..9f06231 100644 --- a/src/engines/MetadataEngine/BaseMetadataEngine.py +++ b/src/engines/MetadataEngine/BaseMetadataEngine.py @@ -1,11 +1,11 @@ from abc import abstractmethod -from typing import TypedDict -from .. import BaseEngine +from ..BaseEngine import BaseEngine class BaseMetadataEngine(BaseEngine): def __init__(self, **kwargs) -> None: + super().__init__() ... @abstractmethod diff --git a/src/engines/MetadataEngine/ShortsMetadataEngine.py b/src/engines/MetadataEngine/ShortsMetadataEngine.py index 357a302..0b7335f 100644 --- a/src/engines/MetadataEngine/ShortsMetadataEngine.py +++ b/src/engines/MetadataEngine/ShortsMetadataEngine.py @@ -10,6 +10,7 @@ class ShortsMetadataEngine(BaseMetadataEngine): num_options = 0 def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) ... def get_metadata(self): diff --git a/src/engines/NoneEngine.py b/src/engines/NoneEngine.py index 6c828b5..1f304da 100644 --- a/src/engines/NoneEngine.py +++ b/src/engines/NoneEngine.py @@ -7,7 +7,7 @@ class NoneEngine(BaseEngine): description = "No engine selected" def __init__(self): - pass + super().__init__() @classmethod def get_options(cls): diff --git a/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py b/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py index 7fb9525..1e6109f 100644 --- a/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py +++ b/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py @@ -15,7 +15,7 @@ class ShowerThoughtsScriptEngine(BaseScriptEngine): self.n_sentences = options[0] super().__init__() - def generate(self) -> str: + def generate(self): sys_prompt, chat_prompt = get_prompt( "shower_thoughts", location=os.path.join( diff --git a/src/engines/TTSEngine/BaseTTSEngine.py b/src/engines/TTSEngine/BaseTTSEngine.py index dbb5c65..9389c33 100644 --- a/src/engines/TTSEngine/BaseTTSEngine.py +++ b/src/engines/TTSEngine/BaseTTSEngine.py @@ -22,17 +22,6 @@ class BaseTTSEngine(BaseEngine): def remove_punctuation(self, text: str) -> str: return text.translate(str.maketrans("", "", ".,!?;:")) - def fix_captions(self, script: str, captions: list[Word]) -> list[Word]: - script = script.split(" ") - new_captions = [] - for i, word in enumerate(script): - original_word = self.remove_punctuation(word.lower()) - stt_word = self.remove_punctuation(word.lower()) - if stt_word in original_word: - captions[i]["text"] = word - new_captions.append(captions[i]) - # elif there is a word more in the stt than in the original, we - def time_with_whisper(self, path: str) -> list[Word]: """ Transcribes the audio file at the given path using a pre-trained model and returns a list of words. diff --git a/src/engines/UploadEngine/BaseUploadEngine.py b/src/engines/UploadEngine/BaseUploadEngine.py index 479ccb9..d5de7b7 100644 --- a/src/engines/UploadEngine/BaseUploadEngine.py +++ b/src/engines/UploadEngine/BaseUploadEngine.py @@ -1,10 +1,11 @@ from abc import abstractmethod -from .. import BaseEngine +from ..BaseEngine import BaseEngine class BaseUploadEngine(BaseEngine): def __init__(self, **kwargs) -> None: + super().__init__() ... @abstractmethod diff --git a/src/engines/UploadEngine/TikTokUploadEngine.py b/src/engines/UploadEngine/TikTokUploadEngine.py index bf956fa..da1bdbb 100644 --- a/src/engines/UploadEngine/TikTokUploadEngine.py +++ b/src/engines/UploadEngine/TikTokUploadEngine.py @@ -16,7 +16,7 @@ class TikTokUploadEngine(BaseUploadEngine): def upload(self): cookies = self.get_setting(type="cookies")["cookies"] - if cookies == None: + if cookies is None: gr.Warning( "Skipping upload to TikTok because no cookies were provided. Please provide cookies in the settings." ) diff --git a/src/engines/UploadEngine/YouTubeUploadEngine.py b/src/engines/UploadEngine/YouTubeUploadEngine.py index efa9cdd..87a153c 100644 --- a/src/engines/UploadEngine/YouTubeUploadEngine.py +++ b/src/engines/UploadEngine/YouTubeUploadEngine.py @@ -12,7 +12,8 @@ class YouTubeUploadEngine(BaseUploadEngine): num_options = 2 - def __init__(self, options: list): + def __init__(self, options: list, **kwargs): + super().__init__(**kwargs) self.oauth_name = options[0] self.oauth = self.retrieve_setting(type="oauth_credentials")[self.oauth_name] self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]] diff --git a/src/engines/__init__.py b/src/engines/__init__.py index 777487b..ad3a5c5 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -12,13 +12,7 @@ from . import UploadEngine from .BaseEngine import BaseEngine from .NoneEngine import NoneEngine - -class EngineDict(TypedDict): - classes: list[BaseEngine] - multiple: bool - - -ENGINES: dict[str, EngineDict] = { +ENGINES: dict[str, dict[str, bool | list[BaseEngine]]] = { "SettingsEngine": { "classes": [SettingsEngine.SettingsEngine], "multiple": False, diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index d5d1e1c..342d5e6 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -1,6 +1,5 @@ import gradio as gr import orjson -import sys from src.engines import ENGINES, BaseEngine from src.chore import GenerationContext @@ -31,7 +30,7 @@ class GenerateUI: def get_ui(self): ui = gr.TabbedInterface( - *self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css + *self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css ) return ui @@ -116,12 +115,13 @@ class GenerateUI: value=None ) preset_button = gr.Button("Load") + def load_preset(preset_name, *inputs) -> list[gr.update]: with open("local/presets.json", "r") as f: presets = orjson.loads(f.read()) returnable = [] if preset_name in presets.keys(): - # If the preset exists + # If the preset exists preset = presets[preset_name] for engine_type, engines in ENGINES.items(): engines = engines["classes"] @@ -129,7 +129,8 @@ class GenerateUI: for engine in engines: if engine.name in preset.get(engine_type, {}).keys(): values[0].append(engine.name) - values.extend(gr.update(value=value) for value in preset[engine_type][engine.name]) + values.extend( + gr.update(value=value) for value in preset[engine_type][engine.name]) else: values.extend(gr.update() for _ in range(engine.num_options)) returnable.extend(values) @@ -154,7 +155,8 @@ class GenerateUI: presets[preset_name] = new_preset f.write(orjson.dumps(presets)) return [gr.update(value=presets.keys()), *returnable] - preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs]) + preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], + outputs=[preset_dropdown, *inputs]) output_gallery = gr.Markdown("aaa", render=False) button.click( self.run_generate_interface, @@ -172,7 +174,7 @@ class GenerateUI: ctx.process() # Here we go ! 🚀 return gr.update(value=ctx.get_file_path("final.mp4")) - def repack_options(self, *args) -> dict[BaseEngine]: + def repack_options(self, *args) -> dict[str, list[BaseEngine]]: """ Repacks the options provided as arguments into a dictionary based on the selected engine. @@ -198,10 +200,10 @@ class GenerateUI: options[engine_type].append( engine(options=args[: engine.num_options]) ) - args = args[engine.num_options :] + args = args[engine.num_options:] else: # we don't care about this, it's not the selected engine, we throw it away - args = args[engine.num_options :] + args = args[engine.num_options:] return options