mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
Fix a commit error
This commit is contained in:
@@ -1,43 +1,65 @@
|
|||||||
import anthropic
|
import openai
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from .BaseLLMEngine import BaseLLMEngine
|
from .BaseLLMEngine import BaseLLMEngine
|
||||||
|
|
||||||
# Assuming these are the models supported by Anthropics that you wish to include
|
OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json
|
||||||
ANTHROPIC_POSSIBLE_MODELS = [
|
"gpt-3.5-turbo-0125",
|
||||||
"claude-2.1",
|
"gpt-4-turbo-preview",
|
||||||
# Add more models as needed
|
|
||||||
]
|
]
|
||||||
|
|
||||||
class AnthropicsLLMEngine(BaseLLMEngine):
|
|
||||||
|
class OpenaiLLMEngine(BaseLLMEngine):
|
||||||
num_options = 1
|
num_options = 1
|
||||||
name = "Anthropics"
|
name = "OpenAI"
|
||||||
description = "Anthropics language model engine."
|
description = "OpenAI language model engine."
|
||||||
|
|
||||||
def __init__(self, options: list) -> None:
|
def __init__(self, options: list) -> None:
|
||||||
self.model = options[0]
|
self.model = options[0]
|
||||||
self.client = anthropic.Anthropic(api_key="YourAnthropicAPIKeyHere") # Ensure API key is securely managed
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def generate(self, system_prompt: str, chat_prompt: str, max_tokens: int = 1024, temperature: float = 1.0, json_mode: bool = False, top_p: float = 1, frequency_penalty: float = 0, presence_penalty: float = 0) -> str | dict:
|
def generate(
|
||||||
# Note: Adjust the parameters as per Anthropics API capabilities
|
self,
|
||||||
message = self.client.messages.create(
|
system_prompt: str,
|
||||||
max_tokens=max_tokens,
|
chat_prompt: str,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
json_mode: bool = False,
|
||||||
|
top_p: float = 1,
|
||||||
|
frequency_penalty: float = 0,
|
||||||
|
presence_penalty: float = 0,
|
||||||
|
) -> str | dict:
|
||||||
|
response = openai.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": chat_prompt},
|
{"role": "user", "content": chat_prompt},
|
||||||
],
|
],
|
||||||
model=self.model,
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format={"type": "json_object"}
|
||||||
|
if json_mode
|
||||||
|
else openai._types.NOT_GIVEN,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
response.choices[0].message.content
|
||||||
|
if not json_mode
|
||||||
|
else orjson.loads(response.choices[0].message.content)
|
||||||
)
|
)
|
||||||
return message.content
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_options(cls) -> list:
|
def get_options(cls) -> list:
|
||||||
return [
|
return [
|
||||||
gr.Dropdown(
|
gr.Dropdown(
|
||||||
label="Model",
|
label="Model",
|
||||||
choices=ANTHROPIC_POSSIBLE_MODELS,
|
choices=OPENAI_POSSIBLE_MODELS,
|
||||||
max_choices=1,
|
max_choices=1,
|
||||||
value=ANTHROPIC_POSSIBLE_MODELS[0]
|
value=OPENAI_POSSIBLE_MODELS[0],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user