Files
viralfactory/ui/gradio_ui.py
Paillat 57bcf0af8e fix(GenerationContext.py): fix typo in variable name powerfulllmengine to powerfulllmengine for better readability
feat(GenerationContext.py): add setup_dir method to create a directory for output files with a timestamp
feat(GenerationContext.py): call setup_dir method before generating script and synthesizing audio to ensure output directory exists
feat(prompts/fix_captions.yaml): add a new prompt file to provide instructions for fixing captions
fix(BaseTTSEngine.py): add force_duration method to adjust audio clip duration if it exceeds a specified duration
feat(CoquiTTSEngine.py): add options for forcing duration and specifying duration in the UI
feat(utils/prompting.py): add get_prompt function to load prompt files from a specified location
fix(gradio_ui.py): set equal_height=True for engine_rows to ensure consistent height for engine options
2024-02-15 12:27:13 +01:00

96 lines
3.7 KiB
Python

import os
import gradio as gr
from src.engines import ENGINES
class GenerateUI:
def __init__(self):
self.css = """.generate_button {
font-size: 5rem !important
}
"""
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
def switch(selected: str):
returnable = []
for i, name in enumerate(engine_names):
returnable.append(gr.update(visible=name == selected))
return returnable
return switch
def launch_ui(self):
ui = gr.TabbedInterface(
*self.get_interfaces(),
"Viral Automator",
"NoCrypt/miku",
css=self.css
)
ui.launch()
def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]:
"""
Returns a tuple containing a list of gr.Blocks interfaces and a list of interface names.
Returns:
tuple[list[gr.Blocks], list[str]]: A tuple containing a list of gr.Blocks interfaces and a list of interface names.
"""
return ([self.get_generate_interface()], ["Generate"])
def get_generate_interface(self) -> gr.Blocks:
with gr.Blocks() as interface:
with gr.Row(equal_height=False) as row:
inputs = []
with gr.Blocks() as col1:
for engine_type, engines in ENGINES.items():
with gr.Tab(engine_type) as engine_tab:
engine_names = [engine.name for engine in engines]
engine_dropdown = gr.Dropdown(
choices=engine_names, value=engine_names[0]
)
inputs.append(engine_dropdown)
engine_rows = []
for i, engine in enumerate(engines):
with gr.Row(equal_height=True, visible=(i == 0)) as engine_row:
engine_rows.append(engine_row)
options = engine.get_options()
inputs.extend(options)
switcher = self.get_switcher_func(engine_names)
engine_dropdown.change(
switcher, inputs=engine_dropdown, outputs=engine_rows
)
with gr.Blocks() as col2:
button = gr.Button("🚀", size="lg", variant="primary", elem_classes="generate_button")
button.click(self.repack_options, inputs=inputs)
return interface
def repack_options(self, *args):
"""
Repacks the options provided as arguments into a dictionary based on the selected engine.
Args:
*args: Variable number of arguments representing the options for each engine.
Returns:
dict: A dictionary containing the repacked options, where the keys are the engine types and the values are the corresponding engine options.
"""
options = {}
args = list(args)
for engine_type, engines in ENGINES.items():
engine_name = args.pop(0)
for engine in engines:
if engine.name == engine_name:
options[engine_type] = engine(options=args[: engine.num_options])
args = args[engine.num_options :]
else:
# we don't care about this, it's not the selected engine, we throw it away
args = args[engine.num_options :]
print(options)
if __name__ == "__main__":
ui_generator = GenerateUI()
ui_generator.launch_ui()