mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Add BaseEngine and GenerationContext classes to gradio_ui.py
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from src.engines import ENGINES
|
|
||||||
|
from src.engines import ENGINES, BaseEngine
|
||||||
|
from src.chore import GenerationContext
|
||||||
|
|
||||||
|
|
||||||
class GenerateUI:
|
class GenerateUI:
|
||||||
@@ -22,10 +24,7 @@ class GenerateUI:
|
|||||||
|
|
||||||
def launch_ui(self):
|
def launch_ui(self):
|
||||||
ui = gr.TabbedInterface(
|
ui = gr.TabbedInterface(
|
||||||
*self.get_interfaces(),
|
*self.get_interfaces(), "Viral Automator", "NoCrypt/miku", css=self.css
|
||||||
"Viral Automator",
|
|
||||||
"NoCrypt/miku",
|
|
||||||
css=self.css
|
|
||||||
)
|
)
|
||||||
ui.launch()
|
ui.launch()
|
||||||
|
|
||||||
@@ -52,7 +51,9 @@ class GenerateUI:
|
|||||||
inputs.append(engine_dropdown)
|
inputs.append(engine_dropdown)
|
||||||
engine_rows = []
|
engine_rows = []
|
||||||
for i, engine in enumerate(engines):
|
for i, engine in enumerate(engines):
|
||||||
with gr.Row(equal_height=True, visible=(i == 0)) as engine_row:
|
with gr.Row(
|
||||||
|
equal_height=True, visible=(i == 0)
|
||||||
|
) as engine_row:
|
||||||
engine_rows.append(engine_row)
|
engine_rows.append(engine_row)
|
||||||
options = engine.get_options()
|
options = engine.get_options()
|
||||||
inputs.extend(options)
|
inputs.extend(options)
|
||||||
@@ -62,11 +63,23 @@ class GenerateUI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with gr.Blocks() as col2:
|
with gr.Blocks() as col2:
|
||||||
button = gr.Button("🚀", size="lg", variant="primary", elem_classes="generate_button")
|
button = gr.Button(
|
||||||
button.click(self.repack_options, inputs=inputs)
|
"🚀",
|
||||||
|
scale=0.33,
|
||||||
|
size="lg",
|
||||||
|
variant="primary",
|
||||||
|
elem_classes="generate_button",
|
||||||
|
)
|
||||||
|
button.click(self.run_generate_interface, inputs=inputs)
|
||||||
return interface
|
return interface
|
||||||
|
|
||||||
def repack_options(self, *args):
|
def run_generate_interface(self, *args) -> None:
|
||||||
|
options = self.repack_options(*args)
|
||||||
|
arugments = {name.lower(): options[name] for name in ENGINES.keys()}
|
||||||
|
ctx = GenerationContext(**arugments)
|
||||||
|
ctx.process() # Here we go ! 🚀
|
||||||
|
|
||||||
|
def repack_options(self, *args) -> dict[BaseEngine]:
|
||||||
"""
|
"""
|
||||||
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
||||||
|
|
||||||
@@ -87,7 +100,7 @@ class GenerateUI:
|
|||||||
else:
|
else:
|
||||||
# we don't care about this, it's not the selected engine, we throw it away
|
# we don't care about this, it's not the selected engine, we throw it away
|
||||||
args = args[engine.num_options :]
|
args = args[engine.num_options :]
|
||||||
print(options)
|
return options
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user