Refactor DallEAssetsEngine.py to improve image generation

This commit is contained in:
2024-02-20 14:55:25 +01:00
parent 6338f016d2
commit e6d58d2920

View File

@@ -1,34 +1,37 @@
import gradio as gr import gradio as gr
import openai import openai
import moviepy.editor as mp import moviepy.editor as mp
from moviepy.video.fx.resize import resize
import io import io
import base64 import base64
import time import time
import requests import requests
import os import os
from moviepy.video.fx.resize import resize
from typing import Literal, TypedDict from typing import Literal, TypedDict
from . import BaseAssetsEngine from . import BaseAssetsEngine
class Spec(TypedDict): class Spec(TypedDict):
prompt: str prompt: str
start: float start: float
end: float end: float
style: Literal["vivid", "natural"] style: Literal["vivid", "natural"]
class DallEAssetsEngine(BaseAssetsEngine): class DallEAssetsEngine(BaseAssetsEngine):
name = "DALL-E" name = "DALL-E"
description = "A powerful image generation model by OpenAI." description = "A powerful image generation model by OpenAI."
spec_name = "dalle" spec_name = "dalle"
spec_description = "Use the dall-e 3 model to generate images from a detailed prompt." spec_description = (
"Use the dall-e 3 model to generate images from a detailed prompt."
)
specification = { specification = {
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]", "prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
"start": "The starting time of the video clip. [float]", "start": "The starting time of the video clip. [float]",
"end": "The ending time of the video clip. [float]", "end": "The ending time of the video clip. [float]",
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]" "style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]",
} }
num_options = 1 num_options = 1
@@ -37,7 +40,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0] self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
super().__init__() super().__init__()
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]: def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]:
max_width = self.ctx.width / 3 * 2 max_width = self.ctx.width / 3 * 2
clips = [] clips = []
@@ -46,7 +49,13 @@ class DallEAssetsEngine(BaseAssetsEngine):
start = option["start"] start = option["start"]
end = option["end"] end = option["end"]
style = option["style"] style = option["style"]
size = "1024x1024" if self.aspect_ratio == "square" else "1024x1792" if self.aspect_ratio == "portrait" else "1792x1024" size = (
"1024x1024"
if self.aspect_ratio == "square"
else "1024x1792"
if self.aspect_ratio == "portrait"
else "1792x1024"
)
try: try:
response = openai.images.generate( response = openai.images.generate(
model="dall-e-3", model="dall-e-3",
@@ -54,11 +63,11 @@ class DallEAssetsEngine(BaseAssetsEngine):
size=size, size=size,
n=1, n=1,
style=style, style=style,
response_format="url" response_format="url",
) )
except openai.BadRequestError as e: except openai.BadRequestError as e:
if e.code == "content_policy_violation": if e.code == "content_policy_violation":
#we skip this prompt # we skip this prompt
continue continue
else: else:
raise raise
@@ -74,14 +83,18 @@ class DallEAssetsEngine(BaseAssetsEngine):
if self.aspect_ratio == "portrait": if self.aspect_ratio == "portrait":
img: mp.ImageClip = img.set_position(("center", "top")) img: mp.ImageClip = img.set_position(("center", "top"))
elif self.aspect_ratio == "landscape": elif self.aspect_ratio == "landscape":
img = img.set_position(("center", "center")) img: mp.ImageClip = img.set_position(("center", "top"))
elif self.aspect_ratio == "square": elif self.aspect_ratio == "square":
img = img.set_position(("center", "center")) img: mp.ImageClip = img.set_position(("center", "top"))
clips.append(img) clips.append(img)
return clips return clips
@classmethod @classmethod
def get_options(cls): def get_options(cls):
return [ return [
gr.Radio(["portrait", "square", "landscape"], label="Aspect Ratio", value="square") gr.Radio(
] ["portrait", "square", "landscape"],
label="Aspect Ratio",
value="square",
)
]