Implement the actual code

This commit is contained in:
2024-02-15 17:53:59 +01:00
parent bf03e0a766
commit a32f339981

View File

@@ -1,18 +1,36 @@
from .BaseScriptEngine import BaseScriptEngine
import gradio as gr import gradio as gr
import os
from .BaseScriptEngine import BaseScriptEngine
from ...utils.prompting import get_prompt
class ShowerThoughtsScriptEngine(BaseScriptEngine): class ShowerThoughtsScriptEngine(BaseScriptEngine):
name = "Shower Thoughts" name = "Shower Thoughts"
description = "Generate a Shower Thoughts script" description = "Generate a Shower Thoughts script"
num_options = 0 num_options = 1
def __init__(self, options: list[list | tuple | str | int | float | bool | None]): def __init__(self, options: list[list | tuple | str | int | float | bool | None]):
self.n_sentences = options[0]
super().__init__() super().__init__()
def generate(self, text: str, path: str) -> str: def generate(self) -> str:
pass sys_prompt, chat_prompt = get_prompt(
"shower_thoughts",
location=os.path.join(
os.path.dirname(os.path.abspath(__file__)), "prompts"
),
)
sys_prompt = sys_prompt.format(n_sentences=self.n_sentences)
chat_prompt = chat_prompt.format(n_sentences=self.n_sentences)
return self.ctx.powerfulllmengine.generate(
system_prompt=sys_prompt,
chat_prompt=chat_prompt,
max_tokens=20 * self.n_sentences,
temperature=1.3,
json_mode=False,
)
@classmethod @classmethod
def get_options(cls) -> list: def get_options(cls) -> list:
return [] return [gr.Number(label="Number of sentences", value=5, minimum=1)]