Files
viralfactory/src/engines/AssetsEngine/DallEAssetsEngine.py

117 lines
4.0 KiB
Python
Raw Normal View History

2024-02-23 09:50:43 +01:00
import os
2024-02-29 16:51:40 +01:00
from typing import Literal, TypedDict, List
2024-02-23 09:50:43 +01:00
2024-02-18 00:56:49 +01:00
import gradio as gr
2024-03-02 15:19:30 +01:00
import moviepy as mp
import moviepy.video.fx as vfx
2024-02-23 09:50:43 +01:00
import openai
2024-02-29 16:51:40 +01:00
from openai import OpenAI
2024-02-18 00:56:49 +01:00
import requests
from . import BaseAssetsEngine
2024-02-18 00:56:49 +01:00
class Spec(TypedDict):
prompt: str
start: float
end: float
style: Literal["vivid", "natural"]
2024-02-18 00:56:49 +01:00
class DallEAssetsEngine(BaseAssetsEngine):
name = "DALL-E"
description = "A powerful image generation model by OpenAI."
spec_name = "dalle"
spec_description = (
"Use the dall-e 3 model to generate images from a detailed prompt."
)
2024-02-18 00:56:49 +01:00
specification = {
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
"start": "The starting time of the video clip. [float]",
"end": "The ending time of the video clip. [float]",
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]",
2024-02-18 00:56:49 +01:00
}
num_options = 1
def __init__(self, options: dict):
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
2024-02-29 16:51:40 +01:00
api_key = self.retrieve_setting(identifier="openai_api_key")
if not api_key:
raise ValueError("OpenAI API key is not set.")
self.client = OpenAI(api_key=api_key["api_key"])
2024-02-18 00:56:49 +01:00
super().__init__()
2024-02-23 13:12:48 +01:00
def generate(self, options: list[Spec]) -> list[mp.ImageClip]:
max_width = self.ctx.width / 3 * 2
2024-02-18 00:56:49 +01:00
clips = []
for option in options:
prompt = option["prompt"]
start = option["start"]
end = option["end"]
style = option["style"]
2024-02-23 13:12:48 +01:00
size: Literal["1024x1024", "1024x1792", "1792x1024"] = (
"1024x1024"
if self.aspect_ratio == "square"
else "1024x1792"
if self.aspect_ratio == "portrait"
else "1792x1024"
)
2024-02-18 00:56:49 +01:00
try:
2024-02-29 16:51:40 +01:00
response = self.client.images.generate(
2024-02-18 00:56:49 +01:00
model="dall-e-3",
prompt=prompt,
size=size,
n=1,
style=style,
response_format="url",
2024-02-18 00:56:49 +01:00
)
except openai.BadRequestError as e:
if e.code == "content_policy_violation":
# we skip this prompt
2024-02-18 00:56:49 +01:00
continue
else:
raise
2024-02-23 13:12:48 +01:00
img_bytes = requests.get(response.data[0].url)
2024-02-18 00:56:49 +01:00
with open("temp.png", "wb") as f:
2024-02-23 13:12:48 +01:00
f.write(img_bytes.content)
2024-02-18 00:56:49 +01:00
img = mp.ImageClip("temp.png")
os.remove("temp.png")
2024-03-02 15:19:30 +01:00
img: mp.ImageClip = img.with_duration(end - start)
img: mp.ImageClip = img.with_start(start)
img: mp.ImageClip = img.with_effects([vfx.Resize(width=max_width)])
position = img.with_position(("center", "top"))
img: mp.ImageClip = img.with_position(position)
2024-02-18 00:56:49 +01:00
clips.append(img)
return clips
@classmethod
def get_options(cls):
return [
gr.Radio(
["portrait", "square", "landscape"],
label="Aspect Ratio",
value="square",
)
]
2024-02-29 16:51:40 +01:00
@classmethod
def get_settings(cls):
current_api_key: dict | list[dict] | None = cls.retrieve_setting(identifier="openai_api_key")
current_api_key = current_api_key["api_key"] if current_api_key else ""
api_key_input = gr.Textbox(
label="OpenAI API Key",
type="password",
value=current_api_key,
)
save = gr.Button("Save")
def save_api_key(api_key: str):
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
gr.Info("API key saved successfully.")
return gr.update(value=api_key)
save.click(save_api_key, inputs=[api_key_input])