2024-02-13 14:15:27 +01:00
|
|
|
import gradio as gr
|
2024-02-20 14:55:58 +01:00
|
|
|
import moviepy.editor as mp
|
|
|
|
|
|
2024-02-15 17:53:02 +01:00
|
|
|
from abc import ABC, abstractmethod
|
2024-02-20 14:55:58 +01:00
|
|
|
from sqlalchemy.future import select
|
2024-02-15 17:53:02 +01:00
|
|
|
|
|
|
|
|
from ..chore import GenerationContext
|
2024-02-20 14:55:58 +01:00
|
|
|
from ..models import SessionLocal, File
|
2024-02-13 14:15:27 +01:00
|
|
|
|
2024-02-14 17:49:51 +01:00
|
|
|
|
2024-02-13 14:15:27 +01:00
|
|
|
class BaseEngine(ABC):
|
2024-02-14 17:49:51 +01:00
|
|
|
num_options: int
|
2024-02-13 14:15:27 +01:00
|
|
|
name: str
|
|
|
|
|
description: str
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
2024-02-15 17:53:02 +01:00
|
|
|
self.ctx: GenerationContext # This is for type hinting only
|
2024-02-14 17:49:51 +01:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_options():
|
|
|
|
|
...
|
2024-02-20 14:55:58 +01:00
|
|
|
|
|
|
|
|
def get_video_duration(self, path: str) -> float:
|
|
|
|
|
return mp.VideoFileClip(path).duration
|
|
|
|
|
|
|
|
|
|
def get_audio_duration(self, path: str) -> float:
|
|
|
|
|
return mp.AudioFileClip(path).duration
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_assets(cls, *, type: str = None) -> list[File]:
|
|
|
|
|
with SessionLocal() as db:
|
|
|
|
|
if type:
|
|
|
|
|
return (
|
|
|
|
|
db.execute(
|
|
|
|
|
select(File).filter(
|
|
|
|
|
File.type == type, File.provider == cls.name
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
.scalars()
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return (
|
|
|
|
|
db.execute(select(File).filter(File.provider == cls.name))
|
|
|
|
|
.scalars()
|
|
|
|
|
.all()
|
|
|
|
|
)
|
2024-02-20 16:23:15 +01:00
|
|
|
|
2024-02-20 14:55:58 +01:00
|
|
|
@classmethod
|
|
|
|
|
def add_asset(cls, *, path: str, metadata: dict, type: str = None):
|
|
|
|
|
with SessionLocal() as db:
|
|
|
|
|
db.add(File(path=path, data=metadata, type=type, provider=cls.name))
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def remove_asset(cls, *, path: str):
|
|
|
|
|
with SessionLocal() as db:
|
|
|
|
|
db.execute(select(File).filter(File.path == path)).delete()
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_settings(cls):
|
|
|
|
|
...
|