From 0a2e4774d06adcdac0d04d830911744fc970c121 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 22 Feb 2024 18:02:52 +0100 Subject: [PATCH] Add presets functionality to GenerateUI class --- ui/gradio_ui.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 8192055..39fc91e 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -1,16 +1,19 @@ import os import gradio as gr +import orjson from src.engines import ENGINES, BaseEngine from src.chore import GenerationContext - class GenerateUI: def __init__(self): self.css = """.generate_button { font-size: 5rem !important } """ + def get_presets(self): + with open("local/presets.json", "r") as f: + return orjson.loads(f.read()) def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]: def switch(selected: str | list[str]): @@ -92,6 +95,54 @@ class GenerateUI: variant="primary", elem_classes="generate_button", ) + with gr.Row() as preset_row: + presets = self.get_presets() + preset_dropdown = gr.Dropdown( + choices=list(presets.keys()), + label="Presets", + allow_custom_value=True, + value=None + ) + preset_button = gr.Button("Load") + def load_preset(preset_name, *inputs) -> list[gr.update]: + with open("local/presets.json", "r") as f: + presets = orjson.loads(f.read()) + returnable = [] + if preset_name in presets.keys(): + # If the preset exists + preset = presets[preset_name] + for engine_type, engines in ENGINES.items(): + engines = engines["classes"] + values = [[]] + for engine in engines: + if engine.name in preset.get(engine_type, {}).keys(): + values[0].append(engine.name) + values.extend(gr.update(value=value) for value in preset[engine_type][engine.name]) + else: + values.extend(gr.update() for _ in range(engine.num_options)) + returnable.extend(values) + else: + poppable_inputs = list(inputs) + new_preset = {} + for engine_type, engines in ENGINES.items(): + engines = engines["classes"] + new_preset[engine_type] = {} + engine_names = poppable_inputs.pop(0) + if isinstance(engine_names, str): + engine_names = [engine_names] + returnable.append(gr.update()) + for engine in engines: + if engine.name in engine_names: + new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] + poppable_inputs = poppable_inputs[engine.num_options:] + else: + poppable_inputs = poppable_inputs[engine.num_options:] + returnable.extend(gr.update() for _ in range(engine.num_options)) + with open("local/presets.json", "wb") as f: + presets[preset_name] = new_preset + f.write(orjson.dumps(presets)) + return returnable + preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs) output_gallery = gr.Markdown("aaa", render=False) button.click( self.run_generate_interface,