Formatting and stuff

This commit is contained in:
2024-02-23 11:07:50 +01:00
parent 8951837d04
commit d7fa8945d1
13 changed files with 28 additions and 35 deletions

View File

@@ -39,7 +39,7 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
)
]
def get_background(self) -> mp.VideoClip:
def get_background(self):
background = mp.VideoFileClip(f"{self.background_video}", audio=False)
background_max_start = background.duration - self.ctx.duration
if background_max_start < 0:

View File

@@ -28,6 +28,7 @@ class BaseEngine(ABC):
def get_audio_duration(self, path: str) -> float:
return mp.AudioFileClip(path).duration
# noinspection PyShadowingBuiltins
@classmethod
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
with SessionLocal() as db:
@@ -58,6 +59,7 @@ class BaseEngine(ABC):
.all()
)
# noinspection PyShadowingBuiltins
@classmethod
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
with SessionLocal() as db:
@@ -70,6 +72,7 @@ class BaseEngine(ABC):
db.execute(select(File).filter(File.path == path)).delete()
db.commit()
# noinspection PyShadowingBuiltins
@classmethod
def store_setting(cls, *, type: str = None, data: dict):
with SessionLocal() as db:
@@ -92,6 +95,7 @@ class BaseEngine(ABC):
"""
return cls.retrieve_setting(*args, **kwargs)
# noinspection PyShadowingBuiltins
@classmethod
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
"""
@@ -127,6 +131,7 @@ class BaseEngine(ABC):
.all()
]
# noinspection PyShadowingBuiltins
@classmethod
def remove_setting(cls, *, identifier: str = None, type: str = None):
"""

View File

@@ -1,11 +1,11 @@
from abc import abstractmethod
from typing import TypedDict
from .. import BaseEngine
from ..BaseEngine import BaseEngine
class BaseMetadataEngine(BaseEngine):
def __init__(self, **kwargs) -> None:
super().__init__()
...
@abstractmethod

View File

@@ -10,6 +10,7 @@ class ShortsMetadataEngine(BaseMetadataEngine):
num_options = 0
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
...
def get_metadata(self):

View File

@@ -7,7 +7,7 @@ class NoneEngine(BaseEngine):
description = "No engine selected"
def __init__(self):
pass
super().__init__()
@classmethod
def get_options(cls):

View File

@@ -15,7 +15,7 @@ class ShowerThoughtsScriptEngine(BaseScriptEngine):
self.n_sentences = options[0]
super().__init__()
def generate(self) -> str:
def generate(self):
sys_prompt, chat_prompt = get_prompt(
"shower_thoughts",
location=os.path.join(

View File

@@ -22,17 +22,6 @@ class BaseTTSEngine(BaseEngine):
def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", ".,!?;:"))
def fix_captions(self, script: str, captions: list[Word]) -> list[Word]:
script = script.split(" ")
new_captions = []
for i, word in enumerate(script):
original_word = self.remove_punctuation(word.lower())
stt_word = self.remove_punctuation(word.lower())
if stt_word in original_word:
captions[i]["text"] = word
new_captions.append(captions[i])
# elif there is a word more in the stt than in the original, we
def time_with_whisper(self, path: str) -> list[Word]:
"""
Transcribes the audio file at the given path using a pre-trained model and returns a list of words.

View File

@@ -1,10 +1,11 @@
from abc import abstractmethod
from .. import BaseEngine
from ..BaseEngine import BaseEngine
class BaseUploadEngine(BaseEngine):
def __init__(self, **kwargs) -> None:
super().__init__()
...
@abstractmethod

View File

@@ -16,7 +16,7 @@ class TikTokUploadEngine(BaseUploadEngine):
def upload(self):
cookies = self.get_setting(type="cookies")["cookies"]
if cookies == None:
if cookies is None:
gr.Warning(
"Skipping upload to TikTok because no cookies were provided. Please provide cookies in the settings."
)

View File

@@ -12,7 +12,8 @@ class YouTubeUploadEngine(BaseUploadEngine):
num_options = 2
def __init__(self, options: list):
def __init__(self, options: list, **kwargs):
super().__init__(**kwargs)
self.oauth_name = options[0]
self.oauth = self.retrieve_setting(type="oauth_credentials")[self.oauth_name]
self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]]

View File

@@ -12,13 +12,7 @@ from . import UploadEngine
from .BaseEngine import BaseEngine
from .NoneEngine import NoneEngine
class EngineDict(TypedDict):
classes: list[BaseEngine]
multiple: bool
ENGINES: dict[str, EngineDict] = {
ENGINES: dict[str, dict[str, bool | list[BaseEngine]]] = {
"SettingsEngine": {
"classes": [SettingsEngine.SettingsEngine],
"multiple": False,

View File

@@ -1,6 +1,5 @@
import gradio as gr
import orjson
import sys
from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext
@@ -31,7 +30,7 @@ class GenerateUI:
def get_ui(self):
ui = gr.TabbedInterface(
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
*self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css
)
return ui
@@ -116,6 +115,7 @@ class GenerateUI:
value=None
)
preset_button = gr.Button("Load")
def load_preset(preset_name, *inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f:
presets = orjson.loads(f.read())
@@ -129,7 +129,8 @@ class GenerateUI:
for engine in engines:
if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name)
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
values.extend(
gr.update(value=value) for value in preset[engine_type][engine.name])
else:
values.extend(gr.update() for _ in range(engine.num_options))
returnable.extend(values)
@@ -154,7 +155,8 @@ class GenerateUI:
presets[preset_name] = new_preset
f.write(orjson.dumps(presets))
return [gr.update(value=presets.keys()), *returnable]
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs])
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs],
outputs=[preset_dropdown, *inputs])
output_gallery = gr.Markdown("aaa", render=False)
button.click(
self.run_generate_interface,
@@ -172,7 +174,7 @@ class GenerateUI:
ctx.process() # Here we go ! 🚀
return gr.update(value=ctx.get_file_path("final.mp4"))
def repack_options(self, *args) -> dict[BaseEngine]:
def repack_options(self, *args) -> dict[str, list[BaseEngine]]:
"""
Repacks the options provided as arguments into a dictionary based on the selected engine.