Formatting and stuff

This commit is contained in:
2024-02-23 11:07:50 +01:00
parent 8951837d04
commit d7fa8945d1
13 changed files with 28 additions and 35 deletions

View File

@@ -39,7 +39,7 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
) )
] ]
def get_background(self) -> mp.VideoClip: def get_background(self):
background = mp.VideoFileClip(f"{self.background_video}", audio=False) background = mp.VideoFileClip(f"{self.background_video}", audio=False)
background_max_start = background.duration - self.ctx.duration background_max_start = background.duration - self.ctx.duration
if background_max_start < 0: if background_max_start < 0:

View File

@@ -28,6 +28,7 @@ class BaseEngine(ABC):
def get_audio_duration(self, path: str) -> float: def get_audio_duration(self, path: str) -> float:
return mp.AudioFileClip(path).duration return mp.AudioFileClip(path).duration
# noinspection PyShadowingBuiltins
@classmethod @classmethod
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None: def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
with SessionLocal() as db: with SessionLocal() as db:
@@ -58,6 +59,7 @@ class BaseEngine(ABC):
.all() .all()
) )
# noinspection PyShadowingBuiltins
@classmethod @classmethod
def add_asset(cls, *, path: str, metadata: dict, type: str = None): def add_asset(cls, *, path: str, metadata: dict, type: str = None):
with SessionLocal() as db: with SessionLocal() as db:
@@ -70,6 +72,7 @@ class BaseEngine(ABC):
db.execute(select(File).filter(File.path == path)).delete() db.execute(select(File).filter(File.path == path)).delete()
db.commit() db.commit()
# noinspection PyShadowingBuiltins
@classmethod @classmethod
def store_setting(cls, *, type: str = None, data: dict): def store_setting(cls, *, type: str = None, data: dict):
with SessionLocal() as db: with SessionLocal() as db:
@@ -92,6 +95,7 @@ class BaseEngine(ABC):
""" """
return cls.retrieve_setting(*args, **kwargs) return cls.retrieve_setting(*args, **kwargs)
# noinspection PyShadowingBuiltins
@classmethod @classmethod
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None: def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
""" """
@@ -127,6 +131,7 @@ class BaseEngine(ABC):
.all() .all()
] ]
# noinspection PyShadowingBuiltins
@classmethod @classmethod
def remove_setting(cls, *, identifier: str = None, type: str = None): def remove_setting(cls, *, identifier: str = None, type: str = None):
""" """

View File

@@ -1,11 +1,11 @@
from abc import abstractmethod from abc import abstractmethod
from typing import TypedDict
from .. import BaseEngine from ..BaseEngine import BaseEngine
class BaseMetadataEngine(BaseEngine): class BaseMetadataEngine(BaseEngine):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__()
... ...
@abstractmethod @abstractmethod

View File

@@ -10,6 +10,7 @@ class ShortsMetadataEngine(BaseMetadataEngine):
num_options = 0 num_options = 0
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
... ...
def get_metadata(self): def get_metadata(self):

View File

@@ -7,7 +7,7 @@ class NoneEngine(BaseEngine):
description = "No engine selected" description = "No engine selected"
def __init__(self): def __init__(self):
pass super().__init__()
@classmethod @classmethod
def get_options(cls): def get_options(cls):

View File

@@ -15,7 +15,7 @@ class ShowerThoughtsScriptEngine(BaseScriptEngine):
self.n_sentences = options[0] self.n_sentences = options[0]
super().__init__() super().__init__()
def generate(self) -> str: def generate(self):
sys_prompt, chat_prompt = get_prompt( sys_prompt, chat_prompt = get_prompt(
"shower_thoughts", "shower_thoughts",
location=os.path.join( location=os.path.join(

View File

@@ -22,17 +22,6 @@ class BaseTTSEngine(BaseEngine):
def remove_punctuation(self, text: str) -> str: def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", ".,!?;:")) return text.translate(str.maketrans("", "", ".,!?;:"))
def fix_captions(self, script: str, captions: list[Word]) -> list[Word]:
script = script.split(" ")
new_captions = []
for i, word in enumerate(script):
original_word = self.remove_punctuation(word.lower())
stt_word = self.remove_punctuation(word.lower())
if stt_word in original_word:
captions[i]["text"] = word
new_captions.append(captions[i])
# elif there is a word more in the stt than in the original, we
def time_with_whisper(self, path: str) -> list[Word]: def time_with_whisper(self, path: str) -> list[Word]:
""" """
Transcribes the audio file at the given path using a pre-trained model and returns a list of words. Transcribes the audio file at the given path using a pre-trained model and returns a list of words.

View File

@@ -1,10 +1,11 @@
from abc import abstractmethod from abc import abstractmethod
from .. import BaseEngine from ..BaseEngine import BaseEngine
class BaseUploadEngine(BaseEngine): class BaseUploadEngine(BaseEngine):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__()
... ...
@abstractmethod @abstractmethod

View File

@@ -16,7 +16,7 @@ class TikTokUploadEngine(BaseUploadEngine):
def upload(self): def upload(self):
cookies = self.get_setting(type="cookies")["cookies"] cookies = self.get_setting(type="cookies")["cookies"]
if cookies == None: if cookies is None:
gr.Warning( gr.Warning(
"Skipping upload to TikTok because no cookies were provided. Please provide cookies in the settings." "Skipping upload to TikTok because no cookies were provided. Please provide cookies in the settings."
) )

View File

@@ -12,7 +12,8 @@ class YouTubeUploadEngine(BaseUploadEngine):
num_options = 2 num_options = 2
def __init__(self, options: list): def __init__(self, options: list, **kwargs):
super().__init__(**kwargs)
self.oauth_name = options[0] self.oauth_name = options[0]
self.oauth = self.retrieve_setting(type="oauth_credentials")[self.oauth_name] self.oauth = self.retrieve_setting(type="oauth_credentials")[self.oauth_name]
self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]] self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]]

View File

@@ -12,13 +12,7 @@ from . import UploadEngine
from .BaseEngine import BaseEngine from .BaseEngine import BaseEngine
from .NoneEngine import NoneEngine from .NoneEngine import NoneEngine
ENGINES: dict[str, dict[str, bool | list[BaseEngine]]] = {
class EngineDict(TypedDict):
classes: list[BaseEngine]
multiple: bool
ENGINES: dict[str, EngineDict] = {
"SettingsEngine": { "SettingsEngine": {
"classes": [SettingsEngine.SettingsEngine], "classes": [SettingsEngine.SettingsEngine],
"multiple": False, "multiple": False,

View File

@@ -1,6 +1,5 @@
import gradio as gr import gradio as gr
import orjson import orjson
import sys
from src.engines import ENGINES, BaseEngine from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext from src.chore import GenerationContext
@@ -31,7 +30,7 @@ class GenerateUI:
def get_ui(self): def get_ui(self):
ui = gr.TabbedInterface( ui = gr.TabbedInterface(
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css *self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css
) )
return ui return ui
@@ -116,12 +115,13 @@ class GenerateUI:
value=None value=None
) )
preset_button = gr.Button("Load") preset_button = gr.Button("Load")
def load_preset(preset_name, *inputs) -> list[gr.update]: def load_preset(preset_name, *inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f: with open("local/presets.json", "r") as f:
presets = orjson.loads(f.read()) presets = orjson.loads(f.read())
returnable = [] returnable = []
if preset_name in presets.keys(): if preset_name in presets.keys():
# If the preset exists # If the preset exists
preset = presets[preset_name] preset = presets[preset_name]
for engine_type, engines in ENGINES.items(): for engine_type, engines in ENGINES.items():
engines = engines["classes"] engines = engines["classes"]
@@ -129,7 +129,8 @@ class GenerateUI:
for engine in engines: for engine in engines:
if engine.name in preset.get(engine_type, {}).keys(): if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name) values[0].append(engine.name)
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name]) values.extend(
gr.update(value=value) for value in preset[engine_type][engine.name])
else: else:
values.extend(gr.update() for _ in range(engine.num_options)) values.extend(gr.update() for _ in range(engine.num_options))
returnable.extend(values) returnable.extend(values)
@@ -154,7 +155,8 @@ class GenerateUI:
presets[preset_name] = new_preset presets[preset_name] = new_preset
f.write(orjson.dumps(presets)) f.write(orjson.dumps(presets))
return [gr.update(value=presets.keys()), *returnable] return [gr.update(value=presets.keys()), *returnable]
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs]) preset_button.click(load_preset, inputs=[preset_dropdown, *inputs],
outputs=[preset_dropdown, *inputs])
output_gallery = gr.Markdown("aaa", render=False) output_gallery = gr.Markdown("aaa", render=False)
button.click( button.click(
self.run_generate_interface, self.run_generate_interface,
@@ -172,7 +174,7 @@ class GenerateUI:
ctx.process() # Here we go ! 🚀 ctx.process() # Here we go ! 🚀
return gr.update(value=ctx.get_file_path("final.mp4")) return gr.update(value=ctx.get_file_path("final.mp4"))
def repack_options(self, *args) -> dict[BaseEngine]: def repack_options(self, *args) -> dict[str, list[BaseEngine]]:
""" """
Repacks the options provided as arguments into a dictionary based on the selected engine. Repacks the options provided as arguments into a dictionary based on the selected engine.
@@ -198,10 +200,10 @@ class GenerateUI:
options[engine_type].append( options[engine_type].append(
engine(options=args[: engine.num_options]) engine(options=args[: engine.num_options])
) )
args = args[engine.num_options :] args = args[engine.num_options:]
else: else:
# we don't care about this, it's not the selected engine, we throw it away # we don't care about this, it's not the selected engine, we throw it away
args = args[engine.num_options :] args = args[engine.num_options:]
return options return options