Add presets feature

This commit is contained in:
2024-02-23 09:37:00 +01:00
parent d0d1f38e50
commit aa5c03b17e

View File

@@ -1,6 +1,7 @@
import os import os
import gradio as gr import gradio as gr
import orjson import orjson
import sys
from src.engines import ENGINES, BaseEngine from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext from src.chore import GenerationContext
@@ -15,8 +16,8 @@ class GenerateUI:
with open("local/presets.json", "r") as f: with open("local/presets.json", "r") as f:
return orjson.loads(f.read()) return orjson.loads(f.read())
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]: def get_switcher_func(self, engine_names: list[str]) -> callable:
def switch(selected: str | list[str]): def switch(selected: str | list[str]) -> list[gr.update]:
if isinstance(selected, str): if isinstance(selected, str):
selected = [selected] selected = [selected]
returnable = [] returnable = []
@@ -27,11 +28,15 @@ class GenerateUI:
return switch return switch
def launch_ui(self): def get_ui(self):
ui = gr.TabbedInterface( ui = gr.TabbedInterface(
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css *self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
) )
ui.launch() return ui
def launch_ui(self):
self.ui = self.get_ui()
self.ui.launch()
def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]: def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]:
""" """
@@ -47,6 +52,12 @@ class GenerateUI:
def get_settings_interface(self) -> gr.Blocks: def get_settings_interface(self) -> gr.Blocks:
with gr.Blocks() as interface: with gr.Blocks() as interface:
reload_ui = gr.Button("Reload UI", variant="primary")
def reload():
self.ui.close()
sys.exit("Reload")
reload_ui.click(reload)
for engine_type, engines in ENGINES.items(): for engine_type, engines in ENGINES.items():
engines = engines["classes"] engines = engines["classes"]
with gr.Tab(engine_type) as engine_tab: with gr.Tab(engine_type) as engine_tab:
@@ -141,8 +152,8 @@ class GenerateUI:
with open("local/presets.json", "wb") as f: with open("local/presets.json", "wb") as f:
presets[preset_name] = new_preset presets[preset_name] = new_preset
f.write(orjson.dumps(presets)) f.write(orjson.dumps(presets))
return returnable return [gr.update(value=presets.keys()), *returnable]
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs) preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs])
output_gallery = gr.Markdown("aaa", render=False) output_gallery = gr.Markdown("aaa", render=False)
button.click( button.click(
self.run_generate_interface, self.run_generate_interface,
@@ -191,8 +202,3 @@ class GenerateUI:
# we don't care about this, it's not the selected engine, we throw it away # we don't care about this, it's not the selected engine, we throw it away
args = args[engine.num_options :] args = args[engine.num_options :]
return options return options
if __name__ == "__main__":
ui_generator = GenerateUI()
ui_generator.launch_ui()