From 00fdc740e0bbeee7b83d72130263137440de16f8 Mon Sep 17 00:00:00 2001 From: Paillat Date: Sun, 21 Apr 2024 21:43:17 +0200 Subject: [PATCH] :bug: fix(AnthropicLLMEngine): Use fix_busted_json to fix badly formatted json by claude --- src/engines/LLMEngine/AnthropicLLMEngine.py | 94 +++++++++++++-------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/src/engines/LLMEngine/AnthropicLLMEngine.py b/src/engines/LLMEngine/AnthropicLLMEngine.py index 548c2bc..981bb65 100644 --- a/src/engines/LLMEngine/AnthropicLLMEngine.py +++ b/src/engines/LLMEngine/AnthropicLLMEngine.py @@ -1,5 +1,6 @@ import anthropic import gradio as gr +import fix_busted_json import orjson from .BaseLLMEngine import BaseLLMEngine @@ -18,45 +19,52 @@ class AnthropicLLMEngine(BaseLLMEngine): def __init__(self, options: list) -> None: self.model = options[0] - self.client = anthropic.Anthropic( - api_key="YourAnthropicAPIKeyHere" - ) # Ensure API key is securely managed + api_key = self.retrieve_setting(identifier="anthropic_api_key")["api_key"] + self.client = anthropic.Anthropic(api_key=api_key) super().__init__() def generate( - self, - system_prompt: str, - chat_prompt: str, - max_tokens: int = 1024, - temperature: float = 1.0, - json_mode: bool = False, - top_p: float = 1, - frequency_penalty: float = 0, - presence_penalty: float = 0, + self, + system_prompt: str, + chat_prompt: str, + max_tokens: int = 1024, + temperature: float = 1.0, + json_mode: bool = False, + top_p: float = 1, + frequency_penalty: float = 0, + presence_penalty: float = 0, ) -> str | dict: - prompt = f"""{anthropic.HUMAN_PROMPT} {system_prompt} {anthropic.HUMAN_PROMPT} {chat_prompt} {anthropic.AI_PROMPT}""" - if json_mode: - # anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format - prompt += " {" - # noinspection PyArgumentList - response: anthropic.types.Completion = self.client.completions.create( - max_tokens_to_sample=max_tokens, - prompt=prompt, - model=self.model, - top_p=top_p, - temperature=temperature, - frequency_penalty=frequency_penalty, - ) + tries = 0 + while tries < 2: + messages = [ + {"role": "user", "content": chat_prompt}, + ] + if json_mode: + # anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format + messages.append({"role": "assistant", "content": "{"}) + response: anthropic.types.Message = self.client.messages.create( + max_tokens=max_tokens, + messages=messages, + model=self.model, + top_p=top_p, + temperature=temperature, + system=system_prompt, + ) - content = response.completion - if json_mode: - # we add back the opening curly brace wich is not included in the response since it is in the prompt - content = "{" + content - # we remove everything after the last closing curly brace - content = content[: content.rfind("}") + 1] - return orjson.loads(content) - else: - return content + content = response.content[0].text + if json_mode: + content = "{" + content + # we remove everything after the last closing curly brace + content = content[: content.rfind("}") + 1] + content = content.replace("\n", "") + try: + returnable = fix_busted_json.repair_json(content) + returnable = orjson.loads(returnable) + return returnable + except Exception as e: # noqa wait for library to imlement pep https://peps.python.org/pep-0352/ (Required Superclass for Exceptions + tries += 1 + else: + return content @classmethod def get_options(cls) -> list: @@ -68,3 +76,21 @@ class AnthropicLLMEngine(BaseLLMEngine): value=ANTHROPIC_POSSIBLE_MODELS[0], ) ] + + @classmethod + def get_settings(cls): + current_api_key = cls.retrieve_setting(identifier="anthropic_api_key") + current_api_key = current_api_key["api_key"] if current_api_key else "" + api_key_input = gr.Textbox( + label="Anthropic API Key", + type="password", + value=current_api_key, + ) + save = gr.Button("Save") + + def save_api_key(api_key: str): + cls.store_setting(identifier="anthropic_api_key", data={"api_key": api_key}) + gr.Info("API key saved successfully.") + return gr.update(value=api_key) + + save.click(save_api_key, inputs=[api_key_input])