diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 342d5e6..f92aa58 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -28,6 +28,55 @@ class GenerateUI: return switch + def get_preset_func(self): + def load_preset(preset_name, *selected_inputs) -> list[gr.update]: + with open("local/presets.json", "r") as f: + current_presets = orjson.loads(f.read()) + returnable = [] + if preset_name in current_presets.keys(): + # If the preset exists + preset = current_presets[preset_name] + for engine_type, engines in ENGINES.items(): + engines_classes = engines["classes"] + values = [[]] + for engine in engines_classes: + if engine.name in preset.get(engine_type, {}).keys(): + values[0].append(engine.name) + values.extend( + gr.update(value=value) for value in preset[engine_type][engine.name]) + else: + values.extend(gr.update() for _ in range(engine.num_options)) + if not engines["multiple"]: + if len(values[0]) > 0: + values[0] = values[0][0] + else: + values[0] = None + else: + ... + returnable.extend(values) + else: + poppable_inputs = list(selected_inputs) + new_preset = {} + for engine_type, engines in ENGINES.items(): + engines = engines["classes"] + new_preset[engine_type] = {} + engine_names = poppable_inputs.pop(0) + if isinstance(engine_names, str): + engine_names = [engine_names] + returnable.append(gr.update()) + for engine in engines: + if engine.name in engine_names: + new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] + poppable_inputs = poppable_inputs[engine.num_options:] + else: + poppable_inputs = poppable_inputs[engine.num_options:] + returnable.extend(gr.update() for _ in range(engine.num_options)) + with open("local/presets.json", "wb") as f: + current_presets[preset_name] = new_preset + f.write(orjson.dumps(current_presets)) + return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] + return load_preset + def get_ui(self): ui = gr.TabbedInterface( *self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css @@ -114,47 +163,10 @@ class GenerateUI: allow_custom_value=True, value=None ) - preset_button = gr.Button("Load") + preset_button = gr.Button("Load/Save", size="sm") + gr.Markdown("Input a name to save a new preset, or select an existing one to load it.") - def load_preset(preset_name, *inputs) -> list[gr.update]: - with open("local/presets.json", "r") as f: - presets = orjson.loads(f.read()) - returnable = [] - if preset_name in presets.keys(): - # If the preset exists - preset = presets[preset_name] - for engine_type, engines in ENGINES.items(): - engines = engines["classes"] - values = [[]] - for engine in engines: - if engine.name in preset.get(engine_type, {}).keys(): - values[0].append(engine.name) - values.extend( - gr.update(value=value) for value in preset[engine_type][engine.name]) - else: - values.extend(gr.update() for _ in range(engine.num_options)) - returnable.extend(values) - else: - poppable_inputs = list(inputs) - new_preset = {} - for engine_type, engines in ENGINES.items(): - engines = engines["classes"] - new_preset[engine_type] = {} - engine_names = poppable_inputs.pop(0) - if isinstance(engine_names, str): - engine_names = [engine_names] - returnable.append(gr.update()) - for engine in engines: - if engine.name in engine_names: - new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] - poppable_inputs = poppable_inputs[engine.num_options:] - else: - poppable_inputs = poppable_inputs[engine.num_options:] - returnable.extend(gr.update() for _ in range(engine.num_options)) - with open("local/presets.json", "wb") as f: - presets[preset_name] = new_preset - f.write(orjson.dumps(presets)) - return [gr.update(value=presets.keys()), *returnable] + load_preset = self.get_preset_func() preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown, *inputs]) output_gallery = gr.Markdown("aaa", render=False) @@ -169,8 +181,8 @@ class GenerateUI: def run_generate_interface(self, progress=gr.Progress(), *args) -> gr.update: progress(0, desc="Loading engines... 🚀") options = self.repack_options(*args) - arugments = {name.lower(): options[name] for name in ENGINES.keys()} - ctx = GenerationContext(**arugments, progress=progress) + arguments = {name.lower(): options[name] for name in ENGINES.keys()} + ctx = GenerationContext(**arguments, progress=progress) ctx.process() # Here we go ! 🚀 return gr.update(value=ctx.get_file_path("final.mp4"))