🎨 Run linter

This commit is contained in:
2024-04-21 21:36:57 +02:00
parent ae437ce67d
commit 4cb395d279

View File

@@ -48,9 +48,13 @@ class GenerateUI:
if engine.name in preset.get(engine_type, {}).keys(): if engine.name in preset.get(engine_type, {}).keys():
values[0].append(engine.name) values[0].append(engine.name)
values.extend( values.extend(
gr.update(value=value) for value in preset[engine_type][engine.name]) gr.update(value=value)
for value in preset[engine_type][engine.name]
)
else: else:
values.extend(gr.update() for _ in range(engine.num_options)) values.extend(
gr.update() for _ in range(engine.num_options)
)
if not engines["multiple"]: if not engines["multiple"]:
if len(values[0]) > 0: if len(values[0]) > 0:
values[0] = values[0][0] values[0] = values[0][0]
@@ -62,7 +66,11 @@ class GenerateUI:
else: else:
raise gr.Error(f"Preset {preset_name} does not exist.") raise gr.Error(f"Preset {preset_name} does not exist.")
gr.Info(f"Preset {preset_name} loaded successfully.") gr.Info(f"Preset {preset_name} loaded successfully.")
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable] return [
gr.Dropdown(choices=list(current_presets.keys()), value=preset_name),
*returnable,
]
return load_preset return load_preset
def get_save_preset_func(self): def get_save_preset_func(self):
@@ -81,16 +89,22 @@ class GenerateUI:
returnable.append(gr.update()) returnable.append(gr.update())
for engine in engines: for engine in engines:
if engine.name in engine_names: if engine.name in engine_names:
new_preset[engine_type][engine.name] = poppable_inputs[:engine.num_options] new_preset[engine_type][engine.name] = poppable_inputs[
poppable_inputs = poppable_inputs[engine.num_options:] : engine.num_options
]
poppable_inputs = poppable_inputs[engine.num_options :]
else: else:
poppable_inputs = poppable_inputs[engine.num_options:] poppable_inputs = poppable_inputs[engine.num_options :]
returnable.extend(gr.update() for _ in range(engine.num_options)) returnable.extend(gr.update() for _ in range(engine.num_options))
with open("local/presets.json", "wb") as f: with open("local/presets.json", "wb") as f:
current_presets[preset_name] = new_preset current_presets[preset_name] = new_preset
f.write(orjson.dumps(current_presets)) f.write(orjson.dumps(current_presets))
gr.Info(f"Preset {preset_name} saved successfully.") gr.Info(f"Preset {preset_name} saved successfully.")
return [gr.Dropdown(choices=list(current_presets.keys()), value=preset_name), *returnable] return [
gr.Dropdown(choices=list(current_presets.keys()), value=preset_name),
*returnable,
]
return save_preset return save_preset
def get_delete_preset_func(self): def get_delete_preset_func(self):
@@ -103,10 +117,15 @@ class GenerateUI:
with open("local/presets.json", "wb") as f: with open("local/presets.json", "wb") as f:
f.write(orjson.dumps(current_presets)) f.write(orjson.dumps(current_presets))
return gr.Dropdown(choices=list(current_presets.keys()), value=None) return gr.Dropdown(choices=list(current_presets.keys()), value=None)
return delete_preset return delete_preset
def get_ui(self): def get_ui(self):
ui = gr.TabbedInterface( ui = gr.TabbedInterface(
*self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css *self.get_interfaces(),
title="Viral Factory",
theme=gr.themes.Soft(),
css=self.css,
) )
return ui return ui
@@ -157,9 +176,11 @@ class GenerateUI:
choices=engine_names, choices=engine_names,
value=engine_names[0], value=engine_names[0],
multiselect=multiselect, multiselect=multiselect,
label="Engine provider:" label=(
"Engine provider:"
if not multiselect if not multiselect
else "Engine providers:", else "Engine providers:"
),
visible=show_dropdown, visible=show_dropdown,
) )
inputs.append(engine_dropdown) inputs.append(engine_dropdown)
@@ -198,12 +219,19 @@ class GenerateUI:
load_preset = self.get_load_preset_func() load_preset = self.get_load_preset_func()
save_preset = self.get_save_preset_func() save_preset = self.get_save_preset_func()
delete_preset = self.get_delete_preset_func() delete_preset = self.get_delete_preset_func()
load_preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], load_preset_button.click(
outputs=[preset_dropdown, *inputs]) load_preset,
save_preset_button.click(save_preset, inputs=[preset_dropdown, *inputs], inputs=[preset_dropdown, *inputs],
outputs=[preset_dropdown, *inputs]) outputs=[preset_dropdown, *inputs],
delete_preset_button.click(delete_preset, inputs=preset_dropdown, )
outputs=preset_dropdown) save_preset_button.click(
save_preset,
inputs=[preset_dropdown, *inputs],
outputs=[preset_dropdown, *inputs],
)
delete_preset_button.click(
delete_preset, inputs=preset_dropdown, outputs=preset_dropdown
)
output_title = gr.Markdown(visible=True, render=False) output_title = gr.Markdown(visible=True, render=False)
output_description = gr.Markdown(visible=True, render=False) output_description = gr.Markdown(visible=True, render=False)
output_video = gr.Video(visible=True, render=False) output_video = gr.Video(visible=True, render=False)
@@ -211,7 +239,12 @@ class GenerateUI:
button.click( button.click(
self.run_generate_interface, self.run_generate_interface,
inputs=inputs, inputs=inputs,
outputs=[output_video, output_title, output_description, output_path], outputs=[
output_video,
output_title,
output_description,
output_path,
],
) )
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -222,13 +255,18 @@ class GenerateUI:
return interface return interface
def run_generate_interface(self, progress=gr.Progress(track_tqdm=True), *args) -> list[gr.update]: def run_generate_interface(self, progress=gr.Progress(), *args) -> list[gr.update]:
progress(0, desc="Loading engines... 🚀") progress(0, desc="Loading engines... 🚀")
options = self.repack_options(*args) options = self.repack_options(*args)
arguments = {name.lower(): options[name] for name in ENGINES.keys()} arguments = {name.lower(): options[name] for name in ENGINES.keys()}
ctx = GenerationContext(**arguments, progress=progress) ctx = GenerationContext(**arguments, progress=progress)
ctx.process() # Here we go ! 🚀 ctx.process() # Here we go ! 🚀
return [gr.update(value=ctx.get_file_path("final.mp4"), visible=True), gr.update(value=ctx.title, visible=True), gr.update(value=ctx.description, visible=True), gr.update(value=ctx.dir)] return [
gr.update(value=ctx.get_file_path("final.mp4"), visible=True),
gr.update(value=ctx.title, visible=True),
gr.update(value=ctx.description, visible=True),
gr.update(value=ctx.dir),
]
def repack_options(self, *args) -> dict[str, list[BaseEngine]]: def repack_options(self, *args) -> dict[str, list[BaseEngine]]:
""" """
@@ -256,10 +294,10 @@ class GenerateUI:
options[engine_type].append( options[engine_type].append(
engine(options=args[: engine.num_options]) engine(options=args[: engine.num_options])
) )
args = args[engine.num_options:] args = args[engine.num_options :]
else: else:
# we don't care about this, it's not the selected engine, we throw it away # we don't care about this, it's not the selected engine, we throw it away
args = args[engine.num_options:] args = args[engine.num_options :]
return options return options