mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Refactor DallEAssetsEngine.py to improve image generation
This commit is contained in:
@@ -1,34 +1,37 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import openai
|
import openai
|
||||||
import moviepy.editor as mp
|
import moviepy.editor as mp
|
||||||
from moviepy.video.fx.resize import resize
|
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from moviepy.video.fx.resize import resize
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
from . import BaseAssetsEngine
|
from . import BaseAssetsEngine
|
||||||
|
|
||||||
|
|
||||||
class Spec(TypedDict):
|
class Spec(TypedDict):
|
||||||
prompt: str
|
prompt: str
|
||||||
start: float
|
start: float
|
||||||
end: float
|
end: float
|
||||||
style: Literal["vivid", "natural"]
|
style: Literal["vivid", "natural"]
|
||||||
|
|
||||||
|
|
||||||
class DallEAssetsEngine(BaseAssetsEngine):
|
class DallEAssetsEngine(BaseAssetsEngine):
|
||||||
|
|
||||||
name = "DALL-E"
|
name = "DALL-E"
|
||||||
description = "A powerful image generation model by OpenAI."
|
description = "A powerful image generation model by OpenAI."
|
||||||
spec_name = "dalle"
|
spec_name = "dalle"
|
||||||
spec_description = "Use the dall-e 3 model to generate images from a detailed prompt."
|
spec_description = (
|
||||||
|
"Use the dall-e 3 model to generate images from a detailed prompt."
|
||||||
|
)
|
||||||
specification = {
|
specification = {
|
||||||
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
|
"prompt": "A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str]",
|
||||||
"start": "The starting time of the video clip. [float]",
|
"start": "The starting time of the video clip. [float]",
|
||||||
"end": "The ending time of the video clip. [float]",
|
"end": "The ending time of the video clip. [float]",
|
||||||
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]"
|
"style": "The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str]",
|
||||||
}
|
}
|
||||||
|
|
||||||
num_options = 1
|
num_options = 1
|
||||||
@@ -37,7 +40,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
self.aspect_ratio: Literal["portrait", "square", "landscape"] = options[0]
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]:
|
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]:
|
||||||
max_width = self.ctx.width / 3 * 2
|
max_width = self.ctx.width / 3 * 2
|
||||||
clips = []
|
clips = []
|
||||||
@@ -46,7 +49,13 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
start = option["start"]
|
start = option["start"]
|
||||||
end = option["end"]
|
end = option["end"]
|
||||||
style = option["style"]
|
style = option["style"]
|
||||||
size = "1024x1024" if self.aspect_ratio == "square" else "1024x1792" if self.aspect_ratio == "portrait" else "1792x1024"
|
size = (
|
||||||
|
"1024x1024"
|
||||||
|
if self.aspect_ratio == "square"
|
||||||
|
else "1024x1792"
|
||||||
|
if self.aspect_ratio == "portrait"
|
||||||
|
else "1792x1024"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = openai.images.generate(
|
response = openai.images.generate(
|
||||||
model="dall-e-3",
|
model="dall-e-3",
|
||||||
@@ -54,11 +63,11 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
size=size,
|
size=size,
|
||||||
n=1,
|
n=1,
|
||||||
style=style,
|
style=style,
|
||||||
response_format="url"
|
response_format="url",
|
||||||
)
|
)
|
||||||
except openai.BadRequestError as e:
|
except openai.BadRequestError as e:
|
||||||
if e.code == "content_policy_violation":
|
if e.code == "content_policy_violation":
|
||||||
#we skip this prompt
|
# we skip this prompt
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@@ -74,14 +83,18 @@ class DallEAssetsEngine(BaseAssetsEngine):
|
|||||||
if self.aspect_ratio == "portrait":
|
if self.aspect_ratio == "portrait":
|
||||||
img: mp.ImageClip = img.set_position(("center", "top"))
|
img: mp.ImageClip = img.set_position(("center", "top"))
|
||||||
elif self.aspect_ratio == "landscape":
|
elif self.aspect_ratio == "landscape":
|
||||||
img = img.set_position(("center", "center"))
|
img: mp.ImageClip = img.set_position(("center", "top"))
|
||||||
elif self.aspect_ratio == "square":
|
elif self.aspect_ratio == "square":
|
||||||
img = img.set_position(("center", "center"))
|
img: mp.ImageClip = img.set_position(("center", "top"))
|
||||||
clips.append(img)
|
clips.append(img)
|
||||||
return clips
|
return clips
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_options(cls):
|
def get_options(cls):
|
||||||
return [
|
return [
|
||||||
gr.Radio(["portrait", "square", "landscape"], label="Aspect Ratio", value="square")
|
gr.Radio(
|
||||||
]
|
["portrait", "square", "landscape"],
|
||||||
|
label="Aspect Ratio",
|
||||||
|
value="square",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user