diff --git a/.gitignore b/.gitignore index 46c1ebb..65d700d 100644 --- a/.gitignore +++ b/.gitignore @@ -152,7 +152,7 @@ cython_debug/ #.idea/ output/ -local/* +local/ local/presets.json cookies.txt diff --git a/src/__init__.py b/src/__init__.py index ce56030..70b070d 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +1,3 @@ +from . import chore from . import engines from . import utils -from . import chore diff --git a/src/chore/GenerationContext.py b/src/chore/GenerationContext.py index dd4e8f6..ae9c489 100644 --- a/src/chore/GenerationContext.py +++ b/src/chore/GenerationContext.py @@ -1,12 +1,10 @@ -import moviepy.editor as mp -import time import os -import gradio as gr - +import time from datetime import datetime +import moviepy.editor as mp + from .. import engines -from ..utils.prompting import get_prompt from ..models import Video, SessionLocal @@ -26,18 +24,18 @@ class GenerationContext: db.commit() def __init__( - self, - powerfulllmengine, - simplellmengine, - scriptengine, - ttsengine, - captioningengine, - assetsengine, - settingsengine, - backgroundengine, - metadataengine, - uploadengine, - progress, + self, + powerfulllmengine, + simplellmengine, + scriptengine, + ttsengine, + captioningengine, + assetsengine, + settingsengine, + backgroundengine, + metadataengine, + uploadengine, + progress, ) -> None: self.progress = progress diff --git a/src/engines/AssetsEngine/AssetsEngineSelector.py b/src/engines/AssetsEngine/AssetsEngineSelector.py index 350accc..6229248 100644 --- a/src/engines/AssetsEngine/AssetsEngineSelector.py +++ b/src/engines/AssetsEngine/AssetsEngineSelector.py @@ -1,7 +1,7 @@ import json -from ...utils.prompting import get_prompt from ...chore import GenerationContext +from ...utils.prompting import get_prompt class AssetsEngineSelector: diff --git a/src/engines/AssetsEngine/BaseAssetsEngine.py b/src/engines/AssetsEngine/BaseAssetsEngine.py index 9079c76..0e6e24d 100644 --- a/src/engines/AssetsEngine/BaseAssetsEngine.py +++ b/src/engines/AssetsEngine/BaseAssetsEngine.py @@ -1,7 +1,6 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod + from ..BaseEngine import BaseEngine -from typing import TypedDict -from moviepy.editor import ImageClip, VideoFileClip class BaseAssetsEngine(BaseEngine): diff --git a/src/engines/AssetsEngine/DallEAssetsEngine.py b/src/engines/AssetsEngine/DallEAssetsEngine.py index e5f3e03..1c70102 100644 --- a/src/engines/AssetsEngine/DallEAssetsEngine.py +++ b/src/engines/AssetsEngine/DallEAssetsEngine.py @@ -1,15 +1,12 @@ -import gradio as gr -import openai -import moviepy.editor as mp -import io -import base64 -import time -import requests import os - -from moviepy.video.fx.resize import resize from typing import Literal, TypedDict +import gradio as gr +import moviepy.editor as mp +import openai +import requests +from moviepy.video.fx.resize import resize + from . import BaseAssetsEngine diff --git a/src/engines/AssetsEngine/GoogleAssetsEngine.py b/src/engines/AssetsEngine/GoogleAssetsEngine.py index b48dbfe..ed0d2ac 100644 --- a/src/engines/AssetsEngine/GoogleAssetsEngine.py +++ b/src/engines/AssetsEngine/GoogleAssetsEngine.py @@ -1,15 +1,12 @@ +import os +import os +import shutil +from typing import TypedDict + import gradio as gr import moviepy.editor as mp -import io -import base64 -import time -import requests -import shutil -import os - from google_images_search import GoogleImagesSearch from moviepy.video.fx.resize import resize -from typing import Literal, TypedDict from . import BaseAssetsEngine diff --git a/src/engines/AssetsEngine/__init__.py b/src/engines/AssetsEngine/__init__.py index 60691bd..b994a2b 100644 --- a/src/engines/AssetsEngine/__init__.py +++ b/src/engines/AssetsEngine/__init__.py @@ -1,4 +1,4 @@ +from .AssetsEngineSelector import AssetsEngineSelector from .BaseAssetsEngine import BaseAssetsEngine from .DallEAssetsEngine import DallEAssetsEngine -from .AssetsEngineSelector import AssetsEngineSelector from .GoogleAssetsEngine import GoogleAssetsEngine diff --git a/src/engines/BackgroundEngine/BaseBackgroundEngine.py b/src/engines/BackgroundEngine/BaseBackgroundEngine.py index 68f9558..6ef6a74 100644 --- a/src/engines/BackgroundEngine/BaseBackgroundEngine.py +++ b/src/engines/BackgroundEngine/BaseBackgroundEngine.py @@ -1,7 +1,6 @@ -from abc import ABC, abstractmethod -from ..BaseEngine import BaseEngine +from abc import abstractmethod -from moviepy.editor import VideoClip +from ..BaseEngine import BaseEngine class BaseBackgroundEngine(BaseEngine): diff --git a/src/engines/BackgroundEngine/VideoBackgroundEngine.py b/src/engines/BackgroundEngine/VideoBackgroundEngine.py index 4b470dd..089eabb 100644 --- a/src/engines/BackgroundEngine/VideoBackgroundEngine.py +++ b/src/engines/BackgroundEngine/VideoBackgroundEngine.py @@ -1,12 +1,12 @@ import os -import shutil import random +import shutil import time + import gradio as gr import moviepy.editor as mp - -from moviepy.video.fx.resize import resize from moviepy.video.fx.crop import crop + from . import BaseBackgroundEngine diff --git a/src/engines/BaseEngine.py b/src/engines/BaseEngine.py index 0a6839e..b862d11 100644 --- a/src/engines/BaseEngine.py +++ b/src/engines/BaseEngine.py @@ -1,7 +1,6 @@ -import gradio as gr -import moviepy.editor as mp - from abc import ABC, abstractmethod + +import moviepy.editor as mp from sqlalchemy.future import select from ..chore import GenerationContext diff --git a/src/engines/CaptioningEngine/BaseCaptioningEngine.py b/src/engines/CaptioningEngine/BaseCaptioningEngine.py index bbadf88..be80b92 100644 --- a/src/engines/CaptioningEngine/BaseCaptioningEngine.py +++ b/src/engines/CaptioningEngine/BaseCaptioningEngine.py @@ -1,7 +1,6 @@ -from abc import ABC, abstractmethod -from ..BaseEngine import BaseEngine +from abc import abstractmethod -from moviepy.editor import TextClip +from ..BaseEngine import BaseEngine class BaseCaptioningEngine(BaseEngine): diff --git a/src/engines/CaptioningEngine/SimpleCaptioningEngine.py b/src/engines/CaptioningEngine/SimpleCaptioningEngine.py index c24a528..2d0cf65 100644 --- a/src/engines/CaptioningEngine/SimpleCaptioningEngine.py +++ b/src/engines/CaptioningEngine/SimpleCaptioningEngine.py @@ -1,6 +1,6 @@ import gradio as gr from moviepy.editor import TextClip -from PIL import ImageFont + from . import BaseCaptioningEngine diff --git a/src/engines/LLMEngine/AnthropicLLMEngine.py b/src/engines/LLMEngine/AnthropicLLMEngine.py index 3b5878d..7c888bd 100644 --- a/src/engines/LLMEngine/AnthropicLLMEngine.py +++ b/src/engines/LLMEngine/AnthropicLLMEngine.py @@ -22,15 +22,15 @@ class AnthropicLLMEngine(BaseLLMEngine): super().__init__() def generate( - self, - system_prompt: str, - chat_prompt: str, - max_tokens: int = 1024, - temperature: float = 1.0, - json_mode: bool = False, - top_p: float = 1, - frequency_penalty: float = 0, - presence_penalty: float = 0, + self, + system_prompt: str, + chat_prompt: str, + max_tokens: int = 1024, + temperature: float = 1.0, + json_mode: bool = False, + top_p: float = 1, + frequency_penalty: float = 0, + presence_penalty: float = 0, ) -> str | dict: prompt = f"""{anthropic.HUMAN_PROMPT} {system_prompt} {anthropic.HUMAN_PROMPT} {chat_prompt} {anthropic.AI_PROMPT}""" if json_mode: diff --git a/src/engines/LLMEngine/BaseLLMEngine.py b/src/engines/LLMEngine/BaseLLMEngine.py index b2aa5a7..850cae5 100644 --- a/src/engines/LLMEngine/BaseLLMEngine.py +++ b/src/engines/LLMEngine/BaseLLMEngine.py @@ -1,20 +1,19 @@ -from abc import ABC, abstractmethod -from ..BaseEngine import BaseEngine +from abc import abstractmethod -import openai +from ..BaseEngine import BaseEngine class BaseLLMEngine(BaseEngine): @abstractmethod def generate( - self, - system_prompt: str, - chat_prompt: str, - max_tokens: int, - temperature: float, - json_mode: bool, - top_p: float, - frequency_penalty: float, - presence_penalty: float, + self, + system_prompt: str, + chat_prompt: str, + max_tokens: int, + temperature: float, + json_mode: bool, + top_p: float, + frequency_penalty: float, + presence_penalty: float, ) -> str | dict: pass diff --git a/src/engines/LLMEngine/OpenaiLLMEngine.py b/src/engines/LLMEngine/OpenaiLLMEngine.py index 7e380c1..8751f9a 100644 --- a/src/engines/LLMEngine/OpenaiLLMEngine.py +++ b/src/engines/LLMEngine/OpenaiLLMEngine.py @@ -1,9 +1,7 @@ -import openai import gradio as gr +import openai import orjson -from abc import ABC, abstractmethod - from .BaseLLMEngine import BaseLLMEngine OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json @@ -22,15 +20,15 @@ class OpenaiLLMEngine(BaseLLMEngine): super().__init__() def generate( - self, - system_prompt: str, - chat_prompt: str, - max_tokens: int = 512, - temperature: float = 1.0, - json_mode: bool = False, - top_p: float = 1, - frequency_penalty: float = 0, - presence_penalty: float = 0, + self, + system_prompt: str, + chat_prompt: str, + max_tokens: int = 512, + temperature: float = 1.0, + json_mode: bool = False, + top_p: float = 1, + frequency_penalty: float = 0, + presence_penalty: float = 0, ) -> str | dict: response = openai.chat.completions.create( model=self.model, diff --git a/src/engines/LLMEngine/__init__.py b/src/engines/LLMEngine/__init__.py index ba44606..65a0e79 100644 --- a/src/engines/LLMEngine/__init__.py +++ b/src/engines/LLMEngine/__init__.py @@ -1,3 +1,3 @@ +from .AnthropicLLMEngine import AnthropicLLMEngine from .BaseLLMEngine import BaseLLMEngine from .OpenaiLLMEngine import OpenaiLLMEngine -from .AnthropicLLMEngine import AnthropicLLMEngine diff --git a/src/engines/MetadataEngine/BaseMetadataEngine.py b/src/engines/MetadataEngine/BaseMetadataEngine.py index 5317909..af3bdf8 100644 --- a/src/engines/MetadataEngine/BaseMetadataEngine.py +++ b/src/engines/MetadataEngine/BaseMetadataEngine.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from typing import TypedDict from .. import BaseEngine diff --git a/src/engines/ScriptEngine/BaseScriptEngine.py b/src/engines/ScriptEngine/BaseScriptEngine.py index eecece4..2309156 100644 --- a/src/engines/ScriptEngine/BaseScriptEngine.py +++ b/src/engines/ScriptEngine/BaseScriptEngine.py @@ -1,4 +1,5 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod + from ..BaseEngine import BaseEngine diff --git a/src/engines/ScriptEngine/CustomScriptEngine.py b/src/engines/ScriptEngine/CustomScriptEngine.py index 3660862..67ce0b0 100644 --- a/src/engines/ScriptEngine/CustomScriptEngine.py +++ b/src/engines/ScriptEngine/CustomScriptEngine.py @@ -1,6 +1,7 @@ -from .BaseScriptEngine import BaseScriptEngine import gradio as gr +from .BaseScriptEngine import BaseScriptEngine + class CustomScriptEngine(BaseScriptEngine): name = "Custom Script Engine" diff --git a/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py b/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py index 3fc4078..7fb9525 100644 --- a/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py +++ b/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py @@ -1,6 +1,7 @@ -import gradio as gr import os +import gradio as gr + from .BaseScriptEngine import BaseScriptEngine from ...utils.prompting import get_prompt diff --git a/src/engines/ScriptEngine/__init__.py b/src/engines/ScriptEngine/__init__.py index 6ad1b7c..0ad53b1 100644 --- a/src/engines/ScriptEngine/__init__.py +++ b/src/engines/ScriptEngine/__init__.py @@ -1,3 +1,3 @@ from .BaseScriptEngine import BaseScriptEngine -from .ShowerThoughtsScriptEngine import ShowerThoughtsScriptEngine from .CustomScriptEngine import CustomScriptEngine +from .ShowerThoughtsScriptEngine import ShowerThoughtsScriptEngine diff --git a/src/engines/SettingsEngine/SettingsEngine.py b/src/engines/SettingsEngine/SettingsEngine.py index 024e498..9346209 100644 --- a/src/engines/SettingsEngine/SettingsEngine.py +++ b/src/engines/SettingsEngine/SettingsEngine.py @@ -1,5 +1,5 @@ import gradio as gr -from abc import ABC, abstractmethod + from ..BaseEngine import BaseEngine diff --git a/src/engines/TTSEngine/BaseTTSEngine.py b/src/engines/TTSEngine/BaseTTSEngine.py index 39b9f5e..49fde68 100644 --- a/src/engines/TTSEngine/BaseTTSEngine.py +++ b/src/engines/TTSEngine/BaseTTSEngine.py @@ -1,9 +1,9 @@ +from abc import abstractmethod +from typing import TypedDict + import moviepy.editor as mp import whisper_timestamped as wt - -from typing import TypedDict from torch.cuda import is_available -from abc import ABC, abstractmethod from ..BaseEngine import BaseEngine diff --git a/src/engines/TTSEngine/CoquiTTSEngine.py b/src/engines/TTSEngine/CoquiTTSEngine.py index db499bf..2914ddb 100644 --- a/src/engines/TTSEngine/CoquiTTSEngine.py +++ b/src/engines/TTSEngine/CoquiTTSEngine.py @@ -1,13 +1,10 @@ -import gradio as gr - -from TTS.api import TTS import os +import gradio as gr import torch +from TTS.api import TTS -from .BaseTTSEngine import BaseTTSEngine, Word - -from ...utils.prompting import get_prompt +from .BaseTTSEngine import BaseTTSEngine class CoquiTTSEngine(BaseTTSEngine): diff --git a/src/engines/UploadEngine/TikTokUploadEngine.py b/src/engines/UploadEngine/TikTokUploadEngine.py index 135e161..bf956fa 100644 --- a/src/engines/UploadEngine/TikTokUploadEngine.py +++ b/src/engines/UploadEngine/TikTokUploadEngine.py @@ -1,5 +1,4 @@ import gradio as gr - from tiktok_uploader.upload import upload_video from .BaseUploadEngine import BaseUploadEngine diff --git a/src/engines/UploadEngine/YouTubeUploadEngine.py b/src/engines/UploadEngine/YouTubeUploadEngine.py index b773c7e..efa9cdd 100644 --- a/src/engines/UploadEngine/YouTubeUploadEngine.py +++ b/src/engines/UploadEngine/YouTubeUploadEngine.py @@ -1,11 +1,11 @@ import gradio as gr import orjson - from google_auth_oauthlib.flow import InstalledAppFlow from . import BaseUploadEngine from ...utils import youtube_uploading + class YouTubeUploadEngine(BaseUploadEngine): name = "YouTube" description = "Upload videos to YouTube" @@ -18,7 +18,7 @@ class YouTubeUploadEngine(BaseUploadEngine): self.credentials = self.retrieve_setting(type="youtube_client_secrets")[self.oauth["client_secret"]] self.hashtags = options[1] - + @classmethod def __oauth(cls, credentials): flow = InstalledAppFlow.from_client_config( @@ -45,10 +45,10 @@ class YouTubeUploadEngine(BaseUploadEngine): try: youtube_uploading.upload(self.oauth["credentials"], options) except Exception as e: - #this means we need to re-authenticate likely + # this means we need to re-authenticate likely # use self.__oauth to re-authenticate new_oauth = self.__oauth(self.credentials) - #also update the credentials in the settings + # also update the credentials in the settings current_oauths = self.retrieve_setting(type="oauth_credentials") or {} current_oauths[self.oauth_name] = { "client_secret": self.oauth["client_secret"], @@ -81,6 +81,7 @@ class YouTubeUploadEngine(BaseUploadEngine): label="Client Secret File", file_types=["json"], type="binary" ) submit_button = gr.Button("Save") + def save(binary, clien_secret_name): current_client_secrets = cls.retrieve_setting(type="youtube_client_secrets") or {} client_secret_json = orjson.loads(binary) @@ -99,6 +100,7 @@ class YouTubeUploadEngine(BaseUploadEngine): choosen_client_secret = gr.Dropdown(label="Login secret", choices=possible_client_secrets) name = gr.Textbox(label="Name", max_lines=1) login_button = gr.Button("Login", variant="primary") + def login(choosen_client_secret, name): choosen_secret_data = cls.retrieve_setting(type="youtube_client_secrets")[choosen_client_secret] new_oauth_entry = cls.__oauth(choosen_secret_data) @@ -112,4 +114,5 @@ class YouTubeUploadEngine(BaseUploadEngine): data=current_oauths, ) gr.Info(f"{name} saved successfully !") + login_button.click(login, inputs=[choosen_client_secret, name]) diff --git a/src/engines/UploadEngine/__init__.py b/src/engines/UploadEngine/__init__.py index 7c64349..b5fac3e 100644 --- a/src/engines/UploadEngine/__init__.py +++ b/src/engines/UploadEngine/__init__.py @@ -1,3 +1,3 @@ from .BaseUploadEngine import BaseUploadEngine from .TikTokUploadEngine import TikTokUploadEngine -from .YouTubeUploadEngine import YouTubeUploadEngine \ No newline at end of file +from .YouTubeUploadEngine import YouTubeUploadEngine diff --git a/src/engines/__init__.py b/src/engines/__init__.py index 725b566..777487b 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -1,15 +1,16 @@ from typing import TypedDict + +from . import AssetsEngine +from . import BackgroundEngine +from . import CaptioningEngine +from . import LLMEngine +from . import MetadataEngine +from . import ScriptEngine +from . import SettingsEngine +from . import TTSEngine +from . import UploadEngine from .BaseEngine import BaseEngine from .NoneEngine import NoneEngine -from . import TTSEngine -from . import ScriptEngine -from . import LLMEngine -from . import CaptioningEngine -from . import AssetsEngine -from . import SettingsEngine -from . import BackgroundEngine -from . import MetadataEngine -from . import UploadEngine class EngineDict(TypedDict): diff --git a/src/engines/engines.json b/src/engines/engines.json index 6d08a1c..7fed7be 100644 --- a/src/engines/engines.json +++ b/src/engines/engines.json @@ -1,6 +1,6 @@ { - "TTSEngine": [ - "CoquiTTSEngine", - "ElevenLabsTTSEngine" - ] + "TTSEngine": [ + "CoquiTTSEngine", + "ElevenLabsTTSEngine" + ] } \ No newline at end of file diff --git a/src/models/DatabaseManager.py b/src/models/DatabaseManager.py index f02b099..85d205d 100644 --- a/src/models/DatabaseManager.py +++ b/src/models/DatabaseManager.py @@ -1,7 +1,5 @@ -import os - from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from . import Base diff --git a/src/models/File.py b/src/models/File.py index ac004e4..7fd7d45 100644 --- a/src/models/File.py +++ b/src/models/File.py @@ -1,8 +1,8 @@ -from . import Base -from typing import Optional from sqlalchemy import String, Column, JSON, Integer from sqlalchemy.ext.mutable import MutableDict +from . import Base + class File(Base): __tablename__ = "files" diff --git a/src/models/Setting.py b/src/models/Setting.py index ef0012c..ba8bffd 100644 --- a/src/models/Setting.py +++ b/src/models/Setting.py @@ -1,8 +1,8 @@ -from . import Base -from typing import Optional from sqlalchemy import String, Column, JSON, Integer from sqlalchemy.ext.mutable import MutableDict +from . import Base + class Setting(Base): __tablename__ = "Settings" diff --git a/src/models/Video.py b/src/models/Video.py index 129be29..f4eedd1 100644 --- a/src/models/Video.py +++ b/src/models/Video.py @@ -1,8 +1,9 @@ -from . import Base -from typing import Optional +from datetime import datetime + from sqlalchemy import String, Column, JSON, Integer, DateTime from sqlalchemy.ext.mutable import MutableList -from datetime import datetime + +from . import Base class Video(Base): @@ -12,6 +13,6 @@ class Video(Base): title: str = Column(String, nullable=False) description: str = Column(String, nullable=False) script: str = Column(String, nullable=False) - timed_script: dict = Column(MutableList.as_mutable(JSON), nullable=False) + timed_script: dict = Column(MutableList.as_mutable(JSON), nullable=False) # type: ignore timestamp: datetime = Column(DateTime, nullable=False, default=datetime.now()) path: str = Column(String, nullable=False) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index fbb149d..160d7fa 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,2 +1,2 @@ from . import prompting -from . import youtube_uploading \ No newline at end of file +from . import youtube_uploading diff --git a/src/utils/prompting.py b/src/utils/prompting.py index 72c9257..55e927c 100644 --- a/src/utils/prompting.py +++ b/src/utils/prompting.py @@ -1,7 +1,8 @@ -import yaml import os from typing import TypedDict +import yaml + class Prompt(TypedDict): system: str @@ -9,7 +10,7 @@ class Prompt(TypedDict): def get_prompt( - name, *, location: str = "src/chore/prompts", by_file_location: str = None + name, *, location: str = "src/chore/prompts", by_file_location: str = None ) -> tuple[str, str]: if by_file_location: path = os.path.join( diff --git a/src/utils/youtube_uploading.py b/src/utils/youtube_uploading.py index faeb1db..348605b 100644 --- a/src/utils/youtube_uploading.py +++ b/src/utils/youtube_uploading.py @@ -1,9 +1,9 @@ -from http import client -import httplib2 import random import time +from http import client import google.oauth2.credentials +import httplib2 from googleapiclient.discovery import build from googleapiclient.errors import HttpError from googleapiclient.http import MediaFileUpload @@ -34,6 +34,7 @@ API_VERSION = "v3" VALID_PRIVACY_STATUSES = ("public", "private", "unlisted") + def get_youtube(oauth_credentials: dict): oauth_credentials = google.oauth2.credentials.Credentials( token=oauth_credentials["token"], @@ -45,6 +46,7 @@ def get_youtube(oauth_credentials: dict): ) return build(API_SERVICE_NAME, API_VERSION, credentials=oauth_credentials) + def upload(oauth_credentials, options): youtube = get_youtube(oauth_credentials) tags = None @@ -106,7 +108,7 @@ def resumable_upload(request): if retry > MAX_RETRIES: exit("No longer attempting to retry.") - max_sleep = 2**retry + max_sleep = 2 ** retry sleep_seconds = random.random() * max_sleep print("Sleeping %f seconds and then retrying..." % sleep_seconds) time.sleep(sleep_seconds) @@ -116,4 +118,4 @@ def upload_thumbnail(video_id, path, oauth_credentials): youtube = get_youtube(oauth_credentials) youtube.thumbnails().set( # type: ignore videoId=video_id, media_body=path - ).execute() \ No newline at end of file + ).execute() diff --git a/ui/__init__.py b/ui/__init__.py index f8a2755..1ec2ce4 100644 --- a/ui/__init__.py +++ b/ui/__init__.py @@ -1,2 +1,2 @@ from .gradio_ui import GenerateUI -from .launcher import launch \ No newline at end of file +from .launcher import launch diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index d6276fb..b27e607 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -1,10 +1,12 @@ import os -import gradio as gr -import orjson import sys -from src.engines import ENGINES, BaseEngine +import gradio as gr +import orjson + from src.chore import GenerationContext +from src.engines import ENGINES, BaseEngine + class GenerateUI: def __init__(self): @@ -12,6 +14,7 @@ class GenerateUI: font-size: 5rem !important } """ + def get_presets(self): with open("local/presets.json", "r") as f: return orjson.loads(f.read()) @@ -53,6 +56,7 @@ class GenerateUI: def get_settings_interface(self) -> gr.Blocks: with gr.Blocks() as interface: reload_ui = gr.Button("Reload UI", variant="primary") + def reload(): self.ui.close() sys.exit("Reload") @@ -115,12 +119,13 @@ class GenerateUI: value=None ) preset_button = gr.Button("Load") + def load_preset(preset_name, *inputs) -> list[gr.update]: with open("local/presets.json", "r") as f: presets = orjson.loads(f.read()) returnable = [] if preset_name in presets.keys(): - # If the preset exists + # If the preset exists preset = presets[preset_name] for engine_type, engines in ENGINES.items(): engines = engines["classes"] @@ -128,7 +133,8 @@ class GenerateUI: for engine in engines: if engine.name in preset.get(engine_type, {}).keys(): values[0].append(engine.name) - values.extend(gr.update(value=value) for value in preset[engine_type][engine.name]) + values.extend( + gr.update(value=value) for value in preset[engine_type][engine.name]) else: values.extend(gr.update() for _ in range(engine.num_options)) returnable.extend(values) @@ -153,7 +159,8 @@ class GenerateUI: presets[preset_name] = new_preset f.write(orjson.dumps(presets)) return [gr.update(value=presets.keys()), *returnable] - preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs]) + preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], + outputs=[preset_dropdown, *inputs]) output_gallery = gr.Markdown("aaa", render=False) button.click( self.run_generate_interface, @@ -197,8 +204,8 @@ class GenerateUI: options[engine_type].append( engine(options=args[: engine.num_options]) ) - args = args[engine.num_options :] + args = args[engine.num_options:] else: # we don't care about this, it's not the selected engine, we throw it away - args = args[engine.num_options :] - return options \ No newline at end of file + args = args[engine.num_options:] + return options diff --git a/ui/launcher.py b/ui/launcher.py index af7778a..cd02a10 100644 --- a/ui/launcher.py +++ b/ui/launcher.py @@ -1,4 +1,6 @@ from . import GenerateUI + + def launch(): ui_generator = GenerateUI() ui_generator.launch_ui()