Add presets functionality to GenerateUI class

This commit is contained in:
2024-02-22 18:02:52 +01:00
parent de7b7b1f1b
commit 0a2e4774d0

View File

@@ -1,16 +1,19 @@
import os
import gradio as gr
import orjson
from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext
class GenerateUI:
def __init__(self):
self.css = """.generate_button {
font-size: 5rem !important
}
"""
def get_presets(self):
with open("local/presets.json", "r") as f:
return orjson.loads(f.read())
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
def switch(selected: str | list[str]):
@@ -92,6 +95,54 @@ class GenerateUI:
variant="primary",
elem_classes="generate_button",
)
with gr.Row() as preset_row:
presets = self.get_presets()
preset_dropdown = gr.Dropdown(
choices=list(presets.keys()),
label="Presets",
allow_custom_value=True,
value=None
)
preset_button = gr.Button("Load")
def load_preset(preset_name, *inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f:
presets = orjson.loads(f.read())
returnable = []
if preset_name in presets.keys():
# If the preset exists
preset = presets[preset_name]
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
values = [[]]
for engine in engines:
if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name)
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
else:
values.extend(gr.update() for _ in range(engine.num_options))
returnable.extend(values)
else:
poppable_inputs = list(inputs)
new_preset = {}
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
new_preset[engine_type] = {}
engine_names = poppable_inputs.pop(0)
if isinstance(engine_names, str):
engine_names = [engine_names]
returnable.append(gr.update())
for engine in engines:
if engine.name in engine_names:
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options]
poppable_inputs = poppable_inputs[engine.num_options:]
else:
poppable_inputs = poppable_inputs[engine.num_options:]
returnable.extend(gr.update() for _ in range(engine.num_options))
with open("local/presets.json", "wb") as f:
presets[preset_name] = new_preset
f.write(orjson.dumps(presets))
return returnable
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs)
output_gallery = gr.Markdown("aaa", render=False)
button.click(
self.run_generate_interface,