From 9f88e6d06998aab687fe3b6311a6dbad73af21e1 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 15 Feb 2024 11:23:36 +0100 Subject: [PATCH] feat(GenerationContext.py): add new file GenerationContext.py to handle the context of generation engines feat(OpenaiLLMEngine.py): add orjson library for JSON serialization and deserialization, and implement the generate method to make API call to OpenAI chat completions endpoint feat(__init__.py): import OpenaiLLMEngine in LLMEngine package feat(BaseScriptEngine.py): add time_script method to the BaseScriptEngine class feat(CustomScriptEngine.py): add new file CustomScriptEngine.py to handle custom script generation, implement generate method to return the provided script, and add get_options method to provide a textbox for the prompt input feat(__init__.py): import CustomScriptEngine in ScriptEngine package feat(__init__.py): import LLMEngine package and add OpenaiLLMEngine to the ENGINES dictionary refactor(gradio_ui.py): change equal_height attribute of Row to False to allow different heights for input blocks --- src/chore/GenerationContext.py | 17 +++++++++++++ src/engines/LLMEngine/OpenaiLLMEngine.py | 24 ++++++++++++++++-- src/engines/LLMEngine/__init__.py | 3 ++- src/engines/ScriptEngine/BaseScriptEngine.py | 3 +++ .../ScriptEngine/CustomScriptEngine.py | 25 +++++++++++++++++++ src/engines/ScriptEngine/__init__.py | 1 + src/engines/__init__.py | 4 ++- ui/gradio_ui.py | 2 +- 8 files changed, 74 insertions(+), 5 deletions(-) create mode 100644 src/chore/GenerationContext.py create mode 100644 src/engines/ScriptEngine/CustomScriptEngine.py diff --git a/src/chore/GenerationContext.py b/src/chore/GenerationContext.py new file mode 100644 index 0000000..bef49d5 --- /dev/null +++ b/src/chore/GenerationContext.py @@ -0,0 +1,17 @@ +import moviepy + +from .. import engines +class GenerationContext: + + def __init__(self, llmengine: engines.LLMEngine.BaseLLMEngine, scriptengine: engines.ScriptEngine.BaseScriptEngine, ttsengine: engines.TTSEngine.BaseTTSEngine) -> None: + self.llmengine = llmengine + self.llmengine.ctx = self + + self.scriptengine = scriptengine + self.scriptengine.ctx = self + + self.ttsengine = ttsengine + self.ttsengine.ctx = self + + def process(self): + timed_script = self.scriptengine.generate() \ No newline at end of file diff --git a/src/engines/LLMEngine/OpenaiLLMEngine.py b/src/engines/LLMEngine/OpenaiLLMEngine.py index 2447e36..93088a0 100644 --- a/src/engines/LLMEngine/OpenaiLLMEngine.py +++ b/src/engines/LLMEngine/OpenaiLLMEngine.py @@ -1,5 +1,6 @@ import openai import gradio as gr +import orjson from abc import ABC, abstractmethod @@ -15,10 +16,29 @@ class OpenaiLLMEngine(BaseLLMEngine): name = "OpenAI" description = "OpenAI language model engine." + def __init__(self, options: list) -> None: + self.model = options[0] + super().__init__() + def generate(self, system_prompt: str, chat_prompt: str, max_tokens: int = 512, temperature: float = 1.0, json_mode: bool= False, top_p: float = 1, frequency_penalty: float = 0, presence_penalty: float = 0) -> str: - ... # TODO: Implement this method + response = openai.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": chat_prompt}, + ], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + response_format={ "type": "json_object" } if json_mode else openai._types.NOT_GIVEN + ) + return response.choices[0].message.content if not json_mode else orjson.loads(response.choices[0].message.content) - def get_options(self) -> list: + + @classmethod + def get_options(cls) -> list: return [ gr.Dropdown( label="Model", diff --git a/src/engines/LLMEngine/__init__.py b/src/engines/LLMEngine/__init__.py index f79e65d..6438acc 100644 --- a/src/engines/LLMEngine/__init__.py +++ b/src/engines/LLMEngine/__init__.py @@ -1 +1,2 @@ -from .BaseLLMEngine import BaseLLMEngine \ No newline at end of file +from .BaseLLMEngine import BaseLLMEngine +from .OpenaiLLMEngine import OpenaiLLMEngine \ No newline at end of file diff --git a/src/engines/ScriptEngine/BaseScriptEngine.py b/src/engines/ScriptEngine/BaseScriptEngine.py index a60f305..eb70094 100644 --- a/src/engines/ScriptEngine/BaseScriptEngine.py +++ b/src/engines/ScriptEngine/BaseScriptEngine.py @@ -8,3 +8,6 @@ class BaseScriptEngine(BaseEngine): @abstractmethod def generate(self) -> str: pass + + def time_script(self): + ... \ No newline at end of file diff --git a/src/engines/ScriptEngine/CustomScriptEngine.py b/src/engines/ScriptEngine/CustomScriptEngine.py new file mode 100644 index 0000000..13d5053 --- /dev/null +++ b/src/engines/ScriptEngine/CustomScriptEngine.py @@ -0,0 +1,25 @@ +from .BaseScriptEngine import BaseScriptEngine +import gradio as gr + + +class CustomScriptEngine(BaseScriptEngine): + name = "Custom Script Engine" + description = "Generate a script with a custom provided prompt" + num_options = 1 + + def __init__(self, options: list[list | tuple | str | int | float | bool | None]): + self.script = options[0] + super().__init__() + + def generate(self, *args, **kwargs) -> str: + return self.script + + @classmethod + def get_options(cls) -> list: + return [ + gr.Textbox( + label="Prompt", + placeholder="Enter your prompt here", + value="", + ) + ] \ No newline at end of file diff --git a/src/engines/ScriptEngine/__init__.py b/src/engines/ScriptEngine/__init__.py index c2e8e3f..7ffb5dc 100644 --- a/src/engines/ScriptEngine/__init__.py +++ b/src/engines/ScriptEngine/__init__.py @@ -1,2 +1,3 @@ from .BaseScriptEngine import BaseScriptEngine from .ShowerThoughtsScriptEngine import ShowerThoughtsScriptEngine +from .CustomScriptEngine import CustomScriptEngine \ No newline at end of file diff --git a/src/engines/__init__.py b/src/engines/__init__.py index 8562586..4b57911 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -1,8 +1,10 @@ from . import TTSEngine from .BaseEngine import BaseEngine from . import ScriptEngine +from . import LLMEngine ENGINES = { + "LLMEngine": [LLMEngine.OpenaiLLMEngine], "TTSEngine": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine], - "ScriptEngine": [ScriptEngine.ShowerThoughtsScriptEngine], + "ScriptEngine": [ScriptEngine.ShowerThoughtsScriptEngine, ScriptEngine.CustomScriptEngine], } diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index cc4f5b6..b7f6129 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -40,7 +40,7 @@ class GenerateUI: def get_generate_interface(self) -> gr.Blocks: with gr.Blocks() as interface: - with gr.Row() as row: + with gr.Row(equal_height=False) as row: inputs = [] with gr.Blocks() as col1: for engine_type, engines in ENGINES.items():