Add presets functionality to GenerateUI class

This commit is contained in:
2024-02-22 18:02:52 +01:00
parent de7b7b1f1b
commit 0a2e4774d0

View File

@@ -1,16 +1,19 @@
import os import os
import gradio as gr import gradio as gr
import orjson
from src.engines import ENGINES, BaseEngine from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext from src.chore import GenerationContext
class GenerateUI: class GenerateUI:
def __init__(self): def __init__(self):
self.css = """.generate_button { self.css = """.generate_button {
font-size: 5rem !important font-size: 5rem !important
} }
""" """
def get_presets(self):
with open("local/presets.json", "r") as f:
return orjson.loads(f.read())
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]: def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
def switch(selected: str | list[str]): def switch(selected: str | list[str]):
@@ -92,6 +95,54 @@ class GenerateUI:
variant="primary", variant="primary",
elem_classes="generate_button", elem_classes="generate_button",
) )
with gr.Row() as preset_row:
presets = self.get_presets()
preset_dropdown = gr.Dropdown(
choices=list(presets.keys()),
label="Presets",
allow_custom_value=True,
value=None
)
preset_button = gr.Button("Load")
def load_preset(preset_name, *inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f:
presets = orjson.loads(f.read())
returnable = []
if preset_name in presets.keys():
# If the preset exists
preset = presets[preset_name]
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
values = [[]]
for engine in engines:
if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name)
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
else:
values.extend(gr.update() for _ in range(engine.num_options))
returnable.extend(values)
else:
poppable_inputs = list(inputs)
new_preset = {}
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
new_preset[engine_type] = {}
engine_names = poppable_inputs.pop(0)
if isinstance(engine_names, str):
engine_names = [engine_names]
returnable.append(gr.update())
for engine in engines:
if engine.name in engine_names:
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options]
poppable_inputs = poppable_inputs[engine.num_options:]
else:
poppable_inputs = poppable_inputs[engine.num_options:]
returnable.extend(gr.update() for _ in range(engine.num_options))
with open("local/presets.json", "wb") as f:
presets[preset_name] = new_preset
f.write(orjson.dumps(presets))
return returnable
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs)
output_gallery = gr.Markdown("aaa", render=False) output_gallery = gr.Markdown("aaa", render=False)
button.click( button.click(
self.run_generate_interface, self.run_generate_interface,