2024-02-23 09:50:43 +01:00
import os
2024-02-29 16:51:40 +01:00
from typing import Literal , TypedDict , List
2024-02-23 09:50:43 +01:00
2024-02-18 00:56:49 +01:00
import gradio as gr
import moviepy . editor as mp
2024-02-23 09:50:43 +01:00
import openai
2024-02-29 16:51:40 +01:00
from openai import OpenAI
2024-02-18 00:56:49 +01:00
import requests
2024-02-20 14:55:25 +01:00
from moviepy . video . fx . resize import resize
2024-02-18 00:56:49 +01:00
from . import BaseAssetsEngine
2024-02-20 14:55:25 +01:00
2024-02-18 00:56:49 +01:00
class Spec ( TypedDict ) :
prompt : str
start : float
end : float
style : Literal [ " vivid " , " natural " ]
2024-02-20 14:55:25 +01:00
2024-02-18 00:56:49 +01:00
class DallEAssetsEngine ( BaseAssetsEngine ) :
name = " DALL-E "
description = " A powerful image generation model by OpenAI. "
spec_name = " dalle "
2024-02-20 14:55:25 +01:00
spec_description = (
" Use the dall-e 3 model to generate images from a detailed prompt. "
)
2024-02-18 00:56:49 +01:00
specification = {
" prompt " : " A detailed prompt to generate the image from. Describe every subtle detail of the image you want to generate. [str] " ,
" start " : " The starting time of the video clip. [float] " ,
" end " : " The ending time of the video clip. [float] " ,
2024-02-20 14:55:25 +01:00
" style " : " The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. [str] " ,
2024-02-18 00:56:49 +01:00
}
num_options = 1
def __init__ ( self , options : dict ) :
self . aspect_ratio : Literal [ " portrait " , " square " , " landscape " ] = options [ 0 ]
2024-02-29 16:51:40 +01:00
api_key = self . retrieve_setting ( identifier = " openai_api_key " )
if not api_key :
raise ValueError ( " OpenAI API key is not set. " )
self . client = OpenAI ( api_key = api_key [ " api_key " ] )
2024-02-18 00:56:49 +01:00
super ( ) . __init__ ( )
2024-02-20 14:55:25 +01:00
2024-02-23 13:12:48 +01:00
def generate ( self , options : list [ Spec ] ) - > list [ mp . ImageClip ] :
2024-02-18 22:49:56 +01:00
max_width = self . ctx . width / 3 * 2
2024-02-18 00:56:49 +01:00
clips = [ ]
for option in options :
prompt = option [ " prompt " ]
start = option [ " start " ]
end = option [ " end " ]
style = option [ " style " ]
2024-02-23 13:12:48 +01:00
size : Literal [ " 1024x1024 " , " 1024x1792 " , " 1792x1024 " ] = (
2024-02-20 14:55:25 +01:00
" 1024x1024 "
if self . aspect_ratio == " square "
else " 1024x1792 "
if self . aspect_ratio == " portrait "
else " 1792x1024 "
)
2024-02-18 00:56:49 +01:00
try :
2024-02-29 16:51:40 +01:00
response = self . client . images . generate (
2024-02-18 00:56:49 +01:00
model = " dall-e-3 " ,
prompt = prompt ,
size = size ,
n = 1 ,
style = style ,
2024-02-20 14:55:25 +01:00
response_format = " url " ,
2024-02-18 00:56:49 +01:00
)
except openai . BadRequestError as e :
if e . code == " content_policy_violation " :
2024-02-20 14:55:25 +01:00
# we skip this prompt
2024-02-18 00:56:49 +01:00
continue
else :
raise
2024-02-23 13:12:48 +01:00
img_bytes = requests . get ( response . data [ 0 ] . url )
2024-02-18 00:56:49 +01:00
with open ( " temp.png " , " wb " ) as f :
2024-02-23 13:12:48 +01:00
f . write ( img_bytes . content )
2024-02-18 00:56:49 +01:00
img = mp . ImageClip ( " temp.png " )
os . remove ( " temp.png " )
img : mp . ImageClip = img . set_duration ( end - start )
2024-02-18 22:49:56 +01:00
img : mp . ImageClip = img . set_start ( start )
img : mp . ImageClip = resize ( img , width = max_width )
2024-02-18 00:56:49 +01:00
if self . aspect_ratio == " portrait " :
2024-02-18 22:49:56 +01:00
img : mp . ImageClip = img . set_position ( ( " center " , " top " ) )
2024-02-18 00:56:49 +01:00
elif self . aspect_ratio == " landscape " :
2024-02-20 14:55:25 +01:00
img : mp . ImageClip = img . set_position ( ( " center " , " top " ) )
2024-02-18 00:56:49 +01:00
elif self . aspect_ratio == " square " :
2024-02-20 14:55:25 +01:00
img : mp . ImageClip = img . set_position ( ( " center " , " top " ) )
2024-02-18 00:56:49 +01:00
clips . append ( img )
return clips
@classmethod
def get_options ( cls ) :
return [
2024-02-20 14:55:25 +01:00
gr . Radio (
[ " portrait " , " square " , " landscape " ] ,
label = " Aspect Ratio " ,
value = " square " ,
)
]
2024-02-29 16:51:40 +01:00
@classmethod
def get_settings ( cls ) :
current_api_key : dict | list [ dict ] | None = cls . retrieve_setting ( identifier = " openai_api_key " )
current_api_key = current_api_key [ " api_key " ] if current_api_key else " "
api_key_input = gr . Textbox (
label = " OpenAI API Key " ,
type = " password " ,
value = current_api_key ,
)
save = gr . Button ( " Save " )
def save_api_key ( api_key : str ) :
cls . store_setting ( identifier = " openai_api_key " , data = { " api_key " : api_key } )
gr . Info ( " API key saved successfully. " )
return gr . update ( value = api_key )
save . click ( save_api_key , inputs = [ api_key_input ] )