Fix presets logic

This commit is contained in:
2024-02-23 12:27:04 +01:00
parent d7fa8945d1
commit ca520296d4

View File

@@ -28,6 +28,55 @@ class GenerateUI:
return switch
def get_preset_func(self):
def load_preset(preset_name, *selected_inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f:
current_presets = orjson.loads(f.read())
returnable = []
if preset_name in current_presets.keys():
# If the preset exists
preset = current_presets[preset_name]
for engine_type, engines in ENGINES.items():
engines_classes = engines["classes"]
values = [[]]
for engine in engines_classes:
if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name)
values.extend(
gr.update(value=value) for value in preset[engine_type][engine.name])
else:
values.extend(gr.update() for _ in range(engine.num_options))
if not engines["multiple"]:
if len(values[0]) > 0:
values[0] = values[0][0]
else:
values[0] = None
else:
...
returnable.extend(values)
else:
poppable_inputs = list(selected_inputs)
new_preset = {}
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
new_preset[engine_type] = {}
engine_names = poppable_inputs.pop(0)
if isinstance(engine_names, str):
engine_names = [engine_names]
returnable.append(gr.update())
for engine in engines:
if engine.name in engine_names:
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options]
poppable_inputs = poppable_inputs[engine.num_options:]
else:
poppable_inputs = poppable_inputs[engine.num_options:]
returnable.extend(gr.update() for _ in range(engine.num_options))
with open("local/presets.json", "wb") as f:
current_presets[preset_name] = new_preset
f.write(orjson.dumps(current_presets))
return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable]
return load_preset
def get_ui(self):
ui = gr.TabbedInterface(
*self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css
@@ -114,47 +163,10 @@ class GenerateUI:
allow_custom_value=True,
value=None
)
preset_button = gr.Button("Load")
preset_button = gr.Button("Load/Save", size="sm")
gr.Markdown("Input a name to save a new preset, or select an existing one to load it.")
def load_preset(preset_name, *inputs) -> list[gr.update]:
with open("local/presets.json", "r") as f:
presets = orjson.loads(f.read())
returnable = []
if preset_name in presets.keys():
# If the preset exists
preset = presets[preset_name]
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
values = [[]]
for engine in engines:
if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name)
values.extend(
gr.update(value=value) for value in preset[engine_type][engine.name])
else:
values.extend(gr.update() for _ in range(engine.num_options))
returnable.extend(values)
else:
poppable_inputs = list(inputs)
new_preset = {}
for engine_type, engines in ENGINES.items():
engines = engines["classes"]
new_preset[engine_type] = {}
engine_names = poppable_inputs.pop(0)
if isinstance(engine_names, str):
engine_names = [engine_names]
returnable.append(gr.update())
for engine in engines:
if engine.name in engine_names:
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options]
poppable_inputs = poppable_inputs[engine.num_options:]
else:
poppable_inputs = poppable_inputs[engine.num_options:]
returnable.extend(gr.update() for _ in range(engine.num_options))
with open("local/presets.json", "wb") as f:
presets[preset_name] = new_preset
f.write(orjson.dumps(presets))
return [gr.update(value=presets.keys()), *returnable]
load_preset = self.get_preset_func()
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs],
outputs=[preset_dropdown, *inputs])
output_gallery = gr.Markdown("aaa", render=False)
@@ -169,8 +181,8 @@ class GenerateUI:
def run_generate_interface(self, progress=gr.Progress(), *args) -> gr.update:
progress(0, desc="Loading engines... 🚀")
options = self.repack_options(*args)
arugments = {name.lower(): options[name] for name in ENGINES.keys()}
ctx = GenerationContext(**arugments, progress=progress)
arguments = {name.lower(): options[name] for name in ENGINES.keys()}
ctx = GenerationContext(**arguments, progress=progress)
ctx.process() # Here we go ! 🚀
return gr.update(value=ctx.get_file_path("final.mp4"))