diff --git a/src/engines/BaseEngine.py b/src/engines/BaseEngine.py index 41520d1..0a6839e 100644 --- a/src/engines/BaseEngine.py +++ b/src/engines/BaseEngine.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from sqlalchemy.future import select from ..chore import GenerationContext -from ..models import SessionLocal, File +from ..models import SessionLocal, File, Setting class BaseEngine(ABC): @@ -29,7 +29,7 @@ class BaseEngine(ABC): return mp.AudioFileClip(path).duration @classmethod - def get_assets(cls, *, type: str = None) -> list[File]: + def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None: with SessionLocal() as db: if type: return ( @@ -41,6 +41,16 @@ class BaseEngine(ABC): .scalars() .all() ) + elif by_id: + return ( + db.execute( + select(File).filter( + File.id == by_id, File.provider == cls.name + ) + ) + .scalars() + .first() + ) else: return ( db.execute(select(File).filter(File.provider == cls.name)) @@ -60,6 +70,87 @@ class BaseEngine(ABC): db.execute(select(File).filter(File.path == path)).delete() db.commit() + @classmethod + def store_setting(cls, *, type: str = None, data: dict): + with SessionLocal() as db: + # check if setting exists + setting = db.execute( + select(Setting).filter( + Setting.provider == cls.name, Setting.type == type + ) + ).scalar() + if setting: + setting.data = data + else: + db.add(Setting(provider=cls.name, type=type, data=data)) + db.commit() + + @classmethod + def get_setting(cls, *args, **kwargs): + """ + This method is deprecated, use retrieve_setting instead + """ + return cls.retrieve_setting(*args, **kwargs) + + @classmethod + def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None: + """ + Retrieve a setting from the database based on the provided identifier or type. + + Args: + identifier (str, optional): The identifier of the setting. Defaults to None. + type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None. + + Returns: + str | list[str] | None: The retrieved setting data, or None if not found. + """ + with SessionLocal() as db: + if not identifier and type: + identifier = type + if identifier: + result = db.execute( + select(Setting).filter( + Setting.provider == cls.name, Setting.type == identifier + ) + ).scalar() + + if result: + return result.data + return None + else: + return [ + s.data + for s in db.execute( + select(Setting).filter(Setting.provider == cls.name) + ) + .scalars() + .all() + ] + + @classmethod + def remove_setting(cls, *, identifier: str = None, type: str = None): + """ + Remove a setting from the database. + + Args: + identifier (str, optional): The identifier of the setting to be removed. If not provided, the type will be used as the identifier. Defaults to None. + type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None. + """ + with SessionLocal() as db: + if not identifier and type: + identifier = type + if identifier: + db.execute( + select(Setting).filter( + Setting.provider == cls.name, Setting.type == identifier + ) + ).delete() + else: + db.execute( + select(Setting).filter(Setting.provider == cls.name) + ).delete() + db.commit() + @classmethod def get_settings(cls): ...