From aa5c03b17e7e589cb5c6d828921ce273d194b13e Mon Sep 17 00:00:00 2001 From: Paillat Date: Fri, 23 Feb 2024 09:37:00 +0100 Subject: [PATCH] Add presets feature --- ui/gradio_ui.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 39fc91e..d6276fb 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -1,6 +1,7 @@ import os import gradio as gr import orjson +import sys from src.engines import ENGINES, BaseEngine from src.chore import GenerationContext @@ -15,8 +16,8 @@ class GenerateUI: with open("local/presets.json", "r") as f: return orjson.loads(f.read()) - def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]: - def switch(selected: str | list[str]): + def get_switcher_func(self, engine_names: list[str]) -> callable: + def switch(selected: str | list[str]) -> list[gr.update]: if isinstance(selected, str): selected = [selected] returnable = [] @@ -27,11 +28,15 @@ class GenerateUI: return switch - def launch_ui(self): + def get_ui(self): ui = gr.TabbedInterface( *self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css ) - ui.launch() + return ui + + def launch_ui(self): + self.ui = self.get_ui() + self.ui.launch() def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]: """ @@ -47,6 +52,12 @@ class GenerateUI: def get_settings_interface(self) -> gr.Blocks: with gr.Blocks() as interface: + reload_ui = gr.Button("Reload UI", variant="primary") + def reload(): + self.ui.close() + sys.exit("Reload") + + reload_ui.click(reload) for engine_type, engines in ENGINES.items(): engines = engines["classes"] with gr.Tab(engine_type) as engine_tab: @@ -141,8 +152,8 @@ class GenerateUI: with open("local/presets.json", "wb") as f: presets[preset_name] = new_preset f.write(orjson.dumps(presets)) - return returnable - preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs) + return [gr.update(value=presets.keys()), *returnable] + preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs]) output_gallery = gr.Markdown("aaa", render=False) button.click( self.run_generate_interface, @@ -190,9 +201,4 @@ class GenerateUI: else: # we don't care about this, it's not the selected engine, we throw it away args = args[engine.num_options :] - return options - - -if __name__ == "__main__": - ui_generator = GenerateUI() - ui_generator.launch_ui() + return options \ No newline at end of file