diff --git a/ui/gradio_ui.py b/ui/gradio_ui.py index 9826de3..329cb30 100644 --- a/ui/gradio_ui.py +++ b/ui/gradio_ui.py @@ -24,7 +24,7 @@ class GenerateUI: def launch_ui(self): ui = gr.TabbedInterface( - *self.get_interfaces(), "Viral Factory", "NoCrypt/miku", css=self.css + *self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css ) ui.launch() @@ -43,10 +43,15 @@ class GenerateUI: inputs = [] with gr.Blocks() as col1: for engine_type, engines in ENGINES.items(): + multiselect = engines["multiple"] + engines = engines["classes"] with gr.Tab(engine_type) as engine_tab: engine_names = [engine.name for engine in engines] engine_dropdown = gr.Dropdown( - choices=engine_names, value=engine_names[0] + choices=engine_names, + value=engine_names[0], + multiselect=multiselect, + label="Engine provider:" ) inputs.append(engine_dropdown) engine_rows = [] @@ -92,10 +97,19 @@ class GenerateUI: options = {} args = list(args) for engine_type, engines in ENGINES.items(): - engine_name = args.pop(0) + engines = engines["classes"] + selected_engines = args.pop(0) + if isinstance(selected_engines, str): + selected_engines = [selected_engines] + options[engine_type] = [] + # for every selected engine for engine in engines: - if engine.name == engine_name: - options[engine_type] = engine(options=args[: engine.num_options]) + # if it correspods to the selected engine + if engine.name in selected_engines: + # we add it to the options + options[engine_type].append( + engine(options=args[: engine.num_options]) + ) args = args[engine.num_options :] else: # we don't care about this, it's not the selected engine, we throw it away