Refactor GenerationContext class

This commit is contained in:
2024-02-15 17:48:47 +01:00
parent 3cf9e46d53
commit e90ec054d5

View File

@@ -5,27 +5,31 @@ import os
from .. import engines from .. import engines
from ..utils.prompting import get_prompt from ..utils.prompting import get_prompt
class GenerationContext:
def __init__(self, powerfulllmengine: engines.LLMEngine.BaseLLMEngine, simplellmengine: engines.LLMEngine.BaseLLMEngine, scriptengine: engines.ScriptEngine.BaseScriptEngine, ttsengine: engines.TTSEngine.BaseTTSEngine) -> None: class GenerationContext:
self.powerfulllmengine = powerfulllmengine def __init__(
self, powerfulllmengine, simplellmengine, scriptengine, ttsengine
) -> None:
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine
self.powerfulllmengine.ctx = self self.powerfulllmengine.ctx = self
self.simplellmengine = simplellmengine self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine
self.simplellmengine.ctx = self self.simplellmengine.ctx = self
self.scriptengine = scriptengine self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine
self.scriptengine.ctx = self self.scriptengine.ctx = self
self.ttsengine = ttsengine self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine
self.ttsengine.ctx = self self.ttsengine.ctx = self
def setup_dir(self): def setup_dir(self):
self.dir = f"output/{time.time()}" self.dir = f"output/{time.time()}"
os.makedirs(self.dir) os.makedirs(self.dir)
def process(self): def process(self):
# IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
self.setup_dir() self.setup_dir()
script = self.scriptengine.generate() script = self.scriptengine.generate()
timed_script = self.ttsengine.synthesize(script, self.dir) timed_script = self.ttsengine.synthesize(script, self.dir)