2024-02-14 17:49:51 +01:00
|
|
|
import gradio as gr
|
2024-02-15 17:53:59 +01:00
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from .BaseScriptEngine import BaseScriptEngine
|
|
|
|
|
from ...utils.prompting import get_prompt
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShowerThoughtsScriptEngine(BaseScriptEngine):
|
|
|
|
|
name = "Shower Thoughts"
|
|
|
|
|
description = "Generate a Shower Thoughts script"
|
2024-02-15 17:53:59 +01:00
|
|
|
num_options = 1
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
def __init__(self, options: list[list | tuple | str | int | float | bool | None]):
|
2024-02-15 17:53:59 +01:00
|
|
|
self.n_sentences = options[0]
|
2024-02-14 17:49:51 +01:00
|
|
|
super().__init__()
|
|
|
|
|
|
2024-02-15 17:53:59 +01:00
|
|
|
def generate(self) -> str:
|
|
|
|
|
sys_prompt, chat_prompt = get_prompt(
|
|
|
|
|
"shower_thoughts",
|
|
|
|
|
location=os.path.join(
|
|
|
|
|
os.path.dirname(os.path.abspath(__file__)), "prompts"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
sys_prompt = sys_prompt.format(n_sentences=self.n_sentences)
|
|
|
|
|
chat_prompt = chat_prompt.format(n_sentences=self.n_sentences)
|
2024-02-22 15:14:36 +01:00
|
|
|
self.ctx.script = self.ctx.powerfulllmengine.generate(
|
|
|
|
|
system_prompt=sys_prompt,
|
|
|
|
|
chat_prompt=chat_prompt,
|
|
|
|
|
max_tokens=20 * self.n_sentences,
|
|
|
|
|
temperature=1.3,
|
|
|
|
|
json_mode=False,
|
|
|
|
|
).strip()
|
2024-02-14 17:49:51 +01:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_options(cls) -> list:
|
2024-02-15 17:53:59 +01:00
|
|
|
return [gr.Number(label="Number of sentences", value=5, minimum=1)]
|