Add get_assets() method to retrieve a single file by ID

This commit is contained in:
2024-02-22 15:15:52 +01:00
parent 00536c4894
commit 1aa0a5b986

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from sqlalchemy.future import select
from ..chore import GenerationContext
from ..models import SessionLocal, File
from ..models import SessionLocal, File, Setting
class BaseEngine(ABC):
@@ -29,7 +29,7 @@ class BaseEngine(ABC):
return mp.AudioFileClip(path).duration
@classmethod
def get_assets(cls, *, type: str = None) -> list[File]:
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
with SessionLocal() as db:
if type:
return (
@@ -41,6 +41,16 @@ class BaseEngine(ABC):
.scalars()
.all()
)
elif by_id:
return (
db.execute(
select(File).filter(
File.id == by_id, File.provider == cls.name
)
)
.scalars()
.first()
)
else:
return (
db.execute(select(File).filter(File.provider == cls.name))
@@ -60,6 +70,87 @@ class BaseEngine(ABC):
db.execute(select(File).filter(File.path == path)).delete()
db.commit()
@classmethod
def store_setting(cls, *, type: str = None, data: dict):
with SessionLocal() as db:
# check if setting exists
setting = db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == type
)
).scalar()
if setting:
setting.data = data
else:
db.add(Setting(provider=cls.name, type=type, data=data))
db.commit()
@classmethod
def get_setting(cls, *args, **kwargs):
"""
This method is deprecated, use retrieve_setting instead
"""
return cls.retrieve_setting(*args, **kwargs)
@classmethod
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
"""
Retrieve a setting from the database based on the provided identifier or type.
Args:
identifier (str, optional): The identifier of the setting. Defaults to None.
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
Returns:
str | list[str] | None: The retrieved setting data, or None if not found.
"""
with SessionLocal() as db:
if not identifier and type:
identifier = type
if identifier:
result = db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier
)
).scalar()
if result:
return result.data
return None
else:
return [
s.data
for s in db.execute(
select(Setting).filter(Setting.provider == cls.name)
)
.scalars()
.all()
]
@classmethod
def remove_setting(cls, *, identifier: str = None, type: str = None):
"""
Remove a setting from the database.
Args:
identifier (str, optional): The identifier of the setting to be removed. If not provided, the type will be used as the identifier. Defaults to None.
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
"""
with SessionLocal() as db:
if not identifier and type:
identifier = type
if identifier:
db.execute(
select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier
)
).delete()
else:
db.execute(
select(Setting).filter(Setting.provider == cls.name)
).delete()
db.commit()
@classmethod
def get_settings(cls):
...