Add presets feature

This commit is contained in:
2024-02-23 09:37:00 +01:00
parent d0d1f38e50
commit aa5c03b17e

View File

@@ -1,6 +1,7 @@
import os
import gradio as gr
import orjson
import sys
from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext
@@ -15,8 +16,8 @@ class GenerateUI:
with open("local/presets.json", "r") as f:
return orjson.loads(f.read())
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
def switch(selected: str | list[str]):
def get_switcher_func(self, engine_names: list[str]) -> callable:
def switch(selected: str | list[str]) -> list[gr.update]:
if isinstance(selected, str):
selected = [selected]
returnable = []
@@ -27,11 +28,15 @@ class GenerateUI:
return switch
def launch_ui(self):
def get_ui(self):
ui = gr.TabbedInterface(
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
)
ui.launch()
return ui
def launch_ui(self):
self.ui = self.get_ui()
self.ui.launch()
def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]:
"""
@@ -47,6 +52,12 @@ class GenerateUI:
def get_settings_interface(self) -> gr.Blocks:
with gr.Blocks() as interface:
reload_ui = gr.Button("Reload UI", variant="primary")
def reload():
self.ui.close()
sys.exit("Reload")
reload_ui.click(reload)
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
with gr.Tab(engine_type) as engine_tab:
@@ -141,8 +152,8 @@ class GenerateUI:
with open("local/presets.json", "wb") as f:
presets[preset_name] = new_preset
f.write(orjson.dumps(presets))
return returnable
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs)
return [gr.update(value=presets.keys()), *returnable]
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs])
output_gallery = gr.Markdown("aaa", render=False)
button.click(
self.run_generate_interface,
@@ -191,8 +202,3 @@ class GenerateUI:
# we don't care about this, it's not the selected engine, we throw it away
args = args[engine.num_options :]
return options
if __name__ == "__main__":
ui_generator = GenerateUI()
ui_generator.launch_ui()