Formatting

This commit is contained in:
2024-02-23 13:12:48 +01:00
parent 03dd641c70
commit 9f91132a0a
23 changed files with 39 additions and 63 deletions

View File

@@ -2,12 +2,10 @@ import os
import time import time
from datetime import datetime from datetime import datetime
import gradio as gr
import moviepy.editor as mp import moviepy.editor as mp
from .. import engines from .. import engines
from ..models import Video, SessionLocal from ..models import Video, SessionLocal
from ..utils.prompting import get_prompt
class GenerationContext: class GenerationContext:

View File

@@ -33,5 +33,5 @@ class AssetsEngineSelector:
assets_opts = [ assets_opts = [
asset["args"] for asset in assets if asset["engine"] == engine.name asset["args"] for asset in assets if asset["engine"] == engine.name
] ]
clips.extend(engine.get_assets(assets_opts)) clips.extend(engine.generate(assets_opts))
self.ctx.index_3.extend(clips) self.ctx.index_3.extend(clips)

View File

@@ -1,7 +1,4 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import TypedDict
from moviepy.editor import ImageClip, VideoFileClip
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine
@@ -21,5 +18,5 @@ class BaseAssetsEngine(BaseEngine):
spec_description: str spec_description: str
@abstractmethod @abstractmethod
def get_assets(self, options: list) -> list: def generate(self, options: list) -> list:
... ...

View File

@@ -1,7 +1,4 @@
import base64
import io
import os import os
import time
from typing import Literal, TypedDict from typing import Literal, TypedDict
import gradio as gr import gradio as gr
@@ -41,7 +38,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
super().__init__() super().__init__()
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]: def generate(self, options: list[Spec]) -> list[mp.ImageClip]:
max_width = self.ctx.width / 3 * 2 max_width = self.ctx.width / 3 * 2
clips = [] clips = []
for option in options: for option in options:
@@ -49,7 +46,7 @@ class DallEAssetsEngine(BaseAssetsEngine):
start = option["start"] start = option["start"]
end = option["end"] end = option["end"]
style = option["style"] style = option["style"]
size = ( size: Literal["1024x1024", "1024x1792", "1792x1024"] = (
"1024x1024" "1024x1024"
if self.aspect_ratio == "square" if self.aspect_ratio == "square"
else "1024x1792" else "1024x1792"
@@ -71,9 +68,9 @@ class DallEAssetsEngine(BaseAssetsEngine):
continue continue
else: else:
raise raise
img = requests.get(response.data[0].url) img_bytes = requests.get(response.data[0].url)
with open("temp.png", "wb") as f: with open("temp.png", "wb") as f:
f.write(img.content) f.write(img_bytes.content)
img = mp.ImageClip("temp.png") img = mp.ImageClip("temp.png")
os.remove("temp.png") os.remove("temp.png")

View File

@@ -1,13 +1,9 @@
import base64
import io
import os import os
import shutil import shutil
import time from typing import TypedDict
from typing import Literal, TypedDict
import gradio as gr import gradio as gr
import moviepy.editor as mp import moviepy.editor as mp
import requests
from google_images_search import GoogleImagesSearch from google_images_search import GoogleImagesSearch
from moviepy.video.fx.resize import resize from moviepy.video.fx.resize import resize
@@ -41,7 +37,7 @@ class GoogleAssetsEngine(BaseAssetsEngine):
self.google = GoogleImagesSearch(api_key, project_cx) self.google = GoogleImagesSearch(api_key, project_cx)
super().__init__() super().__init__()
def get_assets(self, options: list[Spec]) -> list[mp.ImageClip]: def generate(self, options: list[Spec]) -> list[mp.ImageClip]:
max_width = self.ctx.width / 3 * 2 max_width = self.ctx.width / 3 * 2
clips = [] clips = []
for option in options: for option in options:

View File

@@ -1,6 +1,4 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from moviepy.editor import VideoClip
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine

View File

@@ -6,7 +6,6 @@ import time
import gradio as gr import gradio as gr
import moviepy.editor as mp import moviepy.editor as mp
from moviepy.video.fx.crop import crop from moviepy.video.fx.crop import crop
from moviepy.video.fx.resize import resize
from . import BaseBackgroundEngine from . import BaseBackgroundEngine
@@ -60,7 +59,7 @@ class VideoBackgroundEngine(BaseBackgroundEngine):
) )
@classmethod @classmethod
def get_settings(cls) -> list: def get_settings(cls):
def add_file(fp: str, name: str, credits: str): def add_file(fp: str, name: str, credits: str):
if name == "": if name == "":
raise ValueError("Name cannot be empty.") raise ValueError("Name cannot be empty.")

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import gradio as gr
import moviepy.editor as mp import moviepy.editor as mp
from sqlalchemy.future import select from sqlalchemy.future import select
@@ -8,6 +7,7 @@ from ..chore import GenerationContext
from ..models import SessionLocal, File, Setting from ..models import SessionLocal, File, Setting
# noinspection PyTypeChecker
class BaseEngine(ABC): class BaseEngine(ABC):
num_options: int num_options: int
name: str name: str
@@ -33,6 +33,7 @@ class BaseEngine(ABC):
def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None: def get_assets(cls, *, type: str = None, by_id: int = None) -> list[File] | File | None:
with SessionLocal() as db: with SessionLocal() as db:
if type: if type:
# noinspection PyTypeChecker
return ( return (
db.execute( db.execute(
select(File).filter( select(File).filter(
@@ -43,6 +44,7 @@ class BaseEngine(ABC):
.all() .all()
) )
elif by_id: elif by_id:
# noinspection PyTypeChecker
return ( return (
db.execute( db.execute(
select(File).filter( select(File).filter(
@@ -53,6 +55,7 @@ class BaseEngine(ABC):
.first() .first()
) )
else: else:
# noinspection PyTypeChecker
return ( return (
db.execute(select(File).filter(File.provider == cls.name)) db.execute(select(File).filter(File.provider == cls.name))
.scalars() .scalars()
@@ -69,6 +72,7 @@ class BaseEngine(ABC):
@classmethod @classmethod
def remove_asset(cls, *, path: str): def remove_asset(cls, *, path: str):
with SessionLocal() as db: with SessionLocal() as db:
# noinspection PyTypeChecker
db.execute(select(File).filter(File.path == path)).delete() db.execute(select(File).filter(File.path == path)).delete()
db.commit() db.commit()
@@ -77,6 +81,7 @@ class BaseEngine(ABC):
def store_setting(cls, *, type: str = None, data: dict): def store_setting(cls, *, type: str = None, data: dict):
with SessionLocal() as db: with SessionLocal() as db:
# check if setting exists # check if setting exists
# noinspection PyTypeChecker
setting = db.execute( setting = db.execute(
select(Setting).filter( select(Setting).filter(
Setting.provider == cls.name, Setting.type == type Setting.provider == cls.name, Setting.type == type
@@ -112,6 +117,7 @@ class BaseEngine(ABC):
if not identifier and type: if not identifier and type:
identifier = type identifier = type
if identifier: if identifier:
# noinspection PyTypeChecker
result = db.execute( result = db.execute(
select(Setting).filter( select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier Setting.provider == cls.name, Setting.type == identifier
@@ -122,6 +128,7 @@ class BaseEngine(ABC):
return result.data return result.data
return None return None
else: else:
# noinspection PyTypeChecker
return [ return [
s.data s.data
for s in db.execute( for s in db.execute(
@@ -145,12 +152,14 @@ class BaseEngine(ABC):
if not identifier and type: if not identifier and type:
identifier = type identifier = type
if identifier: if identifier:
# noinspection PyTypeChecker
db.execute( db.execute(
select(Setting).filter( select(Setting).filter(
Setting.provider == cls.name, Setting.type == identifier Setting.provider == cls.name, Setting.type == identifier
) )
).delete() ).delete()
else: else:
# noinspection PyTypeChecker
db.execute( db.execute(
select(Setting).filter(Setting.provider == cls.name) select(Setting).filter(Setting.provider == cls.name)
).delete() ).delete()

View File

@@ -1,6 +1,4 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from moviepy.editor import TextClip
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine

View File

@@ -1,5 +1,4 @@
import gradio as gr import gradio as gr
from PIL import ImageFont
from moviepy.editor import TextClip from moviepy.editor import TextClip
from . import BaseCaptioningEngine from . import BaseCaptioningEngine

View File

@@ -34,8 +34,9 @@ class AnthropicLLMEngine(BaseLLMEngine):
) -> str | dict: ) -> str | dict:
prompt = f"""{anthropic.HUMAN_PROMPT} {system_prompt} {anthropic.HUMAN_PROMPT} {chat_prompt} {anthropic.AI_PROMPT}""" prompt = f"""{anthropic.HUMAN_PROMPT} {system_prompt} {anthropic.HUMAN_PROMPT} {chat_prompt} {anthropic.AI_PROMPT}"""
if json_mode: if json_mode:
# anthopic does not officially support JSON mode, but we can bias the output towards a JSON-like format # anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format
prompt += " {" prompt += " {"
# noinspection PyArgumentList
response: anthropic.types.Completion = self.client.completions.create( response: anthropic.types.Completion = self.client.completions.create(
max_tokens_to_sample=max_tokens, max_tokens_to_sample=max_tokens,
prompt=prompt, prompt=prompt,

View File

@@ -1,6 +1,4 @@
from abc import ABC, abstractmethod from abc import abstractmethod
import openai
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine
@@ -11,11 +9,11 @@ class BaseLLMEngine(BaseEngine):
self, self,
system_prompt: str, system_prompt: str,
chat_prompt: str, chat_prompt: str,
max_tokens: int, max_tokens: int = 512,
temperature: float, temperature: float = 1.0,
json_mode: bool, json_mode: bool = False,
top_p: float, top_p: float = 1,
frequency_penalty: float, frequency_penalty: float = 0,
presence_penalty: float, presence_penalty: float = 0,
) -> str | dict: ) -> str | dict:
pass pass

View File

@@ -1,5 +1,3 @@
from abc import ABC, abstractmethod
import gradio as gr import gradio as gr
import openai import openai
import orjson import orjson

View File

@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine

View File

@@ -12,8 +12,8 @@ class CustomScriptEngine(BaseScriptEngine):
self.script = options[0] self.script = options[0]
super().__init__() super().__init__()
def generate(self, *args, **kwargs) -> str: def generate(self, *args, **kwargs):
self.ctx.script = self.script.strip().copy() self.ctx.script = self.script.strip()
@classmethod @classmethod
def get_options(cls) -> list: def get_options(cls) -> list:

View File

@@ -1,5 +1,3 @@
from abc import ABC, abstractmethod
import gradio as gr import gradio as gr
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine

View File

@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import TypedDict from typing import TypedDict
import moviepy.editor as mp import moviepy.editor as mp

View File

@@ -4,8 +4,7 @@ import gradio as gr
import torch import torch
from TTS.api import TTS from TTS.api import TTS
from .BaseTTSEngine import BaseTTSEngine, Word from .BaseTTSEngine import BaseTTSEngine
from ...utils.prompting import get_prompt
class CoquiTTSEngine(BaseTTSEngine): class CoquiTTSEngine(BaseTTSEngine):

View File

@@ -1,5 +1,3 @@
from typing import TypedDict
from . import AssetsEngine from . import AssetsEngine
from . import BackgroundEngine from . import BackgroundEngine
from . import CaptioningEngine from . import CaptioningEngine

View File

@@ -1,7 +1,5 @@
import os
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import sessionmaker
from . import Base from . import Base

View File

@@ -1,5 +1,3 @@
from typing import Optional
from sqlalchemy import String, Column, JSON, Integer from sqlalchemy import String, Column, JSON, Integer
from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.ext.mutable import MutableDict
@@ -13,4 +11,4 @@ class File(Base):
provider: str = Column(String, nullable=False) provider: str = Column(String, nullable=False)
type: str = Column(String, nullable=True) type: str = Column(String, nullable=True)
path: str = Column(String, nullable=False) path: str = Column(String, nullable=False)
data: dict = Column(MutableDict.as_mutable(JSON), nullable=False, default={}) data: dict = Column(MutableDict.as_mutable(JSON), nullable=False, default={}) # type: ignore

View File

@@ -1,5 +1,3 @@
from typing import Optional
from sqlalchemy import String, Column, JSON, Integer from sqlalchemy import String, Column, JSON, Integer
from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.ext.mutable import MutableDict
@@ -12,4 +10,4 @@ class Setting(Base):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
provider: str = Column(String, nullable=False) provider: str = Column(String, nullable=False)
type: str = Column(String, nullable=True) type: str = Column(String, nullable=True)
data: dict = Column(MutableDict.as_mutable(JSON), nullable=False, default={}) data: dict = Column(MutableDict.as_mutable(JSON), nullable=False, default={}) # type: ignore

View File

@@ -1,5 +1,4 @@
from datetime import datetime from datetime import datetime
from typing import Optional
from sqlalchemy import String, Column, JSON, Integer, DateTime from sqlalchemy import String, Column, JSON, Integer, DateTime
from sqlalchemy.ext.mutable import MutableList from sqlalchemy.ext.mutable import MutableList