diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 254ba5a..bd21ea1 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -30,7 +30,7 @@ class GenerateUI: return switch - def get_preset_func(self): + def get_load_preset_func(self): def load_preset(preset_name, *selected_inputs) -> list[gr.update]: with open("local/presets.json", "r") as f: current_presets = orjson.loads(f.read()) @@ -57,28 +57,46 @@ class GenerateUI: ... returnable.extend(values) else: - poppable_inputs = list(selected_inputs) - new_preset = {} - for engine_type, engines in ENGINES.items(): - engines = engines["classes"] - new_preset[engine_type] = {} - engine_names = poppable_inputs.pop(0) - if isinstance(engine_names, str): - engine_names = [engine_names] - returnable.append(gr.update()) - for engine in engines: - if engine.name in engine_names: - new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] - poppable_inputs = poppable_inputs[engine.num_options:] - else: - poppable_inputs = poppable_inputs[engine.num_options:] - returnable.extend(gr.update() for _ in range(engine.num_options)) - with open("local/presets.json", "wb") as f: - current_presets[preset_name] = new_preset - f.write(orjson.dumps(current_presets)) + raise ValueError("Preset not found") return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] return load_preset + def get_save_preset_func(self): + def save_preset(preset_name, *selected_inputs) -> list[gr.update]: + with open("local/presets.json", "r") as f: + current_presets = orjson.loads(f.read()) + returnable = [] + poppable_inputs = list(selected_inputs) + new_preset = {} + for engine_type, engines in ENGINES.items(): + engines = engines["classes"] + new_preset[engine_type] = {} + engine_names = poppable_inputs.pop(0) + if isinstance(engine_names, str): + engine_names = [engine_names] + returnable.append(gr.update()) + for engine in engines: + if engine.name in engine_names: + new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] + poppable_inputs = poppable_inputs[engine.num_options:] + else: + poppable_inputs = poppable_inputs[engine.num_options:] + returnable.extend(gr.update() for _ in range(engine.num_options)) + with open("local/presets.json", "wb") as f: + current_presets[preset_name] = new_preset + f.write(orjson.dumps(current_presets)) + return [gr.update(choices=list(current_presets.keys()), value=preset_name), *returnable] + return save_preset + + def get_delete_preset_func(self): + def delete_preset(preset_name) -> list[gr.update]: + with open("local/presets.json", "r") as f: + current_presets = orjson.loads(f.read()) + current_presets.pop(preset_name) + with open("local/presets.json", "wb") as f: + f.write(orjson.dumps(current_presets)) + return [gr.update(choices=list(current_presets.keys()), value=None)] + return delete_preset def get_ui(self): ui = gr.TabbedInterface( *self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css @@ -157,36 +175,41 @@ class GenerateUI: variant="primary", elem_classes="generate_button", ) - with gr.Row() as preset_row: - presets = self.get_presets() - preset_dropdown = gr.Dropdown( - choices=list(presets.keys()), - label="Presets", - allow_custom_value=True, - value=None - ) - preset_button = gr.Button("Load/Save", size="sm") - gr.Markdown("Input a name to save a new preset, or select an existing one to load it.") + gr.Markdown(value="## Presets") + presets = self.get_presets() + preset_dropdown = gr.Dropdown( + choices=list(presets.keys()), + show_label=False, + label="dd", + allow_custom_value=True, + value=None, + ) + load_preset_button = gr.Button("📂", size="sm", variant="primary") + save_preset_button = gr.Button("💾", size="sm", variant="secondary") + delete_preset_button = gr.Button("🗑️", size="sm", variant="stop") - load_preset = self.get_preset_func() - preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], + load_preset = self.get_load_preset_func() + save_preset = self.get_save_preset_func() + delete_preset = self.get_delete_preset_func() + load_preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown, *inputs]) + save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs], + outputs=[preset_dropdown, *inputs]) + delete_preset_button.click(delete_preset, inputs=[preset_dropdown], + outputs=[preset_dropdown]) output_title = gr.Markdown(visible=True, render=False) output_description = gr.Markdown(visible=True, render=False) output_video = gr.Video(visible=True, render=False) - open_folder = gr.Button("📁", size="sm", variant="secondary", render=False) output_path = gr.State(value=None) button.click( self.run_generate_interface, inputs=inputs, - outputs=[output_video, output_title, output_description, output_path, open_folder], + outputs=[output_video, output_title, output_description, output_path], ) with gr.Row(): with gr.Column(): output_title.render() output_description.render() - open_folder.render() - open_folder.click(lambda x: os.system(f"open {os.path.abspath(x)}") if os.name == "posix" else os.system(f"explorer {os.path.abspath(x)}"), inputs=output_path) with gr.Column(): output_video.render() @@ -198,7 +221,7 @@ class GenerateUI: arguments = {name.lower(): options[name] for name in ENGINES.keys()} ctx = GenerationContext(**arguments, progress=progress) ctx.process() # Here we go ! 🚀 - return [gr.update(value=ctx.get_file_path("final.mp4"), visible=True), gr.update(value=ctx.title, visible=True), gr.update(value=ctx.description, visible=True), gr.update(value=ctx.dir), gr.update(visible=True)] + return [gr.update(value=ctx.get_file_path("final.mp4"), visible=True), gr.update(value=ctx.title, visible=True), gr.update(value=ctx.description, visible=True), gr.update(value=ctx.dir)] def repack_options(self, *args) -> dict[str, list[BaseEngine]]: """