Files
viralfactory/src/engines/ScriptEngine/ShowerThoughtsScriptEngine.py

37 lines
1.2 KiB
Python
Raw Normal View History

2024-02-14 17:49:51 +01:00
import gradio as gr
2024-02-15 17:53:59 +01:00
import os
from .BaseScriptEngine import BaseScriptEngine
from ...utils.prompting import get_prompt
2024-02-14 17:49:51 +01:00
class ShowerThoughtsScriptEngine(BaseScriptEngine):
name = "Shower Thoughts"
description = "Generate a Shower Thoughts script"
2024-02-15 17:53:59 +01:00
num_options = 1
2024-02-14 17:49:51 +01:00
def __init__(self, options: list[list | tuple | str | int | float | bool | None]):
2024-02-15 17:53:59 +01:00
self.n_sentences = options[0]
2024-02-14 17:49:51 +01:00
super().__init__()
2024-02-15 17:53:59 +01:00
def generate(self) -> str:
sys_prompt, chat_prompt = get_prompt(
"shower_thoughts",
location=os.path.join(
os.path.dirname(os.path.abspath(__file__)), "prompts"
),
)
sys_prompt = sys_prompt.format(n_sentences=self.n_sentences)
chat_prompt = chat_prompt.format(n_sentences=self.n_sentences)
return self.ctx.powerfulllmengine.generate(
system_prompt=sys_prompt,
chat_prompt=chat_prompt,
max_tokens=20 * self.n_sentences,
temperature=1.3,
json_mode=False,
)
2024-02-14 17:49:51 +01:00
@classmethod
def get_options(cls) -> list:
2024-02-15 17:53:59 +01:00
return [gr.Number(label="Number of sentences", value=5, minimum=1)]