Some changes

This commit is contained in:
2024-02-29 16:51:40 +01:00
parent 5e7b32dc7a
commit be3776e1ce
5 changed files with 76 additions and 24 deletions

View File

@@ -1,9 +1,10 @@
import os import os
from typing import Literal, TypedDict from typing import Literal, TypedDict, List
import gradio as gr import gradio as gr
import moviepy.editor as mp import moviepy.editor as mp
import openai import openai
from openai import OpenAI
import requests import requests
from moviepy.video.fx.resize import resize from moviepy.video.fx.resize import resize
@@ -35,6 +36,10 @@ class DallEAssetsEngine(BaseAssetsEngine):
def __init__(self, options: dict): def __init__(self, options: dict):
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0] self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
api_key = self.retrieve_setting(identifier="openai_api_key")
if not api_key:
raise ValueError("OpenAI API key is not set.")
self.client = OpenAI(api_key=api_key["api_key"])
super().__init__() super().__init__()
@@ -54,7 +59,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
else "1792x1024" else "1792x1024"
) )
try: try:
response = openai.images.generate( response = self.client.images.generate(
model="dall-e-3", model="dall-e-3",
prompt=prompt, prompt=prompt,
size=size, size=size,
@@ -95,3 +100,21 @@ class DallEAssetsEngine(BaseAssetsEngine):
value="square", value="square",
) )
] ]
@classmethod
def get_settings(cls):
current_api_key: dict | list[dict] | None = cls.retrieve_setting(identifier="openai_api_key")
current_api_key = current_api_key["api_key"] if current_api_key else ""
api_key_input = gr.Textbox(
label="OpenAI API Key",
type="password",
value=current_api_key,
)
save = gr.Button("Save")
def save_api_key(api_key: str):
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
gr.Info("API key saved successfully.")
return gr.update(value=api_key)
save.click(save_api_key, inputs=[api_key_input])

View File

@@ -49,15 +49,15 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
clip = background.subclip(start, start + self.ctx.duration) clip = background.subclip(start, start + self.ctx.duration)
w, h = clip.size w, h = clip.size
self.ctx.credits += f"\n{self.background_video.data['credits']}" self.ctx.credits += f"\n{self.background_video.data['credits']}"
self.ctx.index_0.append( if w == h:
crop( clip = clip.resize(width=self.ctx.width) if w > h else clip.resize(height=self.ctx.height)
clip, elif w > h:
width=self.ctx.width, clip = clip.resize(width=self.ctx.width)
height=self.ctx.height, clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
x_center=w / 2, else:
y_center=h / 2, clip = clip.resize(height=self.ctx.height)
) clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
) self.ctx.index_0.append(clip)
@classmethod @classmethod
def get_settings(cls): def get_settings(cls):

View File

@@ -76,19 +76,21 @@ class BaseEngine(ABC):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
@classmethod @classmethod
def store_setting(cls, *, type: str = None, data: dict): def store_setting(cls, *, identifier: str = None, type: str = None, data: dict):
if not identifier and type:
identifier = type
with SessionLocal() as db: with SessionLocal() as db:
# check if setting exists # check if setting exists
# noinspection PyTypeChecker # noinspection PyTypeChecker
setting = db.execute( setting = db.execute(
select(Setting).filter( select(Setting).filter(
Setting.provider == cls.name, Setting.type == type Setting.provider == cls.name, Setting.type == identifier
) )
).scalar() ).scalar()
if setting: if setting:
setting.data = data setting.data = data
else: else:
db.add(Setting(provider=cls.name, type=type, data=data)) db.add(Setting(provider=cls.name, type=identifier, data=data))
db.commit() db.commit()
@classmethod @classmethod

View File

@@ -1,5 +1,6 @@
import gradio as gr import gradio as gr
import openai import openai
from openai import OpenAI
import orjson import orjson
from .BaseLLMEngine import BaseLLMEngine from .BaseLLMEngine import BaseLLMEngine
@@ -17,6 +18,10 @@ class OpenaiLLMEngine(BaseLLMEngine):
def __init__(self, options: list) -> None: def __init__(self, options: list) -> None:
self.model = options[0] self.model = options[0]
api_key = self.retrieve_setting(identifier="openai_api_key")
if not api_key:
raise ValueError("OpenAI API key is not set.")
self.client = OpenAI(api_key=api_key["api_key"])
super().__init__() super().__init__()
def generate( def generate(
@@ -30,7 +35,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
frequency_penalty: float = 0, frequency_penalty: float = 0,
presence_penalty: float = 0, presence_penalty: float = 0,
) -> str | dict: ) -> str | dict:
response = openai.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
@@ -61,3 +66,21 @@ class OpenaiLLMEngine(BaseLLMEngine):
value=OPENAI_POSSIBLE_MODELS[0], value=OPENAI_POSSIBLE_MODELS[0],
) )
] ]
@classmethod
def get_settings(cls):
current_api_key = cls.retrieve_setting(identifier="openai_api_key")
current_api_key = current_api_key["api_key"] if current_api_key else ""
api_key_input = gr.Textbox(
label="OpenAI API Key",
type="password",
value=current_api_key,
)
save = gr.Button("Save")
def save_api_key(api_key: str):
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
gr.Info("API key saved successfully.")
return gr.update(value=api_key)
save.click(save_api_key, inputs=[api_key_input])

View File

@@ -57,13 +57,14 @@ class GenerateUI:
... ...
returnable.extend(values) returnable.extend(values)
else: else:
raise ValueError("Preset not found") raise gr.Error(f"Preset {preset_name} does not exist.")
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] gr.Info(f"Preset {preset_name} loaded successfully.")
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable]
return load_preset return load_preset
def get_save_preset_func(self): def get_save_preset_func(self):
def save_preset(preset_name, *selected_inputs) -> list[gr.update]: def save_preset(preset_name, *selected_inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f: with open("local/presets.json", "rb") as f:
current_presets = orjson.loads(f.read()) current_presets = orjson.loads(f.read())
returnable = [] returnable = []
poppable_inputs = list(selected_inputs) poppable_inputs = list(selected_inputs)
@@ -85,17 +86,20 @@ class GenerateUI:
with open("local/presets.json", "wb") as f: with open("local/presets.json", "wb") as f:
current_presets[preset_name] = new_preset current_presets[preset_name] = new_preset
f.write(orjson.dumps(current_presets)) f.write(orjson.dumps(current_presets))
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] gr.Info(f"Preset {preset_name} saved successfully.")
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable]
return save_preset return save_preset
def get_delete_preset_func(self): def get_delete_preset_func(self):
def delete_preset(preset_name) -> list[gr.update]: def delete_preset(preset_name) -> list[gr.update]:
with open("local/presets.json", "r") as f: with open("local/presets.json", "r") as f:
current_presets = orjson.loads(f.read()) current_presets = orjson.loads(f.read())
if not current_presets.get(preset_name):
raise ValueError("You cannot delete a non-existing preset.")
current_presets.pop(preset_name) current_presets.pop(preset_name)
with open("local/presets.json", "wb") as f: with open("local/presets.json", "wb") as f:
f.write(orjson.dumps(current_presets)) f.write(orjson.dumps(current_presets))
return [gr.update(choices=list(current_presets.keys()), value=None)] return gr.Dropdown(choices=list(current_presets.keys()), value=None)
return delete_preset return delete_preset
def get_ui(self): def get_ui(self):
ui = gr.TabbedInterface( ui = gr.TabbedInterface(
@@ -180,9 +184,9 @@ class GenerateUI:
preset_dropdown = gr.Dropdown( preset_dropdown = gr.Dropdown(
choices=list(presets.keys()), choices=list(presets.keys()),
show_label=False, show_label=False,
label="dd", label="",
allow_custom_value=True, allow_custom_value=True,
value=None, value="",
) )
load_preset_button = gr.Button("📂", size="sm", variant="primary") load_preset_button = gr.Button("📂", size="sm", variant="primary")
save_preset_button = gr.Button("💾", size="sm", variant="secondary") save_preset_button = gr.Button("💾", size="sm", variant="secondary")
@@ -195,8 +199,8 @@ class GenerateUI:
outputs=[preset_dropdown, *inputs]) outputs=[preset_dropdown, *inputs])
save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs], save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs],
outputs=[preset_dropdown, *inputs]) outputs=[preset_dropdown, *inputs])
delete_preset_button.click(delete_preset, inputs=[preset_dropdown], delete_preset_button.click(delete_preset, inputs=preset_dropdown,
outputs=[preset_dropdown]) outputs=preset_dropdown)
output_title = gr.Markdown(visible=True, render=False) output_title = gr.Markdown(visible=True, render=False)
output_description = gr.Markdown(visible=True, render=False) output_description = gr.Markdown(visible=True, render=False)
output_video = gr.Video(visible=True, render=False) output_video = gr.Video(visible=True, render=False)