From e90ec054d54eede444fff778e082d0d9af8eaba2 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 15 Feb 2024 17:48:47 +0100 Subject: [PATCH] Refactor GenerationContext class --- src/chore/GenerationContext.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/chore/GenerationContext.py b/src/chore/GenerationContext.py index dcd30cd..a6f99c5 100644 --- a/src/chore/GenerationContext.py +++ b/src/chore/GenerationContext.py @@ -5,27 +5,31 @@ import os from .. import engines from ..utils.prompting import get_prompt -class GenerationContext: - def __init__(self, powerfulllmengine: engines.LLMEngine.BaseLLMEngine, simplellmengine: engines.LLMEngine.BaseLLMEngine, scriptengine: engines.ScriptEngine.BaseScriptEngine, ttsengine: engines.TTSEngine.BaseTTSEngine) -> None: - self.powerfulllmengine = powerfulllmengine +class GenerationContext: + def __init__( + self, powerfulllmengine, simplellmengine, scriptengine, ttsengine + ) -> None: + self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine self.powerfulllmengine.ctx = self - - self.simplellmengine = simplellmengine + + self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine self.simplellmengine.ctx = self - self.scriptengine = scriptengine + self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine self.scriptengine.ctx = self - self.ttsengine = ttsengine + self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine self.ttsengine.ctx = self + def setup_dir(self): self.dir = f"output/{time.time()}" os.makedirs(self.dir) def process(self): + # IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation. self.setup_dir() script = self.scriptengine.generate() - timed_script = self.ttsengine.synthesize(script, self.dir) \ No newline at end of file + timed_script = self.ttsengine.synthesize(script, self.dir)