From be3776e1ce2d0e8df982cea6d530af7b14d52b43 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 Feb 2024 16:51:40 +0100 Subject: [PATCH] Some changes --- src/engines/AssetsEngine/DallEAssetsEngine.py | 27 +++++++++++++++++-- .../BackgroundEngine/VideoBackgroundEngine.py | 18 ++++++------- src/engines/BaseEngine.py | 8 +++--- src/engines/LLMEngine/OpenaiLLMEngine.py | 25 ++++++++++++++++- ui/gradio_ui.py | 22 ++++++++------- 5 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/engines/AssetsEngine/DallEAssetsEngine.py b/src/engines/AssetsEngine/DallEAssetsEngine.py index 46f9d83..8be64fb 100644 --- a/src/engines/AssetsEngine/DallEAssetsEngine.py +++ b/src/engines/AssetsEngine/DallEAssetsEngine.py @@ -1,9 +1,10 @@ import os -from typing import Literal, TypedDict +from typing import Literal, TypedDict, List import gradio as gr import moviepy.editor as mp import openai +from openai import OpenAI import requests from moviepy.video.fx.resize import resize @@ -35,6 +36,10 @@ class DallEAssetsEngine(BaseAssetsEngine): def __init__(self, options: dict): self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0] + api_key = self.retrieve_setting(identifier="openai_api_key") + if not api_key: + raise ValueError("OpenAI API key is not set.") + self.client = OpenAI(api_key=api_key["api_key"]) super().__init__() @@ -54,7 +59,7 @@ class DallEAssetsEngine(BaseAssetsEngine): else "1792x1024" ) try: - response = openai.images.generate( + response = self.client.images.generate( model="dall-e-3", prompt=prompt, size=size, @@ -95,3 +100,21 @@ class DallEAssetsEngine(BaseAssetsEngine): value="square", ) ] + + @classmethod + def get_settings(cls): + current_api_key: dict | list[dict] | None = cls.retrieve_setting(identifier="openai_api_key") + current_api_key = current_api_key["api_key"] if current_api_key else "" + api_key_input = gr.Textbox( + label="OpenAI API Key", + type="password", + value=current_api_key, + ) + save = gr.Button("Save") + + def save_api_key(api_key: str): + cls.store_setting(identifier="openai_api_key", data={"api_key": api_key}) + gr.Info("API key saved successfully.") + return gr.update(value=api_key) + + save.click(save_api_key, inputs=[api_key_input]) diff --git a/src/engines/BackgroundEngine/VideoBackgroundEngine.py b/src/engines/BackgroundEngine/VideoBackgroundEngine.py index 381f1df..3669e3b 100644 --- a/src/engines/BackgroundEngine/VideoBackgroundEngine.py +++ b/src/engines/BackgroundEngine/VideoBackgroundEngine.py @@ -49,15 +49,15 @@ class VideoBackgroundEngine(BaseBackgroundEngine): clip = background.subclip(start, start + self.ctx.duration) w, h = clip.size self.ctx.credits += f"\n{self.background_video.data['credits']}" - self.ctx.index_0.append( - crop( - clip, - width=self.ctx.width, - height=self.ctx.height, - x_center=w / 2, - y_center=h / 2, - ) - ) + if w == h: + clip = clip.resize(width=self.ctx.width) if w > h else clip.resize(height=self.ctx.height) + elif w > h: + clip = clip.resize(width=self.ctx.width) + clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2) + else: + clip = clip.resize(height=self.ctx.height) + clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2) + self.ctx.index_0.append(clip) @classmethod def get_settings(cls): diff --git a/src/engines/BaseEngine.py b/src/engines/BaseEngine.py index 586758c..05fecbf 100644 --- a/src/engines/BaseEngine.py +++ b/src/engines/BaseEngine.py @@ -76,19 +76,21 @@ class BaseEngine(ABC): # noinspection PyShadowingBuiltins @classmethod - def store_setting(cls, *, type: str = None, data: dict): + def store_setting(cls, *, identifier: str = None, type: str = None, data: dict): + if not identifier and type: + identifier = type with SessionLocal() as db: # check if setting exists # noinspection PyTypeChecker setting = db.execute( select(Setting).filter( - Setting.provider == cls.name, Setting.type == type + Setting.provider == cls.name, Setting.type == identifier ) ).scalar() if setting: setting.data = data else: - db.add(Setting(provider=cls.name, type=type, data=data)) + db.add(Setting(provider=cls.name, type=identifier, data=data)) db.commit() @classmethod diff --git a/src/engines/LLMEngine/OpenaiLLMEngine.py b/src/engines/LLMEngine/OpenaiLLMEngine.py index 8751f9a..7677140 100644 --- a/src/engines/LLMEngine/OpenaiLLMEngine.py +++ b/src/engines/LLMEngine/OpenaiLLMEngine.py @@ -1,5 +1,6 @@ import gradio as gr import openai +from openai import OpenAI import orjson from .BaseLLMEngine import BaseLLMEngine @@ -17,6 +18,10 @@ class OpenaiLLMEngine(BaseLLMEngine): def __init__(self, options: list) -> None: self.model = options[0] + api_key = self.retrieve_setting(identifier="openai_api_key") + if not api_key: + raise ValueError("OpenAI API key is not set.") + self.client = OpenAI(api_key=api_key["api_key"]) super().__init__() def generate( @@ -30,7 +35,7 @@ class OpenaiLLMEngine(BaseLLMEngine): frequency_penalty: float = 0, presence_penalty: float = 0, ) -> str | dict: - response = openai.chat.completions.create( + response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": system_prompt}, @@ -61,3 +66,21 @@ class OpenaiLLMEngine(BaseLLMEngine): value=OPENAI_POSSIBLE_MODELS[0], ) ] + + @classmethod + def get_settings(cls): + current_api_key = cls.retrieve_setting(identifier="openai_api_key") + current_api_key = current_api_key["api_key"] if current_api_key else "" + api_key_input = gr.Textbox( + label="OpenAI API Key", + type="password", + value=current_api_key, + ) + save = gr.Button("Save") + + def save_api_key(api_key: str): + cls.store_setting(identifier="openai_api_key", data={"api_key": api_key}) + gr.Info("API key saved successfully.") + return gr.update(value=api_key) + + save.click(save_api_key, inputs=[api_key_input]) diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index bd21ea1..1b5a423 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -57,13 +57,14 @@ class GenerateUI: ... returnable.extend(values) else: - raise ValueError("Preset not found") - return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] + raise gr.Error(f"Preset {preset_name} does not exist.") + gr.Info(f"Preset {preset_name} loaded successfully.") + return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable] return load_preset def get_save_preset_func(self): def save_preset(preset_name, *selected_inputs) -> list[gr.update]: - with open("local/presets.json", "r") as f: + with open("local/presets.json", "rb") as f: current_presets = orjson.loads(f.read()) returnable = [] poppable_inputs = list(selected_inputs) @@ -85,17 +86,20 @@ class GenerateUI: with open("local/presets.json", "wb") as f: current_presets[preset_name] = new_preset f.write(orjson.dumps(current_presets)) - return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] + gr.Info(f"Preset {preset_name} saved successfully.") + return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable] return save_preset def get_delete_preset_func(self): def delete_preset(preset_name) -> list[gr.update]: with open("local/presets.json", "r") as f: current_presets = orjson.loads(f.read()) + if not current_presets.get(preset_name): + raise ValueError("You cannot delete a non-existing preset.") current_presets.pop(preset_name) with open("local/presets.json", "wb") as f: f.write(orjson.dumps(current_presets)) - return [gr.update(choices=list(current_presets.keys()), value=None)] + return gr.Dropdown(choices=list(current_presets.keys()), value=None) return delete_preset def get_ui(self): ui = gr.TabbedInterface( @@ -180,9 +184,9 @@ class GenerateUI: preset_dropdown = gr.Dropdown( choices=list(presets.keys()), show_label=False, - label="dd", + label="", allow_custom_value=True, - value=None, + value="", ) load_preset_button = gr.Button("📂", size="sm", variant="primary") save_preset_button = gr.Button("💾", size="sm", variant="secondary") @@ -195,8 +199,8 @@ class GenerateUI: outputs=[preset_dropdown, *inputs]) save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown, *inputs]) - delete_preset_button.click(delete_preset, inputs=[preset_dropdown], - outputs=[preset_dropdown]) + delete_preset_button.click(delete_preset, inputs=preset_dropdown, + outputs=preset_dropdown) output_title = gr.Markdown(visible=True, render=False) output_description = gr.Markdown(visible=True, render=False) output_video = gr.Video(visible=True, render=False)