Files
viralfactory/src/engines/LLMEngine/AnthropicLLMEngine.py

69 lines
2.2 KiB
Python
Raw Normal View History

2024-02-15 14:11:16 +01:00
import anthropic
import gradio as gr
import orjson
from .BaseLLMEngine import BaseLLMEngine
ANTHROPIC_POSSIBLE_MODELS = [
"claude-2.1",
]
2024-02-15 17:53:18 +01:00
class AnthropicLLMEngine(BaseLLMEngine):
2024-02-15 14:11:16 +01:00
num_options = 1
2024-02-15 17:53:18 +01:00
name = "Anthropic"
description = "Anthropic language model engine."
2024-02-15 14:11:16 +01:00
def __init__(self, options: list) -> None:
self.model = options[0]
2024-02-15 17:53:18 +01:00
self.client = anthropic.Anthropic(
api_key="YourAnthropicAPIKeyHere"
) # Ensure API key is securely managed
2024-02-15 14:11:16 +01:00
super().__init__()
2024-02-15 17:53:18 +01:00
def generate(
2024-02-23 09:50:43 +01:00
self,
system_prompt: str,
chat_prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
json_mode: bool = False,
top_p: float = 1,
frequency_penalty: float = 0,
presence_penalty: float = 0,
2024-02-15 17:53:18 +01:00
) -> str | dict:
2024-02-15 14:11:16 +01:00
prompt = f"""{anthropic.HUMAN_PROMPT} {system_prompt} {anthropic.HUMAN_PROMPT} {chat_prompt} {anthropic.AI_PROMPT}"""
if json_mode:
2024-02-23 13:12:48 +01:00
# anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format
2024-02-15 14:11:16 +01:00
prompt += " {"
2024-02-23 13:12:48 +01:00
# noinspection PyArgumentList
2024-02-15 14:11:16 +01:00
response: anthropic.types.Completion = self.client.completions.create(
max_tokens_to_sample=max_tokens,
prompt=prompt,
model=self.model,
top_p=top_p,
temperature=temperature,
frequency_penalty=frequency_penalty,
)
content = response.completion
if json_mode:
2024-02-15 17:53:18 +01:00
# we add back the opening curly brace wich is not included in the response since it is in the prompt
2024-02-15 14:11:16 +01:00
content = "{" + content
2024-02-15 17:53:18 +01:00
# we remove everything after the last closing curly brace
content = content[: content.rfind("}") + 1]
2024-02-15 14:11:16 +01:00
return orjson.loads(content)
else:
return content
@classmethod
def get_options(cls) -> list:
return [
gr.Dropdown(
label="Model",
choices=ANTHROPIC_POSSIBLE_MODELS,
max_choices=1,
2024-02-15 17:53:18 +01:00
value=ANTHROPIC_POSSIBLE_MODELS[0],
2024-02-15 14:11:16 +01:00
)
]