mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Refactor GenerationContext class
This commit is contained in:
@@ -5,25 +5,29 @@ import os
|
||||
from .. import engines
|
||||
from ..utils.prompting import get_prompt
|
||||
|
||||
class GenerationContext:
|
||||
|
||||
def __init__(self, powerfulllmengine: engines.LLMEngine.BaseLLMEngine, simplellmengine: engines.LLMEngine.BaseLLMEngine, scriptengine: engines.ScriptEngine.BaseScriptEngine, ttsengine: engines.TTSEngine.BaseTTSEngine) -> None:
|
||||
self.powerfulllmengine = powerfulllmengine
|
||||
class GenerationContext:
|
||||
def __init__(
|
||||
self, powerfulllmengine, simplellmengine, scriptengine, ttsengine
|
||||
) -> None:
|
||||
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine
|
||||
self.powerfulllmengine.ctx = self
|
||||
|
||||
self.simplellmengine = simplellmengine
|
||||
self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine
|
||||
self.simplellmengine.ctx = self
|
||||
|
||||
self.scriptengine = scriptengine
|
||||
self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine
|
||||
self.scriptengine.ctx = self
|
||||
|
||||
self.ttsengine = ttsengine
|
||||
self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine
|
||||
self.ttsengine.ctx = self
|
||||
|
||||
def setup_dir(self):
|
||||
self.dir = f"output/{time.time()}"
|
||||
os.makedirs(self.dir)
|
||||
|
||||
def process(self):
|
||||
# IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
|
||||
self.setup_dir()
|
||||
|
||||
script = self.scriptengine.generate()
|
||||
|
||||
Reference in New Issue
Block a user