From f6e4fa6bd3c00099299d84d426533bfcfa6cd387 Mon Sep 17 00:00:00 2001 From: Paillat Date: Tue, 20 Feb 2024 14:54:25 +0100 Subject: [PATCH] Refactor UI code to add settings interface --- ui/gradio_ui.py | 45 +++++++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 34c81f6..aad4da6 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -4,7 +4,6 @@ import gradio as gr from src.engines import ENGINES, BaseEngine from src.chore import GenerationContext - class GenerateUI: def __init__(self): self.css = """.generate_button { @@ -37,13 +36,25 @@ class GenerateUI: Returns: tuple[list[gr.Blocks], list[str]]: A tuple containing a list of gr.Blocks interfaces and a list of interface names. """ - return ([self.get_generate_interface()], ["Generate"]) + return ( + [self.get_generate_interface(), self.get_settings_interface()], + ["Generate", "Settings"], + ) + + def get_settings_interface(self) -> gr.Blocks: + with gr.Blocks() as interface: + for engine_type, engines in ENGINES.items(): + engines = engines["classes"] + with gr.Tab(engine_type) as engine_tab: + for engine in engines: + engine.get_settings() + return interface def get_generate_interface(self) -> gr.Blocks: with gr.Blocks() as interface: with gr.Row(equal_height=False) as row: inputs = [] - with gr.Blocks() as col1: + with gr.Column(scale=2) as col1: for engine_type, engines in ENGINES.items(): multiselect = engines["multiple"] show_dropdown = engines.get("show_dropdown", True) @@ -54,16 +65,20 @@ class GenerateUI: choices=engine_names, value=engine_names[0], multiselect=multiselect, - label="Engine provider:" if not multiselect else "Engine providers:", + label="Engine provider:" + if not multiselect + else "Engine providers:", visible=show_dropdown, ) inputs.append(engine_dropdown) engine_rows = [] for i, engine in enumerate(engines): - with gr.Group( - visible=(i == 0) - ) as engine_row: - gr.Markdown(value = " ", label=f"{engine.name}", show_label=True) + with gr.Group(visible=(i == 0)) as engine_row: + gr.Markdown( + value=" ", + label=f"{engine.name}", + show_label=True, + ) engine_rows.append(engine_row) options = engine.get_options() inputs.extend(options) @@ -72,23 +87,25 @@ class GenerateUI: switcher, inputs=engine_dropdown, outputs=engine_rows ) - with gr.Blocks() as col2: + with gr.Column() as col2: button = gr.Button( "🚀", - scale=0.33, size="lg", variant="primary", elem_classes="generate_button", ) - button.click(self.run_generate_interface, inputs=inputs) + output_gallery = gr.Markdown("aaa", render=False) + button.click(self.run_generate_interface, inputs=inputs, outputs=output_gallery) + output_gallery.render() return interface - def run_generate_interface(self, *args) -> None: + def run_generate_interface(self, progress=gr.Progress(), *args) -> gr.update: + progress(0, desc="Loading engines... 🚀") options = self.repack_options(*args) arugments = {name.lower(): options[name] for name in ENGINES.keys()} - ctx = GenerationContext(**arugments) + ctx = GenerationContext(**arugments, progress=progress) ctx.process() # Here we go ! 🚀 - + return gr.update(value=ctx.get_file_path("final.mp4")) def repack_options(self, *args) -> dict[BaseEngine]: """ Repacks the options provided as arguments into a dictionary based on the selected engine.