Files
viralfactory/src/engines/BaseEngine.py
2024-02-23 09:56:54 +01:00

157 lines
5.0 KiB
Python

from abc import ABC, abstractmethod
import gradio as gr
import moviepy.editor as mp
from sqlalchemy.future import select
from ..chore import GenerationContext
from ..models import SessionLocal, File, Setting
class BaseEngine(ABC):
num_options: int
name: str
description: str
def __init__(self):
self.ctx: GenerationContext # This is for type hinting only
pass
@classmethod
@abstractmethod
def get_options(cls):
...
def get_video_duration(self, path: str) -> float:
return mp.VideoFileClip(path).duration
def get_audio_duration(self, path: str) -> float:
return mp.AudioFileClip(path).duration
@classmethod
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
with SessionLocal() as db:
if type:
return (
db.execute(
select(File).filter(
File.type == type, File.provider == cls.name
)
)
.scalars()
.all()
)
elif by_id:
return (
db.execute(
select(File).filter(
File.id == by_id, File.provider == cls.name
)
)
.scalars()
.first()
)
else:
return (
db.execute(select(File).filter(File.provider == cls.name))
.scalars()
.all()
)
@classmethod
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
with SessionLocal() as db:
db.add(File(path=path, data=metadata, type=type, provider=cls.name))
db.commit()
@classmethod
def remove_asset(cls, *, path: str):
with SessionLocal() as db:
db.execute(select(File).filter(File.path == path)).delete()
db.commit()
@classmethod
def store_setting(cls, *, type: str = None, data: dict):
with SessionLocal() as db:
# check if setting exists
setting = db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == type
)
).scalar()
if setting:
setting.data = data
else:
db.add(Setting(provider=cls.name, type=type, data=data))
db.commit()
@classmethod
def get_setting(cls, *args, **kwargs):
"""
This method is deprecated, use retrieve_setting instead
"""
return cls.retrieve_setting(*args, **kwargs)
@classmethod
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
"""
Retrieve a setting from the database based on the provided identifier or type.
Args:
identifier (str, optional): The identifier of the setting. Defaults to None.
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
Returns:
str | list[str] | None: The retrieved setting data, or None if not found.
"""
with SessionLocal() as db:
if not identifier and type:
identifier = type
if identifier:
result = db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier
)
).scalar()
if result:
return result.data
return None
else:
return [
s.data
for s in db.execute(
select(Setting).filter(Setting.provider == cls.name)
)
.scalars()
.all()
]
@classmethod
def remove_setting(cls, *, identifier: str = None, type: str = None):
"""
Remove a setting from the database.
Args:
identifier (str, optional): The identifier of the setting to be removed. If not provided, the type will be used as the identifier. Defaults to None.
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
"""
with SessionLocal() as db:
if not identifier and type:
identifier = type
if identifier:
db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier
)
).delete()
else:
db.execute(
select(Setting).filter(Setting.provider == cls.name)
).delete()
db.commit()
@classmethod
def get_settings(cls):
...