mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 17:24:54 +00:00
Add get_assets() method to retrieve a single file by ID
This commit is contained in:
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from ..chore import GenerationContext
|
||||
from ..models import SessionLocal, File
|
||||
from ..models import SessionLocal, File, Setting
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@@ -29,7 +29,7 @@ class BaseEngine(ABC):
|
||||
return mp.AudioFileClip(path).duration
|
||||
|
||||
@classmethod
|
||||
def get_assets(cls, *, type: str = None) -> list[File]:
|
||||
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
|
||||
with SessionLocal() as db:
|
||||
if type:
|
||||
return (
|
||||
@@ -41,6 +41,16 @@ class BaseEngine(ABC):
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
elif by_id:
|
||||
return (
|
||||
db.execute(
|
||||
select(File).filter(
|
||||
File.id == by_id, File.provider == cls.name
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
db.execute(select(File).filter(File.provider == cls.name))
|
||||
@@ -60,6 +70,87 @@ class BaseEngine(ABC):
|
||||
db.execute(select(File).filter(File.path == path)).delete()
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def store_setting(cls, *, type: str = None, data: dict):
|
||||
with SessionLocal() as db:
|
||||
# check if setting exists
|
||||
setting = db.execute(
|
||||
select(Setting).filter(
|
||||
Setting.provider == cls.name, Setting.type == type
|
||||
)
|
||||
).scalar()
|
||||
if setting:
|
||||
setting.data = data
|
||||
else:
|
||||
db.add(Setting(provider=cls.name, type=type, data=data))
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def get_setting(cls, *args, **kwargs):
|
||||
"""
|
||||
This method is deprecated, use retrieve_setting instead
|
||||
"""
|
||||
return cls.retrieve_setting(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def retrieve_setting(cls, *, identifier: str = None, type: str = None) -> dict | list[dict] | None:
|
||||
"""
|
||||
Retrieve a setting from the database based on the provided identifier or type.
|
||||
|
||||
Args:
|
||||
identifier (str, optional): The identifier of the setting. Defaults to None.
|
||||
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str | list[str] | None: The retrieved setting data, or None if not found.
|
||||
"""
|
||||
with SessionLocal() as db:
|
||||
if not identifier and type:
|
||||
identifier = type
|
||||
if identifier:
|
||||
result = db.execute(
|
||||
select(Setting).filter(
|
||||
Setting.provider == cls.name, Setting.type == identifier
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if result:
|
||||
return result.data
|
||||
return None
|
||||
else:
|
||||
return [
|
||||
s.data
|
||||
for s in db.execute(
|
||||
select(Setting).filter(Setting.provider == cls.name)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def remove_setting(cls, *, identifier: str = None, type: str = None):
|
||||
"""
|
||||
Remove a setting from the database.
|
||||
|
||||
Args:
|
||||
identifier (str, optional): The identifier of the setting to be removed. If not provided, the type will be used as the identifier. Defaults to None.
|
||||
type (str, optional): Deprecated. Now an alias for identifier, please use identifier instead. Defaults to None.
|
||||
"""
|
||||
with SessionLocal() as db:
|
||||
if not identifier and type:
|
||||
identifier = type
|
||||
if identifier:
|
||||
db.execute(
|
||||
select(Setting).filter(
|
||||
Setting.provider == cls.name, Setting.type == identifier
|
||||
)
|
||||
).delete()
|
||||
else:
|
||||
db.execute(
|
||||
select(Setting).filter(Setting.provider == cls.name)
|
||||
).delete()
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user