2024-02-23 09:50:43 +01:00
import os
from typing import Literal , TypedDict
2024-02-18 00:56:49 +01:00
import gradio as gr
import moviepy . editor as mp
2024-02-23 09:50:43 +01:00
import openai
2024-02-18 00:56:49 +01:00
import requests
2024-02-20 14:55:25 +01:00
from moviepy . video . fx . resize import resize
2024-02-18 00:56:49 +01:00
from . import BaseAssetsEngine
2024-02-20 14:55:25 +01:00
2024-02-18 00:56:49 +01:00
class Spec ( TypedDict ) :
prompt : str
start : float
end : float
style : Literal [ " vivid " , " natural " ]
2024-02-20 14:55:25 +01:00
2024-02-18 00:56:49 +01:00
class DallEAssetsEngine ( BaseAssetsEngine ) :
name = " DALL-E "
description = " A powerful image generation model by OpenAI. "
spec_name = " dalle "
2024-02-20 14:55:25 +01:00
spec_description = (
" Use the dall-e 3 model to generate images from a detailed prompt. "
)
2024-02-18 00:56:49 +01:00
specification = {
" prompt " : " A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str] " ,
" start " : " The starting time of the video clip. [float] " ,
" end " : " The ending time of the video clip. [float] " ,
2024-02-20 14:55:25 +01:00
" style " : " The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str] " ,
2024-02-18 00:56:49 +01:00
}
num_options = 1
def __init__ ( self , options : dict ) :
self . aspect_ratio : Literal [ " portrait " , " square " , " landscape " ] = options [ 0 ]
super ( ) . __init__ ( )
2024-02-20 14:55:25 +01:00
2024-02-18 00:56:49 +01:00
def get_assets ( self , options : list [ Spec ] ) - > list [ mp . ImageClip ] :
2024-02-18 22:49:56 +01:00
max_width = self . ctx . width / 3 * 2
2024-02-18 00:56:49 +01:00
clips = [ ]
for option in options :
prompt = option [ " prompt " ]
start = option [ " start " ]
end = option [ " end " ]
style = option [ " style " ]
2024-02-20 14:55:25 +01:00
size = (
" 1024x1024 "
if self . aspect_ratio == " square "
else " 1024x1792 "
if self . aspect_ratio == " portrait "
else " 1792x1024 "
)
2024-02-18 00:56:49 +01:00
try :
response = openai . images . generate (
model = " dall-e-3 " ,
prompt = prompt ,
size = size ,
n = 1 ,
style = style ,
2024-02-20 14:55:25 +01:00
response_format = " url " ,
2024-02-18 00:56:49 +01:00
)
except openai . BadRequestError as e :
if e . code == " content_policy_violation " :
2024-02-20 14:55:25 +01:00
# we skip this prompt
2024-02-18 00:56:49 +01:00
continue
else :
raise
img = requests . get ( response . data [ 0 ] . url )
with open ( " temp.png " , " wb " ) as f :
f . write ( img . content )
img = mp . ImageClip ( " temp.png " )
os . remove ( " temp.png " )
img : mp . ImageClip = img . set_duration ( end - start )
2024-02-18 22:49:56 +01:00
img : mp . ImageClip = img . set_start ( start )
img : mp . ImageClip = resize ( img , width = max_width )
2024-02-18 00:56:49 +01:00
if self . aspect_ratio == " portrait " :
2024-02-18 22:49:56 +01:00
img : mp . ImageClip = img . set_position ( ( " center " , " top " ) )
2024-02-18 00:56:49 +01:00
elif self . aspect_ratio == " landscape " :
2024-02-20 14:55:25 +01:00
img : mp . ImageClip = img . set_position ( ( " center " , " top " ) )
2024-02-18 00:56:49 +01:00
elif self . aspect_ratio == " square " :
2024-02-20 14:55:25 +01:00
img : mp . ImageClip = img . set_position ( ( " center " , " top " ) )
2024-02-18 00:56:49 +01:00
clips . append ( img )
return clips
@classmethod
def get_options ( cls ) :
return [
2024-02-20 14:55:25 +01:00
gr . Radio (
[ " portrait " , " square " , " landscape " ] ,
label = " Aspect Ratio " ,
value = " square " ,
)
]