mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Add presets feature
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import orjson
|
import orjson
|
||||||
|
import sys
|
||||||
|
|
||||||
from src.engines import ENGINES, BaseEngine
|
from src.engines import ENGINES, BaseEngine
|
||||||
from src.chore import GenerationContext
|
from src.chore import GenerationContext
|
||||||
@@ -15,8 +16,8 @@ class GenerateUI:
|
|||||||
with open("local/presets.json", "r") as f:
|
with open("local/presets.json", "r") as f:
|
||||||
return orjson.loads(f.read())
|
return orjson.loads(f.read())
|
||||||
|
|
||||||
def get_switcher_func(self, engine_names: list[str]) -> list[gr.update]:
|
def get_switcher_func(self, engine_names: list[str]) -> callable:
|
||||||
def switch(selected: str | list[str]):
|
def switch(selected: str | list[str]) -> list[gr.update]:
|
||||||
if isinstance(selected, str):
|
if isinstance(selected, str):
|
||||||
selected = [selected]
|
selected = [selected]
|
||||||
returnable = []
|
returnable = []
|
||||||
@@ -27,11 +28,15 @@ class GenerateUI:
|
|||||||
|
|
||||||
return switch
|
return switch
|
||||||
|
|
||||||
def launch_ui(self):
|
def get_ui(self):
|
||||||
ui = gr.TabbedInterface(
|
ui = gr.TabbedInterface(
|
||||||
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
|
*self.get_interfaces(), "Viral Factory", gr.themes.Soft(), css=self.css
|
||||||
)
|
)
|
||||||
ui.launch()
|
return ui
|
||||||
|
|
||||||
|
def launch_ui(self):
|
||||||
|
self.ui = self.get_ui()
|
||||||
|
self.ui.launch()
|
||||||
|
|
||||||
def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]:
|
def get_interfaces(self) -> tuple[list[gr.Blocks], list[str]]:
|
||||||
"""
|
"""
|
||||||
@@ -47,6 +52,12 @@ class GenerateUI:
|
|||||||
|
|
||||||
def get_settings_interface(self) -> gr.Blocks:
|
def get_settings_interface(self) -> gr.Blocks:
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
|
reload_ui = gr.Button("Reload UI", variant="primary")
|
||||||
|
def reload():
|
||||||
|
self.ui.close()
|
||||||
|
sys.exit("Reload")
|
||||||
|
|
||||||
|
reload_ui.click(reload)
|
||||||
for engine_type, engines in ENGINES.items():
|
for engine_type, engines in ENGINES.items():
|
||||||
engines = engines["classes"]
|
engines = engines["classes"]
|
||||||
with gr.Tab(engine_type) as engine_tab:
|
with gr.Tab(engine_type) as engine_tab:
|
||||||
@@ -141,8 +152,8 @@ class GenerateUI:
|
|||||||
with open("local/presets.json", "wb") as f:
|
with open("local/presets.json", "wb") as f:
|
||||||
presets[preset_name] = new_preset
|
presets[preset_name] = new_preset
|
||||||
f.write(orjson.dumps(presets))
|
f.write(orjson.dumps(presets))
|
||||||
return returnable
|
return [gr.update(value=presets.keys()), *returnable]
|
||||||
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=inputs)
|
preset_button.click(load_preset, inputs=[preset_dropdown, *inputs], outputs=[preset_dropdown,*inputs])
|
||||||
output_gallery = gr.Markdown("aaa", render=False)
|
output_gallery = gr.Markdown("aaa", render=False)
|
||||||
button.click(
|
button.click(
|
||||||
self.run_generate_interface,
|
self.run_generate_interface,
|
||||||
@@ -191,8 +202,3 @@ class GenerateUI:
|
|||||||
# we don't care about this, it's not the selected engine, we throw it away
|
# we don't care about this, it's not the selected engine, we throw it away
|
||||||
args = args[engine.num_options :]
|
args = args[engine.num_options :]
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
ui_generator = GenerateUI()
|
|
||||||
ui_generator.launch_ui()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user