mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Formatting and stuff
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import gradio as gr
|
||||
import orjson
|
||||
import sys
|
||||
|
||||
from src.engines import ENGINES, BaseEngine
|
||||
from src.chore import GenerationContext
|
||||
@@ -31,7 +30,7 @@ class GenerateUI:
|
||||
|
||||
def get_ui(self):
|
||||
ui = gr.TabbedInterface(
|
||||
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
|
||||
*self.get_interfaces(), title="Viral Factory", theme=gr.themes.Soft(), css=self.css
|
||||
)
|
||||
return ui
|
||||
|
||||
@@ -116,12 +115,13 @@ class GenerateUI:
|
||||
value=None
|
||||
)
|
||||
preset_button = gr.Button("Load")
|
||||
|
||||
def load_preset(preset_name, *inputs) -> list[gr.update]:
|
||||
with open("local/presets.json", "r") as f:
|
||||
presets = orjson.loads(f.read())
|
||||
returnable = []
|
||||
if preset_name in presets.keys():
|
||||
# If the preset exists
|
||||
# If the preset exists
|
||||
preset = presets[preset_name]
|
||||
for engine_type, engines in ENGINES.items():
|
||||
engines = engines["classes"]
|
||||
@@ -129,7 +129,8 @@ class GenerateUI:
|
||||
for engine in engines:
|
||||
if engine.name in preset.get(engine_type, {}).keys():
|
||||
values[0].append(engine.name)
|
||||
values.extend(gr.update(value=value) for value in preset[engine_type][engine.name])
|
||||
values.extend(
|
||||
gr.update(value=value) for value in preset[engine_type][engine.name])
|
||||
else:
|
||||
values.extend(gr.update() for _ in range(engine.num_options))
|
||||
returnable.extend(values)
|
||||
@@ -154,7 +155,8 @@ class GenerateUI:
|
||||
presets[preset_name] = new_preset
|
||||
f.write(orjson.dumps(presets))
|
||||
return [gr.update(value=presets.keys()), *returnable]
|
||||
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs])
|
||||
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs],
|
||||
outputs=[preset_dropdown, *inputs])
|
||||
output_gallery = gr.Markdown("aaa", render=False)
|
||||
button.click(
|
||||
self.run_generate_interface,
|
||||
@@ -172,7 +174,7 @@ class GenerateUI:
|
||||
ctx.process() # Here we go ! 🚀
|
||||
return gr.update(value=ctx.get_file_path("final.mp4"))
|
||||
|
||||
def repack_options(self, *args) -> dict[BaseEngine]:
|
||||
def repack_options(self, *args) -> dict[str, list[BaseEngine]]:
|
||||
"""
|
||||
Repacks the options provided as arguments into a dictionary based on the selected engine.
|
||||
|
||||
@@ -198,10 +200,10 @@ class GenerateUI:
|
||||
options[engine_type].append(
|
||||
engine(options=args[: engine.num_options])
|
||||
)
|
||||
args = args[engine.num_options :]
|
||||
args = args[engine.num_options:]
|
||||
else:
|
||||
# we don't care about this, it's not the selected engine, we throw it away
|
||||
args = args[engine.num_options :]
|
||||
args = args[engine.num_options:]
|
||||
return options
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user