2024-02-14 17:49:51 +01:00
|
|
|
import os
|
|
|
|
|
import gradio as gr
|
2024-02-22 18:02:52 +01:00
|
|
|
import orjson
|
2024-02-15 17:49:06 +01:00
|
|
|
|
|
|
|
|
from src.engines import ENGINES, BaseEngine
|
|
|
|
|
from src.chore import GenerationContext
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
class GenerateUI:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.css = """.generate_button {
|
|
|
|
|
font-size: 5rem !important
|
|
|
|
|
}
|
|
|
|
|
"""
|
2024-02-22 18:02:52 +01:00
|
|
|
def get_presets(self):
|
|
|
|
|
with open("local/presets.json", "r") as f:
|
|
|
|
|
return orjson.loads(f.read())
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
|
2024-02-18 00:56:49 +01:00
|
|
|
def switch(selected: str | list[str]):
|
|
|
|
|
if isinstance(selected, str):
|
|
|
|
|
selected = [selected]
|
2024-02-14 17:49:51 +01:00
|
|
|
returnable = []
|
|
|
|
|
for i, name in enumerate(engine_names):
|
2024-02-18 00:56:49 +01:00
|
|
|
returnable.append(gr.update(visible=name in selected))
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
return returnable
|
|
|
|
|
|
|
|
|
|
return switch
|
|
|
|
|
|
|
|
|
|
def launch_ui(self):
|
|
|
|
|
ui = gr.TabbedInterface(
|
2024-02-16 20:21:22 +01:00
|
|
|
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
|
2024-02-14 17:49:51 +01:00
|
|
|
)
|
|
|
|
|
ui.launch()
|
|
|
|
|
|
|
|
|
|
def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]:
|
|
|
|
|
"""
|
|
|
|
|
Returns a tuple containing a list of gr.Blocks interfaces and a list of interface names.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple[list[gr.Blocks], list[str]]: A tuple containing a list of gr.Blocks interfaces and a list of interface names.
|
|
|
|
|
"""
|
2024-02-20 14:54:25 +01:00
|
|
|
return (
|
|
|
|
|
[self.get_generate_interface(), self.get_settings_interface()],
|
|
|
|
|
["Generate", "Settings"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_settings_interface(self) -> gr.Blocks:
|
|
|
|
|
with gr.Blocks() as interface:
|
|
|
|
|
for engine_type, engines in ENGINES.items():
|
|
|
|
|
engines = engines["classes"]
|
|
|
|
|
with gr.Tab(engine_type) as engine_tab:
|
|
|
|
|
for engine in engines:
|
2024-02-20 16:22:47 +01:00
|
|
|
gr.Markdown(f"## {engine.name}")
|
2024-02-20 14:54:25 +01:00
|
|
|
engine.get_settings()
|
|
|
|
|
return interface
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
def get_generate_interface(self) -> gr.Blocks:
|
|
|
|
|
with gr.Blocks() as interface:
|
2024-02-15 11:23:36 +01:00
|
|
|
with gr.Row(equal_height=False) as row:
|
2024-02-14 17:49:51 +01:00
|
|
|
inputs = []
|
2024-02-20 14:54:25 +01:00
|
|
|
with gr.Column(scale=2) as col1:
|
2024-02-14 17:49:51 +01:00
|
|
|
for engine_type, engines in ENGINES.items():
|
2024-02-16 20:21:22 +01:00
|
|
|
multiselect = engines["multiple"]
|
2024-02-18 22:50:20 +01:00
|
|
|
show_dropdown = engines.get("show_dropdown", True)
|
2024-02-16 20:21:22 +01:00
|
|
|
engines = engines["classes"]
|
2024-02-14 17:49:51 +01:00
|
|
|
with gr.Tab(engine_type) as engine_tab:
|
|
|
|
|
engine_names = [engine.name for engine in engines]
|
|
|
|
|
engine_dropdown = gr.Dropdown(
|
2024-02-16 20:21:22 +01:00
|
|
|
choices=engine_names,
|
|
|
|
|
value=engine_names[0],
|
|
|
|
|
multiselect=multiselect,
|
2024-02-20 14:54:25 +01:00
|
|
|
label="Engine provider:"
|
|
|
|
|
if not multiselect
|
|
|
|
|
else "Engine providers:",
|
2024-02-18 22:50:20 +01:00
|
|
|
visible=show_dropdown,
|
2024-02-14 17:49:51 +01:00
|
|
|
)
|
|
|
|
|
inputs.append(engine_dropdown)
|
|
|
|
|
engine_rows = []
|
|
|
|
|
for i, engine in enumerate(engines):
|
2024-02-22 15:13:08 +01:00
|
|
|
with gr.Column(visible=(i == 0)) as engine_row:
|
2024-02-21 09:06:36 +01:00
|
|
|
gr.Markdown(value=f"## {engine.name}")
|
2024-02-14 17:49:51 +01:00
|
|
|
engine_rows.append(engine_row)
|
|
|
|
|
options = engine.get_options()
|
|
|
|
|
inputs.extend(options)
|
|
|
|
|
switcher = self.get_switcher_func(engine_names)
|
|
|
|
|
engine_dropdown.change(
|
|
|
|
|
switcher, inputs=engine_dropdown, outputs=engine_rows
|
|
|
|
|
)
|
|
|
|
|
|
2024-02-20 14:54:25 +01:00
|
|
|
with gr.Column() as col2:
|
2024-02-15 17:49:06 +01:00
|
|
|
button = gr.Button(
|
|
|
|
|
"🚀",
|
|
|
|
|
size="lg",
|
|
|
|
|
variant="primary",
|
|
|
|
|
elem_classes="generate_button",
|
|
|
|
|
)
|
2024-02-22 18:02:52 +01:00
|
|
|
with gr.Row() as preset_row:
|
|
|
|
|
presets = self.get_presets()
|
|
|
|
|
preset_dropdown = gr.Dropdown(
|
|
|
|
|
choices=list(presets.keys()),
|
|
|
|
|
label="Presets",
|
|
|
|
|
allow_custom_value=True,
|
|
|
|
|
value=None
|
|
|
|
|
)
|
|
|
|
|
preset_button = gr.Button("Load")
|
|
|
|
|
def load_preset(preset_name, *inputs) -> list[gr.update]:
|
|
|
|
|
with open("local/presets.json", "r") as f:
|
|
|
|
|
presets = orjson.loads(f.read())
|
|
|
|
|
returnable = []
|
|
|
|
|
if preset_name in presets.keys():
|
|
|
|
|
# If the preset exists
|
|
|
|
|
preset = presets[preset_name]
|
|
|
|
|
for engine_type, engines in ENGINES.items():
|
|
|
|
|
engines = engines["classes"]
|
|
|
|
|
values = [[]]
|
|
|
|
|
for engine in engines:
|
|
|
|
|
if engine.name in preset.get(engine_type, {}).keys():
|
|
|
|
|
values[0].append(engine.name)
|
|
|
|
|
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
|
|
|
|
|
else:
|
|
|
|
|
values.extend(gr.update() for _ in range(engine.num_options))
|
|
|
|
|
returnable.extend(values)
|
|
|
|
|
else:
|
|
|
|
|
poppable_inputs = list(inputs)
|
|
|
|
|
new_preset = {}
|
|
|
|
|
for engine_type, engines in ENGINES.items():
|
|
|
|
|
engines = engines["classes"]
|
|
|
|
|
new_preset[engine_type] = {}
|
|
|
|
|
engine_names = poppable_inputs.pop(0)
|
|
|
|
|
if isinstance(engine_names, str):
|
|
|
|
|
engine_names = [engine_names]
|
|
|
|
|
returnable.append(gr.update())
|
|
|
|
|
for engine in engines:
|
|
|
|
|
if engine.name in engine_names:
|
|
|
|
|
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options]
|
|
|
|
|
poppable_inputs = poppable_inputs[engine.num_options:]
|
|
|
|
|
else:
|
|
|
|
|
poppable_inputs = poppable_inputs[engine.num_options:]
|
|
|
|
|
returnable.extend(gr.update() for _ in range(engine.num_options))
|
|
|
|
|
with open("local/presets.json", "wb") as f:
|
|
|
|
|
presets[preset_name] = new_preset
|
|
|
|
|
f.write(orjson.dumps(presets))
|
|
|
|
|
return returnable
|
|
|
|
|
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs)
|
2024-02-20 14:54:25 +01:00
|
|
|
output_gallery = gr.Markdown("aaa", render=False)
|
2024-02-20 16:23:15 +01:00
|
|
|
button.click(
|
|
|
|
|
self.run_generate_interface,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
outputs=output_gallery,
|
|
|
|
|
)
|
2024-02-20 14:54:25 +01:00
|
|
|
output_gallery.render()
|
2024-02-14 17:49:51 +01:00
|
|
|
return interface
|
|
|
|
|
|
2024-02-20 14:54:25 +01:00
|
|
|
def run_generate_interface(self, progress=gr.Progress(), *args) -> gr.update:
|
|
|
|
|
progress(0, desc="Loading engines... 🚀")
|
2024-02-15 17:49:06 +01:00
|
|
|
options = self.repack_options(*args)
|
|
|
|
|
arugments = {name.lower(): options[name] for name in ENGINES.keys()}
|
2024-02-20 14:54:25 +01:00
|
|
|
ctx = GenerationContext(**arugments, progress=progress)
|
2024-02-15 17:49:06 +01:00
|
|
|
ctx.process() # Here we go ! 🚀
|
2024-02-20 14:54:25 +01:00
|
|
|
return gr.update(value=ctx.get_file_path("final.mp4"))
|
2024-02-20 16:23:15 +01:00
|
|
|
|
2024-02-15 17:49:06 +01:00
|
|
|
def repack_options(self, *args) -> dict[BaseEngine]:
|
2024-02-14 17:49:51 +01:00
|
|
|
"""
|
|
|
|
|
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*args: Variable number of arguments representing the options for each engine.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict: A dictionary containing the repacked options, where the keys are the engine types and the values are the corresponding engine options.
|
|
|
|
|
"""
|
|
|
|
|
options = {}
|
|
|
|
|
args = list(args)
|
|
|
|
|
for engine_type, engines in ENGINES.items():
|
2024-02-16 20:21:22 +01:00
|
|
|
engines = engines["classes"]
|
|
|
|
|
selected_engines = args.pop(0)
|
|
|
|
|
if isinstance(selected_engines, str):
|
|
|
|
|
selected_engines = [selected_engines]
|
|
|
|
|
options[engine_type] = []
|
|
|
|
|
# for every selected engine
|
2024-02-14 17:49:51 +01:00
|
|
|
for engine in engines:
|
2024-02-16 20:21:22 +01:00
|
|
|
# if it correspods to the selected engine
|
|
|
|
|
if engine.name in selected_engines:
|
|
|
|
|
# we add it to the options
|
|
|
|
|
options[engine_type].append(
|
|
|
|
|
engine(options=args[: engine.num_options])
|
|
|
|
|
)
|
2024-02-14 17:49:51 +01:00
|
|
|
args = args[engine.num_options :]
|
|
|
|
|
else:
|
|
|
|
|
# we don't care about this, it's not the selected engine, we throw it away
|
|
|
|
|
args = args[engine.num_options :]
|
2024-02-15 17:49:06 +01:00
|
|
|
return options
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
ui_generator = GenerateUI()
|
|
|
|
|
ui_generator.launch_ui()
|