mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Some changes
This commit is contained in:
@@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict, List
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import moviepy.editor as mp
|
import moviepy.editor as mp
|
||||||
import openai
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
import requests
|
import requests
|
||||||
from moviepy.video.fx.resize import resize
|
from moviepy.video.fx.resize import resize
|
||||||
|
|
||||||
@@ -35,6 +36,10 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
|
|
||||||
def __init__(self, options: dict):
|
def __init__(self, options: dict):
|
||||||
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
||||||
|
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("OpenAI API key is not set.")
|
||||||
|
self.client = OpenAI(api_key=api_key["api_key"])
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -54,7 +59,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
else "1792x1024"
|
else "1792x1024"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = openai.images.generate(
|
response = self.client.images.generate(
|
||||||
model="dall-e-3",
|
model="dall-e-3",
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
size=size,
|
size=size,
|
||||||
@@ -95,3 +100,21 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
value="square",
|
value="square",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_settings(cls):
|
||||||
|
current_api_key: dict | list[dict] | None = cls.retrieve_setting(identifier="openai_api_key")
|
||||||
|
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||||
|
api_key_input = gr.Textbox(
|
||||||
|
label="OpenAI API Key",
|
||||||
|
type="password",
|
||||||
|
value=current_api_key,
|
||||||
|
)
|
||||||
|
save = gr.Button("Save")
|
||||||
|
|
||||||
|
def save_api_key(api_key: str):
|
||||||
|
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||||
|
gr.Info("API key saved successfully.")
|
||||||
|
return gr.update(value=api_key)
|
||||||
|
|
||||||
|
save.click(save_api_key, inputs=[api_key_input])
|
||||||
|
|||||||
@@ -49,15 +49,15 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
|
|||||||
clip = background.subclip(start, start + self.ctx.duration)
|
clip = background.subclip(start, start + self.ctx.duration)
|
||||||
w, h = clip.size
|
w, h = clip.size
|
||||||
self.ctx.credits += f"\n{self.background_video.data['credits']}"
|
self.ctx.credits += f"\n{self.background_video.data['credits']}"
|
||||||
self.ctx.index_0.append(
|
if w == h:
|
||||||
crop(
|
clip = clip.resize(width=self.ctx.width) if w > h else clip.resize(height=self.ctx.height)
|
||||||
clip,
|
elif w > h:
|
||||||
width=self.ctx.width,
|
clip = clip.resize(width=self.ctx.width)
|
||||||
height=self.ctx.height,
|
clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
|
||||||
x_center=w / 2,
|
else:
|
||||||
y_center=h / 2,
|
clip = clip.resize(height=self.ctx.height)
|
||||||
)
|
clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
|
||||||
)
|
self.ctx.index_0.append(clip)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_settings(cls):
|
def get_settings(cls):
|
||||||
|
|||||||
@@ -76,19 +76,21 @@ class BaseEngine(ABC):
|
|||||||
|
|
||||||
# noinspection PyShadowingBuiltins
|
# noinspection PyShadowingBuiltins
|
||||||
@classmethod
|
@classmethod
|
||||||
def store_setting(cls, *, type: str = None, data: dict):
|
def store_setting(cls, *, identifier: str = None, type: str = None, data: dict):
|
||||||
|
if not identifier and type:
|
||||||
|
identifier = type
|
||||||
with SessionLocal() as db:
|
with SessionLocal() as db:
|
||||||
# check if setting exists
|
# check if setting exists
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
setting = db.execute(
|
setting = db.execute(
|
||||||
select(Setting).filter(
|
select(Setting).filter(
|
||||||
Setting.provider == cls.name, Setting.type == type
|
Setting.provider == cls.name, Setting.type == identifier
|
||||||
)
|
)
|
||||||
).scalar()
|
).scalar()
|
||||||
if setting:
|
if setting:
|
||||||
setting.data = data
|
setting.data = data
|
||||||
else:
|
else:
|
||||||
db.add(Setting(provider=cls.name, type=type, data=data))
|
db.add(Setting(provider=cls.name, type=identifier, data=data))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import openai
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from .BaseLLMEngine import BaseLLMEngine
|
from .BaseLLMEngine import BaseLLMEngine
|
||||||
@@ -17,6 +18,10 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
|||||||
|
|
||||||
def __init__(self, options: list) -> None:
|
def __init__(self, options: list) -> None:
|
||||||
self.model = options[0]
|
self.model = options[0]
|
||||||
|
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("OpenAI API key is not set.")
|
||||||
|
self.client = OpenAI(api_key=api_key["api_key"])
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -30,7 +35,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
|||||||
frequency_penalty: float = 0,
|
frequency_penalty: float = 0,
|
||||||
presence_penalty: float = 0,
|
presence_penalty: float = 0,
|
||||||
) -> str | dict:
|
) -> str | dict:
|
||||||
response = openai.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
@@ -61,3 +66,21 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
|||||||
value=OPENAI_POSSIBLE_MODELS[0],
|
value=OPENAI_POSSIBLE_MODELS[0],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_settings(cls):
|
||||||
|
current_api_key = cls.retrieve_setting(identifier="openai_api_key")
|
||||||
|
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||||
|
api_key_input = gr.Textbox(
|
||||||
|
label="OpenAI API Key",
|
||||||
|
type="password",
|
||||||
|
value=current_api_key,
|
||||||
|
)
|
||||||
|
save = gr.Button("Save")
|
||||||
|
|
||||||
|
def save_api_key(api_key: str):
|
||||||
|
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||||
|
gr.Info("API key saved successfully.")
|
||||||
|
return gr.update(value=api_key)
|
||||||
|
|
||||||
|
save.click(save_api_key, inputs=[api_key_input])
|
||||||
|
|||||||
@@ -57,13 +57,14 @@ class GenerateUI:
|
|||||||
...
|
...
|
||||||
returnable.extend(values)
|
returnable.extend(values)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Preset not found")
|
raise gr.Error(f"Preset {preset_name} does not exist.")
|
||||||
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
gr.Info(f"Preset {preset_name} loaded successfully.")
|
||||||
|
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
||||||
return load_preset
|
return load_preset
|
||||||
|
|
||||||
def get_save_preset_func(self):
|
def get_save_preset_func(self):
|
||||||
def save_preset(preset_name, *selected_inputs) -> list[gr.update]:
|
def save_preset(preset_name, *selected_inputs) -> list[gr.update]:
|
||||||
with open("local/presets.json", "r") as f:
|
with open("local/presets.json", "rb") as f:
|
||||||
current_presets = orjson.loads(f.read())
|
current_presets = orjson.loads(f.read())
|
||||||
returnable = []
|
returnable = []
|
||||||
poppable_inputs = list(selected_inputs)
|
poppable_inputs = list(selected_inputs)
|
||||||
@@ -85,17 +86,20 @@ class GenerateUI:
|
|||||||
with open("local/presets.json", "wb") as f:
|
with open("local/presets.json", "wb") as f:
|
||||||
current_presets[preset_name] = new_preset
|
current_presets[preset_name] = new_preset
|
||||||
f.write(orjson.dumps(current_presets))
|
f.write(orjson.dumps(current_presets))
|
||||||
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
gr.Info(f"Preset {preset_name} saved successfully.")
|
||||||
|
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable]
|
||||||
return save_preset
|
return save_preset
|
||||||
|
|
||||||
def get_delete_preset_func(self):
|
def get_delete_preset_func(self):
|
||||||
def delete_preset(preset_name) -> list[gr.update]:
|
def delete_preset(preset_name) -> list[gr.update]:
|
||||||
with open("local/presets.json", "r") as f:
|
with open("local/presets.json", "r") as f:
|
||||||
current_presets = orjson.loads(f.read())
|
current_presets = orjson.loads(f.read())
|
||||||
|
if not current_presets.get(preset_name):
|
||||||
|
raise ValueError("You cannot delete a non-existing preset.")
|
||||||
current_presets.pop(preset_name)
|
current_presets.pop(preset_name)
|
||||||
with open("local/presets.json", "wb") as f:
|
with open("local/presets.json", "wb") as f:
|
||||||
f.write(orjson.dumps(current_presets))
|
f.write(orjson.dumps(current_presets))
|
||||||
return [gr.update(choices=list(current_presets.keys()), value=None)]
|
return gr.Dropdown(choices=list(current_presets.keys()), value=None)
|
||||||
return delete_preset
|
return delete_preset
|
||||||
def get_ui(self):
|
def get_ui(self):
|
||||||
ui = gr.TabbedInterface(
|
ui = gr.TabbedInterface(
|
||||||
@@ -180,9 +184,9 @@ class GenerateUI:
|
|||||||
preset_dropdown = gr.Dropdown(
|
preset_dropdown = gr.Dropdown(
|
||||||
choices=list(presets.keys()),
|
choices=list(presets.keys()),
|
||||||
show_label=False,
|
show_label=False,
|
||||||
label="dd",
|
label="",
|
||||||
allow_custom_value=True,
|
allow_custom_value=True,
|
||||||
value=None,
|
value="",
|
||||||
)
|
)
|
||||||
load_preset_button = gr.Button("📂", size="sm", variant="primary")
|
load_preset_button = gr.Button("📂", size="sm", variant="primary")
|
||||||
save_preset_button = gr.Button("💾", size="sm", variant="secondary")
|
save_preset_button = gr.Button("💾", size="sm", variant="secondary")
|
||||||
@@ -195,8 +199,8 @@ class GenerateUI:
|
|||||||
outputs=[preset_dropdown, *inputs])
|
outputs=[preset_dropdown, *inputs])
|
||||||
save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs],
|
save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs],
|
||||||
outputs=[preset_dropdown, *inputs])
|
outputs=[preset_dropdown, *inputs])
|
||||||
delete_preset_button.click(delete_preset, inputs=[preset_dropdown],
|
delete_preset_button.click(delete_preset, inputs=preset_dropdown,
|
||||||
outputs=[preset_dropdown])
|
outputs=preset_dropdown)
|
||||||
output_title = gr.Markdown(visible=True, render=False)
|
output_title = gr.Markdown(visible=True, render=False)
|
||||||
output_description = gr.Markdown(visible=True, render=False)
|
output_description = gr.Markdown(visible=True, render=False)
|
||||||
output_video = gr.Video(visible=True, render=False)
|
output_video = gr.Video(visible=True, render=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user