Files
viralfactory/src/engines/BaseEngine.py

66 lines
1.7 KiB
Python
Raw Normal View History

2024-02-13 14:15:27 +01:00
import gradio as gr
2024-02-20 14:55:58 +01:00
import moviepy.editor as mp
2024-02-15 17:53:02 +01:00
from abc import ABC, abstractmethod
2024-02-20 14:55:58 +01:00
from sqlalchemy.future import select
2024-02-15 17:53:02 +01:00
from ..chore import GenerationContext
2024-02-20 14:55:58 +01:00
from ..models import SessionLocal, File
2024-02-13 14:15:27 +01:00
2024-02-14 17:49:51 +01:00
2024-02-13 14:15:27 +01:00
class BaseEngine(ABC):
2024-02-14 17:49:51 +01:00
num_options: int
2024-02-13 14:15:27 +01:00
name: str
description: str
def __init__(self):
2024-02-15 17:53:02 +01:00
self.ctx: GenerationContext # This is for type hinting only
2024-02-14 17:49:51 +01:00
pass
@classmethod
@abstractmethod
def get_options():
...
2024-02-20 14:55:58 +01:00
def get_video_duration(self, path: str) -> float:
return mp.VideoFileClip(path).duration
def get_audio_duration(self, path: str) -> float:
return mp.AudioFileClip(path).duration
@classmethod
def get_assets(cls, *, type: str = None) -> list[File]:
with SessionLocal() as db:
if type:
return (
db.execute(
select(File).filter(
File.type == type, File.provider == cls.name
)
)
.scalars()
.all()
)
else:
return (
db.execute(select(File).filter(File.provider == cls.name))
.scalars()
.all()
)
@classmethod
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
with SessionLocal() as db:
db.add(File(path=path, data=metadata, type=type, provider=cls.name))
db.commit()
@classmethod
def remove_asset(cls, *, path: str):
with SessionLocal() as db:
db.execute(select(File).filter(File.path == path)).delete()
db.commit()
@classmethod
def get_settings(cls):
...