mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Implement the actual code
This commit is contained in:
@@ -1,18 +1,36 @@
|
||||
from .BaseScriptEngine import BaseScriptEngine
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
from .BaseScriptEngine import BaseScriptEngine
|
||||
from ...utils.prompting import get_prompt
|
||||
|
||||
|
||||
class ShowerThoughtsScriptEngine(BaseScriptEngine):
|
||||
name = "Shower Thoughts"
|
||||
description = "Generate a Shower Thoughts script"
|
||||
num_options = 0
|
||||
num_options = 1
|
||||
|
||||
def __init__(self, options: list[list | tuple | str | int | float | bool | None]):
|
||||
self.n_sentences = options[0]
|
||||
super().__init__()
|
||||
|
||||
def generate(self, text: str, path: str) -> str:
|
||||
pass
|
||||
def generate(self) -> str:
|
||||
sys_prompt, chat_prompt = get_prompt(
|
||||
"shower_thoughts",
|
||||
location=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "prompts"
|
||||
),
|
||||
)
|
||||
sys_prompt = sys_prompt.format(n_sentences=self.n_sentences)
|
||||
chat_prompt = chat_prompt.format(n_sentences=self.n_sentences)
|
||||
return self.ctx.powerfulllmengine.generate(
|
||||
system_prompt=sys_prompt,
|
||||
chat_prompt=chat_prompt,
|
||||
max_tokens=20 * self.n_sentences,
|
||||
temperature=1.3,
|
||||
json_mode=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_options(cls) -> list:
|
||||
return []
|
||||
return [gr.Number(label="Number of sentences", value=5, minimum=1)]
|
||||
|
||||
Reference in New Issue
Block a user