mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Add BackgroundEngine to engines.
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import gradio as gr
|
||||
import moviepy.editor as mp
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from ..chore import GenerationContext
|
||||
from ..models import SessionLocal, File
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@@ -17,3 +21,45 @@ class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def get_options():
|
||||
...
|
||||
|
||||
def get_video_duration(self, path: str) -> float:
|
||||
return mp.VideoFileClip(path).duration
|
||||
|
||||
def get_audio_duration(self, path: str) -> float:
|
||||
return mp.AudioFileClip(path).duration
|
||||
|
||||
@classmethod
|
||||
def get_assets(cls, *, type: str = None) -> list[File]:
|
||||
with SessionLocal() as db:
|
||||
if type:
|
||||
return (
|
||||
db.execute(
|
||||
select(File).filter(
|
||||
File.type == type, File.provider == cls.name
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
db.execute(select(File).filter(File.provider == cls.name))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
|
||||
with SessionLocal() as db:
|
||||
db.add(File(path=path, data=metadata, type=type, provider=cls.name))
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def remove_asset(cls, *, path: str):
|
||||
with SessionLocal() as db:
|
||||
db.execute(select(File).filter(File.path == path)).delete()
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
...
|
||||
|
||||
@@ -7,6 +7,8 @@ from . import LLMEngine
|
||||
from . import CaptioningEngine
|
||||
from . import AssetsEngine
|
||||
from . import SettingsEngine
|
||||
from . import BackgroundEngine
|
||||
|
||||
|
||||
class EngineDict(TypedDict):
|
||||
classes: list[BaseEngine]
|
||||
@@ -19,7 +21,6 @@ ENGINES: dict[str, EngineDict] = {
|
||||
"multiple": False,
|
||||
"show_dropdown": False,
|
||||
},
|
||||
|
||||
"SimpleLLMEngine": {
|
||||
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
||||
"multiple": False,
|
||||
@@ -47,4 +48,8 @@ ENGINES: dict[str, EngineDict] = {
|
||||
"classes": [AssetsEngine.DallEAssetsEngine, NoneEngine],
|
||||
"multiple": True,
|
||||
},
|
||||
"BackgroundEngine": {
|
||||
"classes": [BackgroundEngine.SimpleBackgroundEngine, NoneEngine],
|
||||
"multiple": False,
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user