mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Add presets functionality to GenerateUI class
This commit is contained in:
@@ -1,16 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import orjson
|
||||||
|
|
||||||
from src.engines import ENGINES, BaseEngine
|
from src.engines import ENGINES, BaseEngine
|
||||||
from src.chore import GenerationContext
|
from src.chore import GenerationContext
|
||||||
|
|
||||||
|
|
||||||
class GenerateUI:
|
class GenerateUI:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.css = """.generate_button {
|
self.css = """.generate_button {
|
||||||
font-size: 5rem !important
|
font-size: 5rem !important
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
def get_presets(self):
|
||||||
|
with open("local/presets.json", "r") as f:
|
||||||
|
return orjson.loads(f.read())
|
||||||
|
|
||||||
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
|
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
|
||||||
def switch(selected: str | list[str]):
|
def switch(selected: str | list[str]):
|
||||||
@@ -92,6 +95,54 @@ class GenerateUI:
|
|||||||
variant="primary",
|
variant="primary",
|
||||||
elem_classes="generate_button",
|
elem_classes="generate_button",
|
||||||
)
|
)
|
||||||
|
with gr.Row() as preset_row:
|
||||||
|
presets = self.get_presets()
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
choices=list(presets.keys()),
|
||||||
|
label="Presets",
|
||||||
|
allow_custom_value=True,
|
||||||
|
value=None
|
||||||
|
)
|
||||||
|
preset_button = gr.Button("Load")
|
||||||
|
def load_preset(preset_name, *inputs) -> list[gr.update]:
|
||||||
|
with open("local/presets.json", "r") as f:
|
||||||
|
presets = orjson.loads(f.read())
|
||||||
|
returnable = []
|
||||||
|
if preset_name in presets.keys():
|
||||||
|
# If the preset exists
|
||||||
|
preset = presets[preset_name]
|
||||||
|
for engine_type, engines in ENGINES.items():
|
||||||
|
engines = engines["classes"]
|
||||||
|
values = [[]]
|
||||||
|
for engine in engines:
|
||||||
|
if engine.name in preset.get(engine_type, {}).keys():
|
||||||
|
values[0].append(engine.name)
|
||||||
|
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
|
||||||
|
else:
|
||||||
|
values.extend(gr.update() for _ in range(engine.num_options))
|
||||||
|
returnable.extend(values)
|
||||||
|
else:
|
||||||
|
poppable_inputs = list(inputs)
|
||||||
|
new_preset = {}
|
||||||
|
for engine_type, engines in ENGINES.items():
|
||||||
|
engines = engines["classes"]
|
||||||
|
new_preset[engine_type] = {}
|
||||||
|
engine_names = poppable_inputs.pop(0)
|
||||||
|
if isinstance(engine_names, str):
|
||||||
|
engine_names = [engine_names]
|
||||||
|
returnable.append(gr.update())
|
||||||
|
for engine in engines:
|
||||||
|
if engine.name in engine_names:
|
||||||
|
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options]
|
||||||
|
poppable_inputs = poppable_inputs[engine.num_options:]
|
||||||
|
else:
|
||||||
|
poppable_inputs = poppable_inputs[engine.num_options:]
|
||||||
|
returnable.extend(gr.update() for _ in range(engine.num_options))
|
||||||
|
with open("local/presets.json", "wb") as f:
|
||||||
|
presets[preset_name] = new_preset
|
||||||
|
f.write(orjson.dumps(presets))
|
||||||
|
return returnable
|
||||||
|
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs)
|
||||||
output_gallery = gr.Markdown("aaa", render=False)
|
output_gallery = gr.Markdown("aaa", render=False)
|
||||||
button.click(
|
button.click(
|
||||||
self.run_generate_interface,
|
self.run_generate_interface,
|
||||||
|
|||||||
Reference in New Issue
Block a user