mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Some changes
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, TypedDict, List
|
||||
|
||||
import gradio as gr
|
||||
import moviepy.editor as mp
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import requests
|
||||
from moviepy.video.fx.resize import resize
|
||||
|
||||
@@ -35,6 +36,10 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
||||
|
||||
def __init__(self, options: dict):
|
||||
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
||||
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key is not set.")
|
||||
self.client = OpenAI(api_key=api_key["api_key"])
|
||||
|
||||
super().__init__()
|
||||
|
||||
@@ -54,7 +59,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
||||
else "1792x1024"
|
||||
)
|
||||
try:
|
||||
response = openai.images.generate(
|
||||
response = self.client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=prompt,
|
||||
size=size,
|
||||
@@ -95,3 +100,21 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
||||
value="square",
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
current_api_key: dict | list[dict] | None = cls.retrieve_setting(identifier="openai_api_key")
|
||||
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||
api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
type="password",
|
||||
value=current_api_key,
|
||||
)
|
||||
save = gr.Button("Save")
|
||||
|
||||
def save_api_key(api_key: str):
|
||||
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||
gr.Info("API key saved successfully.")
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
@@ -49,15 +49,15 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
|
||||
clip = background.subclip(start, start + self.ctx.duration)
|
||||
w, h = clip.size
|
||||
self.ctx.credits += f"\n{self.background_video.data['credits']}"
|
||||
self.ctx.index_0.append(
|
||||
crop(
|
||||
clip,
|
||||
width=self.ctx.width,
|
||||
height=self.ctx.height,
|
||||
x_center=w / 2,
|
||||
y_center=h / 2,
|
||||
)
|
||||
)
|
||||
if w == h:
|
||||
clip = clip.resize(width=self.ctx.width) if w > h else clip.resize(height=self.ctx.height)
|
||||
elif w > h:
|
||||
clip = clip.resize(width=self.ctx.width)
|
||||
clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
|
||||
else:
|
||||
clip = clip.resize(height=self.ctx.height)
|
||||
clip = crop(clip, width=self.ctx.width, height=self.ctx.height, x_center=w / 2, y_center=h / 2)
|
||||
self.ctx.index_0.append(clip)
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
|
||||
@@ -76,19 +76,21 @@ class BaseEngine(ABC):
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
@classmethod
|
||||
def store_setting(cls, *, type: str = None, data: dict):
|
||||
def store_setting(cls, *, identifier: str = None, type: str = None, data: dict):
|
||||
if not identifier and type:
|
||||
identifier = type
|
||||
with SessionLocal() as db:
|
||||
# check if setting exists
|
||||
# noinspection PyTypeChecker
|
||||
setting = db.execute(
|
||||
select(Setting).filter(
|
||||
Setting.provider == cls.name, Setting.type == type
|
||||
Setting.provider == cls.name, Setting.type == identifier
|
||||
)
|
||||
).scalar()
|
||||
if setting:
|
||||
setting.data = data
|
||||
else:
|
||||
db.add(Setting(provider=cls.name, type=type, data=data))
|
||||
db.add(Setting(provider=cls.name, type=identifier, data=data))
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import orjson
|
||||
|
||||
from .BaseLLMEngine import BaseLLMEngine
|
||||
@@ -17,6 +18,10 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
|
||||
def __init__(self, options: list) -> None:
|
||||
self.model = options[0]
|
||||
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key is not set.")
|
||||
self.client = OpenAI(api_key=api_key["api_key"])
|
||||
super().__init__()
|
||||
|
||||
def generate(
|
||||
@@ -30,7 +35,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
frequency_penalty: float = 0,
|
||||
presence_penalty: float = 0,
|
||||
) -> str | dict:
|
||||
response = openai.chat.completions.create(
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
@@ -61,3 +66,21 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
value=OPENAI_POSSIBLE_MODELS[0],
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
current_api_key = cls.retrieve_setting(identifier="openai_api_key")
|
||||
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||
api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
type="password",
|
||||
value=current_api_key,
|
||||
)
|
||||
save = gr.Button("Save")
|
||||
|
||||
def save_api_key(api_key: str):
|
||||
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||
gr.Info("API key saved successfully.")
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
Reference in New Issue
Block a user