mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Add BackgroundEngine to engines.
This commit is contained in:
@@ -1,7 +1,11 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import moviepy.editor as mp
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from ..chore import GenerationContext
|
from ..chore import GenerationContext
|
||||||
|
from ..models import SessionLocal, File
|
||||||
|
|
||||||
|
|
||||||
class BaseEngine(ABC):
|
class BaseEngine(ABC):
|
||||||
@@ -17,3 +21,45 @@ class BaseEngine(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_options():
|
def get_options():
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_video_duration(self, path: str) -> float:
|
||||||
|
return mp.VideoFileClip(path).duration
|
||||||
|
|
||||||
|
def get_audio_duration(self, path: str) -> float:
|
||||||
|
return mp.AudioFileClip(path).duration
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_assets(cls, *, type: str = None) -> list[File]:
|
||||||
|
with SessionLocal() as db:
|
||||||
|
if type:
|
||||||
|
return (
|
||||||
|
db.execute(
|
||||||
|
select(File).filter(
|
||||||
|
File.type == type, File.provider == cls.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
db.execute(select(File).filter(File.provider == cls.name))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
|
||||||
|
with SessionLocal() as db:
|
||||||
|
db.add(File(path=path, data=metadata, type=type, provider=cls.name))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remove_asset(cls, *, path: str):
|
||||||
|
with SessionLocal() as db:
|
||||||
|
db.execute(select(File).filter(File.path == path)).delete()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_settings(cls):
|
||||||
|
...
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ from . import LLMEngine
|
|||||||
from . import CaptioningEngine
|
from . import CaptioningEngine
|
||||||
from . import AssetsEngine
|
from . import AssetsEngine
|
||||||
from . import SettingsEngine
|
from . import SettingsEngine
|
||||||
|
from . import BackgroundEngine
|
||||||
|
|
||||||
|
|
||||||
class EngineDict(TypedDict):
|
class EngineDict(TypedDict):
|
||||||
classes: list[BaseEngine]
|
classes: list[BaseEngine]
|
||||||
@@ -19,7 +21,6 @@ ENGINES: dict[str, EngineDict] = {
|
|||||||
"multiple": False,
|
"multiple": False,
|
||||||
"show_dropdown": False,
|
"show_dropdown": False,
|
||||||
},
|
},
|
||||||
|
|
||||||
"SimpleLLMEngine": {
|
"SimpleLLMEngine": {
|
||||||
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
||||||
"multiple": False,
|
"multiple": False,
|
||||||
@@ -47,4 +48,8 @@ ENGINES: dict[str, EngineDict] = {
|
|||||||
"classes": [AssetsEngine.DallEAssetsEngine, NoneEngine],
|
"classes": [AssetsEngine.DallEAssetsEngine, NoneEngine],
|
||||||
"multiple": True,
|
"multiple": True,
|
||||||
},
|
},
|
||||||
|
"BackgroundEngine": {
|
||||||
|
"classes": [BackgroundEngine.SimpleBackgroundEngine, NoneEngine],
|
||||||
|
"multiple": False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user