mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Some changes
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, TypedDict, List
|
||||
|
||||
import gradio as gr
|
||||
import moviepy.editor as mp
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import requests
|
||||
from moviepy.video.fx.resize import resize
|
||||
|
||||
@@ -35,6 +36,10 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
||||
|
||||
def __init__(self, options: dict):
|
||||
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
||||
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key is not set.")
|
||||
self.client = OpenAI(api_key=api_key["api_key"])
|
||||
|
||||
super().__init__()
|
||||
|
||||
@@ -54,7 +59,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
||||
else "1792x1024"
|
||||
)
|
||||
try:
|
||||
response = openai.images.generate(
|
||||
response = self.client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=prompt,
|
||||
size=size,
|
||||
@@ -95,3 +100,21 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
||||
value="square",
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
current_api_key: dict | list[dict] | None = cls.retrieve_setting(identifier="openai_api_key")
|
||||
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||
api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
type="password",
|
||||
value=current_api_key,
|
||||
)
|
||||
save = gr.Button("Save")
|
||||
|
||||
def save_api_key(api_key: str):
|
||||
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||
gr.Info("API key saved successfully.")
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
@@ -49,15 +49,15 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
|
||||
clip = background.subclip(start, start + self.ctx.duration)
|
||||
w, h = clip.size
|
||||
self.ctx.credits += f"\n{self.background_video.data['credits']}"
|
||||
self.ctx.index_0.append(
|
||||
crop(
|
||||
clip,
|
||||
width=self.ctx.width,
|
||||
height=self.ctx.height,
|
||||
x_center=w / 2,
|
||||
y_center=h / 2,
|
||||
)
|
||||
)
|
||||
if w == h:
|
||||
clip = clip.resize(width=self.ctx.width) if w > h else clip.resize(height=self.ctx.height)
|
||||
elif w > h:
|
||||
clip = clip.resize(width=self.ctx.width)
|
||||
clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
|
||||
else:
|
||||
clip = clip.resize(height=self.ctx.height)
|
||||
clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
|
||||
self.ctx.index_0.append(clip)
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
|
||||
@@ -76,19 +76,21 @@ class BaseEngine(ABC):
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
@classmethod
|
||||
def store_setting(cls, *, type: str = None, data: dict):
|
||||
def store_setting(cls, *, identifier: str = None, type: str = None, data: dict):
|
||||
if not identifier and type:
|
||||
identifier = type
|
||||
with SessionLocal() as db:
|
||||
# check if setting exists
|
||||
# noinspection PyTypeChecker
|
||||
setting = db.execute(
|
||||
select(Setting).filter(
|
||||
Setting.provider == cls.name, Setting.type == type
|
||||
Setting.provider == cls.name, Setting.type == identifier
|
||||
)
|
||||
).scalar()
|
||||
if setting:
|
||||
setting.data = data
|
||||
else:
|
||||
db.add(Setting(provider=cls.name, type=type, data=data))
|
||||
db.add(Setting(provider=cls.name, type=identifier, data=data))
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import orjson
|
||||
|
||||
from .BaseLLMEngine import BaseLLMEngine
|
||||
@@ -17,6 +18,10 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
|
||||
def __init__(self, options: list) -> None:
|
||||
self.model = options[0]
|
||||
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key is not set.")
|
||||
self.client = OpenAI(api_key=api_key["api_key"])
|
||||
super().__init__()
|
||||
|
||||
def generate(
|
||||
@@ -30,7 +35,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
frequency_penalty: float = 0,
|
||||
presence_penalty: float = 0,
|
||||
) -> str | dict:
|
||||
response = openai.chat.completions.create(
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
@@ -61,3 +66,21 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
value=OPENAI_POSSIBLE_MODELS[0],
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
current_api_key = cls.retrieve_setting(identifier="openai_api_key")
|
||||
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||
api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
type="password",
|
||||
value=current_api_key,
|
||||
)
|
||||
save = gr.Button("Save")
|
||||
|
||||
def save_api_key(api_key: str):
|
||||
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||
gr.Info("API key saved successfully.")
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
@@ -57,13 +57,14 @@ class GenerateUI:
|
||||
...
|
||||
returnable.extend(values)
|
||||
else:
|
||||
raise ValueError("Preset not found")
|
||||
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
||||
raise gr.Error(f"Preset {preset_name} does not exist.")
|
||||
gr.Info(f"Preset {preset_name} loaded successfully.")
|
||||
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
||||
return load_preset
|
||||
|
||||
def get_save_preset_func(self):
|
||||
def save_preset(preset_name, *selected_inputs) -> list[gr.update]:
|
||||
with open("local/presets.json", "r") as f:
|
||||
with open("local/presets.json", "rb") as f:
|
||||
current_presets = orjson.loads(f.read())
|
||||
returnable = []
|
||||
poppable_inputs = list(selected_inputs)
|
||||
@@ -85,17 +86,20 @@ class GenerateUI:
|
||||
with open("local/presets.json", "wb") as f:
|
||||
current_presets[preset_name] = new_preset
|
||||
f.write(orjson.dumps(current_presets))
|
||||
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
||||
gr.Info(f"Preset {preset_name} saved successfully.")
|
||||
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
||||
return save_preset
|
||||
|
||||
def get_delete_preset_func(self):
|
||||
def delete_preset(preset_name) -> list[gr.update]:
|
||||
with open("local/presets.json", "r") as f:
|
||||
current_presets = orjson.loads(f.read())
|
||||
if not current_presets.get(preset_name):
|
||||
raise ValueError("You cannot delete a non-existing preset.")
|
||||
current_presets.pop(preset_name)
|
||||
with open("local/presets.json", "wb") as f:
|
||||
f.write(orjson.dumps(current_presets))
|
||||
return [gr.update(choices=list(current_presets.keys()), value=None)]
|
||||
return gr.Dropdown(choices=list(current_presets.keys()), value=None)
|
||||
return delete_preset
|
||||
def get_ui(self):
|
||||
ui = gr.TabbedInterface(
|
||||
@@ -180,9 +184,9 @@ class GenerateUI:
|
||||
preset_dropdown = gr.Dropdown(
|
||||
choices=list(presets.keys()),
|
||||
show_label=False,
|
||||
label="dd",
|
||||
label="",
|
||||
allow_custom_value=True,
|
||||
value=None,
|
||||
value="",
|
||||
)
|
||||
load_preset_button = gr.Button("📂", size="sm", variant="primary")
|
||||
save_preset_button = gr.Button("💾", size="sm", variant="secondary")
|
||||
@@ -195,8 +199,8 @@ class GenerateUI:
|
||||
outputs=[preset_dropdown, *inputs])
|
||||
save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs],
|
||||
outputs=[preset_dropdown, *inputs])
|
||||
delete_preset_button.click(delete_preset, inputs=[preset_dropdown],
|
||||
outputs=[preset_dropdown])
|
||||
delete_preset_button.click(delete_preset, inputs=preset_dropdown,
|
||||
outputs=preset_dropdown)
|
||||
output_title = gr.Markdown(visible=True, render=False)
|
||||
output_description = gr.Markdown(visible=True, render=False)
|
||||
output_video = gr.Video(visible=True, render=False)
|
||||
|
||||
Reference in New Issue
Block a user