Add BaseEngine and GenerationContext classes to gradio_ui.py

This commit is contained in:
2024-02-15 17:49:06 +01:00
parent a8da179269
commit 3338077c18

View File

@@ -1,6 +1,8 @@
import os import os
import gradio as gr import gradio as gr
from src.engines import ENGINES
from src.engines import ENGINES, BaseEngine
from src.chore import GenerationContext
class GenerateUI: class GenerateUI:
@@ -22,10 +24,7 @@ class GenerateUI:
def launch_ui(self): def launch_ui(self):
ui = gr.TabbedInterface( ui = gr.TabbedInterface(
*self.get_interfaces(), *self.get_interfaces(), "Viral Automator", "NoCrypt/miku", css=self.css
"Viral Automator",
"NoCrypt/miku",
css=self.css
) )
ui.launch() ui.launch()
@@ -52,7 +51,9 @@ class GenerateUI:
inputs.append(engine_dropdown) inputs.append(engine_dropdown)
engine_rows = [] engine_rows = []
for i, engine in enumerate(engines): for i, engine in enumerate(engines):
with gr.Row(equal_height=True, visible=(i == 0)) as engine_row: with gr.Row(
equal_height=True, visible=(i == 0)
) as engine_row:
engine_rows.append(engine_row) engine_rows.append(engine_row)
options = engine.get_options() options = engine.get_options()
inputs.extend(options) inputs.extend(options)
@@ -62,11 +63,23 @@ class GenerateUI:
) )
with gr.Blocks() as col2: with gr.Blocks() as col2:
button = gr.Button("🚀", size="lg", variant="primary", elem_classes="generate_button") button = gr.Button(
button.click(self.repack_options, inputs=inputs) "🚀",
scale=0.33,
size="lg",
variant="primary",
elem_classes="generate_button",
)
button.click(self.run_generate_interface, inputs=inputs)
return interface return interface
def repack_options(self, *args): def run_generate_interface(self, *args) -> None:
options = self.repack_options(*args)
arugments = {name.lower(): options[name] for name in ENGINES.keys()}
ctx = GenerationContext(**arugments)
ctx.process() # Here we go ! 🚀
def repack_options(self, *args) -> dict[BaseEngine]:
""" """
Repacks the options provided as arguments into a dictionary based on the selected engine. Repacks the options provided as arguments into a dictionary based on the selected engine.
@@ -87,7 +100,7 @@ class GenerateUI:
else: else:
# we don't care about this, it's not the selected engine, we throw it away # we don't care about this, it's not the selected engine, we throw it away
args = args[engine.num_options :] args = args[engine.num_options :]
print(options) return options
if __name__ == "__main__": if __name__ == "__main__":