Add get_assets() method to retrieve a single file by ID

This commit is contained in:
2024-02-22 15:15:52 +01:00
parent 00536c4894
commit 1aa0a5b986

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from sqlalchemy.future import select from sqlalchemy.future import select
from ..chore import GenerationContext from ..chore import GenerationContext
from ..models import SessionLocal, File from ..models import SessionLocal, File, Setting
class BaseEngine(ABC): class BaseEngine(ABC):
@@ -29,7 +29,7 @@ class BaseEngine(ABC):
return mp.AudioFileClip(path).duration return mp.AudioFileClip(path).duration
@classmethod @classmethod
def get_assets(cls, *, type: str = None) -> list[File]: def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
with SessionLocal() as db: with SessionLocal() as db:
if type: if type:
return ( return (
@@ -41,6 +41,16 @@ class BaseEngine(ABC):
.scalars() .scalars()
.all() .all()
) )
elif by_id:
return (
db.execute(
select(File).filter(
File.id == by_id, File.provider == cls.name
)
)
.scalars()
.first()
)
else: else:
return ( return (
db.execute(select(File).filter(File.provider == cls.name)) db.execute(select(File).filter(File.provider == cls.name))
@@ -60,6 +70,87 @@ class BaseEngine(ABC):
db.execute(select(File).filter(File.path == path)).delete() db.execute(select(File).filter(File.path == path)).delete()
db.commit() db.commit()
@classmethod
def store_setting(cls, *, type: str = None, data: dict):
with SessionLocal() as db:
# check if setting exists
setting = db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == type
)
).scalar()
if setting:
setting.data = data
else:
db.add(Setting(provider=cls.name, type=type, data=data))
db.commit()
@classmethod
def get_setting(cls, *args, **kwargs):
"""
This method is deprecated, use retrieve_setting instead
"""
return cls.retrieve_setting(*args, **kwargs)
@classmethod
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
"""
Retrieve a setting from the database based on the provided identifier or type.
Args:
identifier (str, optional): The identifier of the setting. Defaults to None.
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
Returns:
str | list[str] | None: The retrieved setting data, or None if not found.
"""
with SessionLocal() as db:
if not identifier and type:
identifier = type
if identifier:
result = db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier
)
).scalar()
if result:
return result.data
return None
else:
return [
s.data
for s in db.execute(
select(Setting).filter(Setting.provider == cls.name)
)
.scalars()
.all()
]
@classmethod
def remove_setting(cls, *, identifier: str = None, type: str = None):
"""
Remove a setting from the database.
Args:
identifier (str, optional): The identifier of the setting to be removed. If not provided, the type will be used as the identifier. Defaults to None.
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
"""
with SessionLocal() as db:
if not identifier and type:
identifier = type
if identifier:
db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier
)
).delete()
else:
db.execute(
select(Setting).filter(Setting.provider == cls.name)
).delete()
db.commit()
@classmethod @classmethod
def get_settings(cls): def get_settings(cls):
... ...