mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Formatting and stuff
This commit is contained in:
@@ -39,7 +39,7 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_background(self) -> mp.VideoClip:
|
def get_background(self):
|
||||||
background = mp.VideoFileClip(f"{self.background_video}", audio=False)
|
background = mp.VideoFileClip(f"{self.background_video}", audio=False)
|
||||||
background_max_start = background.duration - self.ctx.duration
|
background_max_start = background.duration - self.ctx.duration
|
||||||
if background_max_start < 0:
|
if background_max_start < 0:
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class BaseEngine(ABC):
|
|||||||
def get_audio_duration(self, path: str) -> float:
|
def get_audio_duration(self, path: str) -> float:
|
||||||
return mp.AudioFileClip(path).duration
|
return mp.AudioFileClip(path).duration
|
||||||
|
|
||||||
|
# noinspection PyShadowingBuiltins
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
|
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
|
||||||
with SessionLocal() as db:
|
with SessionLocal() as db:
|
||||||
@@ -58,6 +59,7 @@ class BaseEngine(ABC):
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# noinspection PyShadowingBuiltins
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
|
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
|
||||||
with SessionLocal() as db:
|
with SessionLocal() as db:
|
||||||
@@ -70,6 +72,7 @@ class BaseEngine(ABC):
|
|||||||
db.execute(select(File).filter(File.path == path)).delete()
|
db.execute(select(File).filter(File.path == path)).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
# noinspection PyShadowingBuiltins
|
||||||
@classmethod
|
@classmethod
|
||||||
def store_setting(cls, *, type: str = None, data: dict):
|
def store_setting(cls, *, type: str = None, data: dict):
|
||||||
with SessionLocal() as db:
|
with SessionLocal() as db:
|
||||||
@@ -92,6 +95,7 @@ class BaseEngine(ABC):
|
|||||||
"""
|
"""
|
||||||
return cls.retrieve_setting(*args, **kwargs)
|
return cls.retrieve_setting(*args, **kwargs)
|
||||||
|
|
||||||
|
# noinspection PyShadowingBuiltins
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
|
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
|
||||||
"""
|
"""
|
||||||
@@ -127,6 +131,7 @@ class BaseEngine(ABC):
|
|||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# noinspection PyShadowingBuiltins
|
||||||
@classmethod
|
@classmethod
|
||||||
def remove_setting(cls, *, identifier: str = None, type: str = None):
|
def remove_setting(cls, *, identifier: str = None, type: str = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
from .. import BaseEngine
|
from ..BaseEngine import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
class BaseMetadataEngine(BaseEngine):
|
class BaseMetadataEngine(BaseEngine):
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class ShortsMetadataEngine(BaseMetadataEngine):
|
|||||||
num_options = 0
|
num_options = 0
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_metadata(self):
|
def get_metadata(self):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ class NoneEngine(BaseEngine):
|
|||||||
description = "No engine selected"
|
description = "No engine selected"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
super().__init__()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_options(cls):
|
def get_options(cls):
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ShowerThoughtsScriptEngine(BaseScriptEngine):
|
|||||||
self.n_sentences = options[0]
|
self.n_sentences = options[0]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def generate(self) -> str:
|
def generate(self):
|
||||||
sys_prompt, chat_prompt = get_prompt(
|
sys_prompt, chat_prompt = get_prompt(
|
||||||
"shower_thoughts",
|
"shower_thoughts",
|
||||||
location=os.path.join(
|
location=os.path.join(
|
||||||
|
|||||||
@@ -22,17 +22,6 @@ class BaseTTSEngine(BaseEngine):
|
|||||||
def remove_punctuation(self, text: str) -> str:
|
def remove_punctuation(self, text: str) -> str:
|
||||||
return text.translate(str.maketrans("", "", ".,!?;:"))
|
return text.translate(str.maketrans("", "", ".,!?;:"))
|
||||||
|
|
||||||
def fix_captions(self, script: str, captions: list[Word]) -> list[Word]:
|
|
||||||
script = script.split(" ")
|
|
||||||
new_captions = []
|
|
||||||
for i, word in enumerate(script):
|
|
||||||
original_word = self.remove_punctuation(word.lower())
|
|
||||||
stt_word = self.remove_punctuation(word.lower())
|
|
||||||
if stt_word in original_word:
|
|
||||||
captions[i]["text"] = word
|
|
||||||
new_captions.append(captions[i])
|
|
||||||
# elif there is a word more in the stt than in the original, we
|
|
||||||
|
|
||||||
def time_with_whisper(self, path: str) -> list[Word]:
|
def time_with_whisper(self, path: str) -> list[Word]:
|
||||||
"""
|
"""
|
||||||
Transcribes the audio file at the given path using a pre-trained model and returns a list of words.
|
Transcribes the audio file at the given path using a pre-trained model and returns a list of words.
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
from .. import BaseEngine
|
from ..BaseEngine import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
class BaseUploadEngine(BaseEngine):
|
class BaseUploadEngine(BaseEngine):
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class TikTokUploadEngine(BaseUploadEngine):
|
|||||||
|
|
||||||
def upload(self):
|
def upload(self):
|
||||||
cookies = self.get_setting(type="cookies")["cookies"]
|
cookies = self.get_setting(type="cookies")["cookies"]
|
||||||
if cookies == None:
|
if cookies is None:
|
||||||
gr.Warning(
|
gr.Warning(
|
||||||
"Skipping upload to TikTok because no cookies were provided. Please provide cookies in the settings."
|
"Skipping upload to TikTok because no cookies were provided. Please provide cookies in the settings."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ class YouTubeUploadEngine(BaseUploadEngine):
|
|||||||
|
|
||||||
num_options = 2
|
num_options = 2
|
||||||
|
|
||||||
def __init__(self, options: list):
|
def __init__(self, options: list, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.oauth_name = options[0]
|
self.oauth_name = options[0]
|
||||||
self.oauth = self.retrieve_setting(type="oauth_credentials")[self.oauth_name]
|
self.oauth = self.retrieve_setting(type="oauth_credentials")[self.oauth_name]
|
||||||
self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]]
|
self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]]
|
||||||
|
|||||||
@@ -12,13 +12,7 @@ from . import UploadEngine
|
|||||||
from .BaseEngine import BaseEngine
|
from .BaseEngine import BaseEngine
|
||||||
from .NoneEngine import NoneEngine
|
from .NoneEngine import NoneEngine
|
||||||
|
|
||||||
|
ENGINES: dict[str, dict[str, bool | list[BaseEngine]]] = {
|
||||||
class EngineDict(TypedDict):
|
|
||||||
classes: list[BaseEngine]
|
|
||||||
multiple: bool
|
|
||||||
|
|
||||||
|
|
||||||
ENGINES: dict[str, EngineDict] = {
|
|
||||||
"SettingsEngine": {
|
"SettingsEngine": {
|
||||||
"classes": [SettingsEngine.SettingsEngine],
|
"classes": [SettingsEngine.SettingsEngine],
|
||||||
"multiple": False,
|
"multiple": False,
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import orjson
|
import orjson
|
||||||
import sys
|
|
||||||
|
|
||||||
from src.engines import ENGINES, BaseEngine
|
from src.engines import ENGINES, BaseEngine
|
||||||
from src.chore import GenerationContext
|
from src.chore import GenerationContext
|
||||||
@@ -31,7 +30,7 @@ class GenerateUI:
|
|||||||
|
|
||||||
def get_ui(self):
|
def get_ui(self):
|
||||||
ui = gr.TabbedInterface(
|
ui = gr.TabbedInterface(
|
||||||
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
|
*self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css
|
||||||
)
|
)
|
||||||
return ui
|
return ui
|
||||||
|
|
||||||
@@ -116,12 +115,13 @@ class GenerateUI:
|
|||||||
value=None
|
value=None
|
||||||
)
|
)
|
||||||
preset_button = gr.Button("Load")
|
preset_button = gr.Button("Load")
|
||||||
|
|
||||||
def load_preset(preset_name, *inputs) -> list[gr.update]:
|
def load_preset(preset_name, *inputs) -> list[gr.update]:
|
||||||
with open("local/presets.json", "r") as f:
|
with open("local/presets.json", "r") as f:
|
||||||
presets = orjson.loads(f.read())
|
presets = orjson.loads(f.read())
|
||||||
returnable = []
|
returnable = []
|
||||||
if preset_name in presets.keys():
|
if preset_name in presets.keys():
|
||||||
# If the preset exists
|
# If the preset exists
|
||||||
preset = presets[preset_name]
|
preset = presets[preset_name]
|
||||||
for engine_type, engines in ENGINES.items():
|
for engine_type, engines in ENGINES.items():
|
||||||
engines = engines["classes"]
|
engines = engines["classes"]
|
||||||
@@ -129,7 +129,8 @@ class GenerateUI:
|
|||||||
for engine in engines:
|
for engine in engines:
|
||||||
if engine.name in preset.get(engine_type, {}).keys():
|
if engine.name in preset.get(engine_type, {}).keys():
|
||||||
values[0].append(engine.name)
|
values[0].append(engine.name)
|
||||||
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
|
values.extend(
|
||||||
|
gr.update(value=value) for value in preset[engine_type][engine.name])
|
||||||
else:
|
else:
|
||||||
values.extend(gr.update() for _ in range(engine.num_options))
|
values.extend(gr.update() for _ in range(engine.num_options))
|
||||||
returnable.extend(values)
|
returnable.extend(values)
|
||||||
@@ -154,7 +155,8 @@ class GenerateUI:
|
|||||||
presets[preset_name] = new_preset
|
presets[preset_name] = new_preset
|
||||||
f.write(orjson.dumps(presets))
|
f.write(orjson.dumps(presets))
|
||||||
return [gr.update(value=presets.keys()), *returnable]
|
return [gr.update(value=presets.keys()), *returnable]
|
||||||
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs])
|
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs],
|
||||||
|
outputs=[preset_dropdown, *inputs])
|
||||||
output_gallery = gr.Markdown("aaa", render=False)
|
output_gallery = gr.Markdown("aaa", render=False)
|
||||||
button.click(
|
button.click(
|
||||||
self.run_generate_interface,
|
self.run_generate_interface,
|
||||||
@@ -172,7 +174,7 @@ class GenerateUI:
|
|||||||
ctx.process() # Here we go ! 🚀
|
ctx.process() # Here we go ! 🚀
|
||||||
return gr.update(value=ctx.get_file_path("final.mp4"))
|
return gr.update(value=ctx.get_file_path("final.mp4"))
|
||||||
|
|
||||||
def repack_options(self, *args) -> dict[BaseEngine]:
|
def repack_options(self, *args) -> dict[str, list[BaseEngine]]:
|
||||||
"""
|
"""
|
||||||
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
||||||
|
|
||||||
@@ -198,10 +200,10 @@ class GenerateUI:
|
|||||||
options[engine_type].append(
|
options[engine_type].append(
|
||||||
engine(options=args[: engine.num_options])
|
engine(options=args[: engine.num_options])
|
||||||
)
|
)
|
||||||
args = args[engine.num_options :]
|
args = args[engine.num_options:]
|
||||||
else:
|
else:
|
||||||
# we don't care about this, it's not the selected engine, we throw it away
|
# we don't care about this, it's not the selected engine, we throw it away
|
||||||
args = args[engine.num_options :]
|
args = args[engine.num_options:]
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user