mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Refactor UI code to add settings interface
This commit is contained in:
@@ -4,7 +4,6 @@ import gradio as gr
|
|||||||
from src.engines import ENGINES, BaseEngine
|
from src.engines import ENGINES, BaseEngine
|
||||||
from src.chore import GenerationContext
|
from src.chore import GenerationContext
|
||||||
|
|
||||||
|
|
||||||
class GenerateUI:
|
class GenerateUI:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.css = """.generate_button {
|
self.css = """.generate_button {
|
||||||
@@ -37,13 +36,25 @@ class GenerateUI:
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[list[gr.Blocks], list[str]]: A tuple containing a list of gr.Blocks interfaces and a list of interface names.
|
tuple[list[gr.Blocks], list[str]]: A tuple containing a list of gr.Blocks interfaces and a list of interface names.
|
||||||
"""
|
"""
|
||||||
return ([self.get_generate_interface()], ["Generate"])
|
return (
|
||||||
|
[self.get_generate_interface(), self.get_settings_interface()],
|
||||||
|
["Generate", "Settings"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_settings_interface(self) -> gr.Blocks:
|
||||||
|
with gr.Blocks() as interface:
|
||||||
|
for engine_type, engines in ENGINES.items():
|
||||||
|
engines = engines["classes"]
|
||||||
|
with gr.Tab(engine_type) as engine_tab:
|
||||||
|
for engine in engines:
|
||||||
|
engine.get_settings()
|
||||||
|
return interface
|
||||||
|
|
||||||
def get_generate_interface(self) -> gr.Blocks:
|
def get_generate_interface(self) -> gr.Blocks:
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Row(equal_height=False) as row:
|
with gr.Row(equal_height=False) as row:
|
||||||
inputs = []
|
inputs = []
|
||||||
with gr.Blocks() as col1:
|
with gr.Column(scale=2) as col1:
|
||||||
for engine_type, engines in ENGINES.items():
|
for engine_type, engines in ENGINES.items():
|
||||||
multiselect = engines["multiple"]
|
multiselect = engines["multiple"]
|
||||||
show_dropdown = engines.get("show_dropdown", True)
|
show_dropdown = engines.get("show_dropdown", True)
|
||||||
@@ -54,16 +65,20 @@ class GenerateUI:
|
|||||||
choices=engine_names,
|
choices=engine_names,
|
||||||
value=engine_names[0],
|
value=engine_names[0],
|
||||||
multiselect=multiselect,
|
multiselect=multiselect,
|
||||||
label="Engine provider:" if not multiselect else "Engine providers:",
|
label="Engine provider:"
|
||||||
|
if not multiselect
|
||||||
|
else "Engine providers:",
|
||||||
visible=show_dropdown,
|
visible=show_dropdown,
|
||||||
)
|
)
|
||||||
inputs.append(engine_dropdown)
|
inputs.append(engine_dropdown)
|
||||||
engine_rows = []
|
engine_rows = []
|
||||||
for i, engine in enumerate(engines):
|
for i, engine in enumerate(engines):
|
||||||
with gr.Group(
|
with gr.Group(visible=(i == 0)) as engine_row:
|
||||||
visible=(i == 0)
|
gr.Markdown(
|
||||||
) as engine_row:
|
value=" ",
|
||||||
gr.Markdown(value = " ", label=f"{engine.name}", show_label=True)
|
label=f"{engine.name}",
|
||||||
|
show_label=True,
|
||||||
|
)
|
||||||
engine_rows.append(engine_row)
|
engine_rows.append(engine_row)
|
||||||
options = engine.get_options()
|
options = engine.get_options()
|
||||||
inputs.extend(options)
|
inputs.extend(options)
|
||||||
@@ -72,23 +87,25 @@ class GenerateUI:
|
|||||||
switcher, inputs=engine_dropdown, outputs=engine_rows
|
switcher, inputs=engine_dropdown, outputs=engine_rows
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Blocks() as col2:
|
with gr.Column() as col2:
|
||||||
button = gr.Button(
|
button = gr.Button(
|
||||||
"🚀",
|
"🚀",
|
||||||
scale=0.33,
|
|
||||||
size="lg",
|
size="lg",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
elem_classes="generate_button",
|
elem_classes="generate_button",
|
||||||
)
|
)
|
||||||
button.click(self.run_generate_interface, inputs=inputs)
|
output_gallery = gr.Markdown("aaa", render=False)
|
||||||
|
button.click(self.run_generate_interface, inputs=inputs, outputs=output_gallery)
|
||||||
|
output_gallery.render()
|
||||||
return interface
|
return interface
|
||||||
|
|
||||||
def run_generate_interface(self, *args) -> None:
|
def run_generate_interface(self, progress=gr.Progress(), *args) -> gr.update:
|
||||||
|
progress(0, desc="Loading engines... 🚀")
|
||||||
options = self.repack_options(*args)
|
options = self.repack_options(*args)
|
||||||
arugments = {name.lower(): options[name] for name in ENGINES.keys()}
|
arugments = {name.lower(): options[name] for name in ENGINES.keys()}
|
||||||
ctx = GenerationContext(**arugments)
|
ctx = GenerationContext(**arugments, progress=progress)
|
||||||
ctx.process() # Here we go ! 🚀
|
ctx.process() # Here we go ! 🚀
|
||||||
|
return gr.update(value=ctx.get_file_path("final.mp4"))
|
||||||
def repack_options(self, *args) -> dict[BaseEngine]:
|
def repack_options(self, *args) -> dict[BaseEngine]:
|
||||||
"""
|
"""
|
||||||
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
||||||
|
|||||||
Reference in New Issue
Block a user