diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index b07ec6a..a9770d5 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -48,9 +48,13 @@ class GenerateUI: if engine.name in preset.get(engine_type, {}).keys(): values[0].append(engine.name) values.extend( - gr.update(value=value) for value in preset[engine_type][engine.name]) + gr.update(value=value) + for value in preset[engine_type][engine.name] + ) else: - values.extend(gr.update() for _ in range(engine.num_options)) + values.extend( + gr.update() for _ in range(engine.num_options) + ) if not engines["multiple"]: if len(values[0]) > 0: values[0] = values[0][0] @@ -62,7 +66,11 @@ class GenerateUI: else: raise gr.Error(f"Preset {preset_name} does not exist.") gr.Info(f"Preset {preset_name} loaded successfully.") - return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable] + return [ + gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), + *returnable, + ] + return load_preset def get_save_preset_func(self): @@ -81,16 +89,22 @@ class GenerateUI: returnable.append(gr.update()) for engine in engines: if engine.name in engine_names: - new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] - poppable_inputs = poppable_inputs[engine.num_options:] + new_preset[engine_type][engine.name] = poppable_inputs[ + : engine.num_options + ] + poppable_inputs = poppable_inputs[engine.num_options :] else: - poppable_inputs = poppable_inputs[engine.num_options:] + poppable_inputs = poppable_inputs[engine.num_options :] returnable.extend(gr.update() for _ in range(engine.num_options)) with open("local/presets.json", "wb") as f: current_presets[preset_name] = new_preset f.write(orjson.dumps(current_presets)) gr.Info(f"Preset {preset_name} saved successfully.") - return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable] + return [ + gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), + *returnable, + ] + return save_preset def get_delete_preset_func(self): @@ -103,10 +117,15 @@ class GenerateUI: with open("local/presets.json", "wb") as f: f.write(orjson.dumps(current_presets)) return gr.Dropdown(choices=list(current_presets.keys()), value=None) + return delete_preset + def get_ui(self): ui = gr.TabbedInterface( - *self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css + *self.get_interfaces(), + title="Viral Factory", + theme=gr.themes.Soft(), + css=self.css, ) return ui @@ -157,9 +176,11 @@ class GenerateUI: choices=engine_names, value=engine_names[0], multiselect=multiselect, - label="Engine provider:" - if not multiselect - else "Engine providers:", + label=( + "Engine provider:" + if not multiselect + else "Engine providers:" + ), visible=show_dropdown, ) inputs.append(engine_dropdown) @@ -198,12 +219,19 @@ class GenerateUI: load_preset = self.get_load_preset_func() save_preset = self.get_save_preset_func() delete_preset = self.get_delete_preset_func() - load_preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], - outputs=[preset_dropdown, *inputs]) - save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs], - outputs=[preset_dropdown, *inputs]) - delete_preset_button.click(delete_preset, inputs=preset_dropdown, - outputs=preset_dropdown) + load_preset_button.click( + load_preset, + inputs=[preset_dropdown, *inputs], + outputs=[preset_dropdown, *inputs], + ) + save_preset_button.click( + save_preset, + inputs=[preset_dropdown, *inputs], + outputs=[preset_dropdown, *inputs], + ) + delete_preset_button.click( + delete_preset, inputs=preset_dropdown, outputs=preset_dropdown + ) output_title = gr.Markdown(visible=True, render=False) output_description = gr.Markdown(visible=True, render=False) output_video = gr.Video(visible=True, render=False) @@ -211,7 +239,12 @@ class GenerateUI: button.click( self.run_generate_interface, inputs=inputs, - outputs=[output_video, output_title, output_description, output_path], + outputs=[ + output_video, + output_title, + output_description, + output_path, + ], ) with gr.Row(): with gr.Column(): @@ -222,13 +255,18 @@ class GenerateUI: return interface - def run_generate_interface(self, progress=gr.Progress(track_tqdm=True), *args) -> list[gr.update]: + def run_generate_interface(self, progress=gr.Progress(), *args) -> list[gr.update]: progress(0, desc="Loading engines... 🚀") options = self.repack_options(*args) arguments = {name.lower(): options[name] for name in ENGINES.keys()} ctx = GenerationContext(**arguments, progress=progress) ctx.process() # Here we go ! 🚀 - return [gr.update(value=ctx.get_file_path("final.mp4"), visible=True), gr.update(value=ctx.title, visible=True), gr.update(value=ctx.description, visible=True), gr.update(value=ctx.dir)] + return [ + gr.update(value=ctx.get_file_path("final.mp4"), visible=True), + gr.update(value=ctx.title, visible=True), + gr.update(value=ctx.description, visible=True), + gr.update(value=ctx.dir), + ] def repack_options(self, *args) -> dict[str, list[BaseEngine]]: """ @@ -256,10 +294,10 @@ class GenerateUI: options[engine_type].append( engine(options=args[: engine.num_options]) ) - args = args[engine.num_options:] + args = args[engine.num_options :] else: # we don't care about this, it's not the selected engine, we throw it away - args = args[engine.num_options:] + args = args[engine.num_options :] return options