Fix a commit error

This commit is contained in:
2024-02-15 17:53:46 +01:00
parent e8cf4bad50
commit bf03e0a766

View File

@@ -1,43 +1,65 @@
import anthropic import openai
import gradio as gr import gradio as gr
import orjson
from abc import ABC, abstractmethod
from .BaseLLMEngine import BaseLLMEngine from .BaseLLMEngine import BaseLLMEngine
# Assuming these are the models supported by Anthropics that you wish to include OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json
ANTHROPIC_POSSIBLE_MODELS = [ "gpt-3.5-turbo-0125",
"claude-2.1", "gpt-4-turbo-preview",
# Add more models as needed
] ]
class AnthropicsLLMEngine(BaseLLMEngine):
class OpenaiLLMEngine(BaseLLMEngine):
num_options = 1 num_options = 1
name = "Anthropics" name = "OpenAI"
description = "Anthropics language model engine." description = "OpenAI language model engine."
def __init__(self, options: list) -> None: def __init__(self, options: list) -> None:
self.model = options[0] self.model = options[0]
self.client = anthropic.Anthropic(api_key="YourAnthropicAPIKeyHere") # Ensure API key is securely managed
super().__init__() super().__init__()
def generate(self, system_prompt: str, chat_prompt: str, max_tokens: int = 1024, temperature: float = 1.0, json_mode: bool = False, top_p: float = 1, frequency_penalty: float = 0, presence_penalty: float = 0) -> str | dict: def generate(
# Note: Adjust the parameters as per Anthropics API capabilities self,
message = self.client.messages.create( system_prompt: str,
max_tokens=max_tokens, chat_prompt: str,
max_tokens: int = 512,
temperature: float = 1.0,
json_mode: bool = False,
top_p: float = 1,
frequency_penalty: float = 0,
presence_penalty: float = 0,
) -> str | dict:
response = openai.chat.completions.create(
model=self.model,
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": chat_prompt}, {"role": "user", "content": chat_prompt},
], ],
model=self.model, max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
response_format={"type": "json_object"}
if json_mode
else openai._types.NOT_GIVEN,
)
return (
response.choices[0].message.content
if not json_mode
else orjson.loads(response.choices[0].message.content)
) )
return message.content
@classmethod @classmethod
def get_options(cls) -> list: def get_options(cls) -> list:
return [ return [
gr.Dropdown( gr.Dropdown(
label="Model", label="Model",
choices=ANTHROPIC_POSSIBLE_MODELS, choices=OPENAI_POSSIBLE_MODELS,
max_choices=1, max_choices=1,
value=ANTHROPIC_POSSIBLE_MODELS[0] value=OPENAI_POSSIBLE_MODELS[0],
) )
] ]