Refactor DallEAssetsEngine.py to improve image generation

This commit is contained in:
2024-02-20 14:55:25 +01:00
parent 6338f016d2
commit e6d58d2920

View File

@@ -1,34 +1,37 @@
import gradio as gr
import openai
import moviepy.editor as mp
from moviepy.video.fx.resize import resize
import io
import base64
import time
import requests
import os
from moviepy.video.fx.resize import resize
from typing import Literal, TypedDict
from . import BaseAssetsEngine
class Spec(TypedDict):
prompt: str
start: float
end: float
style: Literal["vivid", "natural"]
class DallEAssetsEngine(BaseAssetsEngine):
name = "DALL-E"
description = "A powerful image generation model by OpenAI."
spec_name = "dalle"
spec_description = "Use the dall-e 3 model to generate images from a detailed prompt."
spec_description = (
"Use the dall-e 3 model to generate images from a detailed prompt."
)
specification = {
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
"start": "The starting time of the video clip. [float]",
"end": "The ending time of the video clip. [float]",
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]"
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]",
}
num_options = 1
@@ -37,7 +40,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
super().__init__()
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]:
max_width = self.ctx.width / 3 * 2
clips = []
@@ -46,7 +49,13 @@ class DallEAssetsEngine(BaseAssetsEngine):
start = option["start"]
end = option["end"]
style = option["style"]
size = "1024x1024" if self.aspect_ratio == "square" else "1024x1792" if self.aspect_ratio == "portrait" else "1792x1024"
size = (
"1024x1024"
if self.aspect_ratio == "square"
else "1024x1792"
if self.aspect_ratio == "portrait"
else "1792x1024"
)
try:
response = openai.images.generate(
model="dall-e-3",
@@ -54,11 +63,11 @@ class DallEAssetsEngine(BaseAssetsEngine):
size=size,
n=1,
style=style,
response_format="url"
response_format="url",
)
except openai.BadRequestError as e:
if e.code == "content_policy_violation":
#we skip this prompt
# we skip this prompt
continue
else:
raise
@@ -74,14 +83,18 @@ class DallEAssetsEngine(BaseAssetsEngine):
if self.aspect_ratio == "portrait":
img: mp.ImageClip = img.set_position(("center", "top"))
elif self.aspect_ratio == "landscape":
img = img.set_position(("center", "center"))
img: mp.ImageClip = img.set_position(("center", "top"))
elif self.aspect_ratio == "square":
img = img.set_position(("center", "center"))
img: mp.ImageClip = img.set_position(("center", "top"))
clips.append(img)
return clips
@classmethod
def get_options(cls):
return [
gr.Radio(["portrait", "square", "landscape"], label="Aspect Ratio", value="square")
]
gr.Radio(
["portrait", "square", "landscape"],
label="Aspect Ratio",
value="square",
)
]