Files
viralfactory/src/engines/AssetsEngine/AssetsEngineSelector.py

38 lines
1.2 KiB
Python
Raw Normal View History

2024-02-18 00:56:49 +01:00
import json
from ...chore import GenerationContext
2024-02-23 09:50:43 +01:00
from ...utils.prompting import get_prompt
2024-02-20 14:47:54 +01:00
2024-02-18 00:56:49 +01:00
class AssetsEngineSelector:
def __init__(self):
self.ctx: GenerationContext
def get_assets(self):
system_prompt, chat_prompt = get_prompt("assets", by_file_location=__file__)
engines_descriptors = ""
2024-02-20 14:47:54 +01:00
2024-02-18 00:56:49 +01:00
for engine in self.ctx.assetsengine:
2024-02-20 14:47:54 +01:00
engines_descriptors += (
f"name: '{engine.name}'\n{json.dumps(engine.specification)}\n"
)
2024-02-18 00:56:49 +01:00
system_prompt = system_prompt.replace("{engines}", engines_descriptors)
2024-02-20 14:47:54 +01:00
chat_prompt = chat_prompt.replace(
"{caption}", json.dumps(self.ctx.timed_script)
)
2024-02-18 00:56:49 +01:00
assets = self.ctx.powerfulllmengine.generate(
system_prompt=system_prompt,
chat_prompt=chat_prompt,
max_tokens=4096,
json_mode=True,
)["assets"]
clips: list = []
for engine in self.ctx.assetsengine:
2024-02-22 15:14:36 +01:00
assets_opts = [
asset["args"] for asset in assets if asset["engine"] == engine.name
]
2024-02-18 00:56:49 +01:00
clips.extend(engine.get_assets(assets_opts))
self.ctx.index_3.extend(clips)