Files
viralfactory/src/engines/LLMEngine/OpenaiLLMEngine.py

113 lines
3.7 KiB
Python
Raw Normal View History

2024-02-14 17:49:51 +01:00
import gradio as gr
2024-02-23 09:50:43 +01:00
import openai
import logging
2024-02-29 16:51:40 +01:00
from openai import OpenAI
2024-02-15 17:53:46 +01:00
import orjson
2024-02-14 17:49:51 +01:00
from .BaseLLMEngine import BaseLLMEngine
2024-02-15 17:53:46 +01:00
OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json
"gpt-3.5-turbo-0125",
"gpt-4-turbo",
"gpt-4o"
2024-02-14 17:49:51 +01:00
]
2024-02-15 17:53:46 +01:00
class OpenaiLLMEngine(BaseLLMEngine):
2024-02-14 17:49:51 +01:00
num_options = 1
2024-02-15 17:53:46 +01:00
name = "OpenAI"
description = "OpenAI language model engine."
2024-02-14 17:49:51 +01:00
def __init__(self, options: list) -> None:
self.model = options[0]
2024-02-29 16:51:40 +01:00
api_key = self.retrieve_setting(identifier="openai_api_key")
if not api_key:
raise ValueError("OpenAI API key is not set.")
self.client = OpenAI(api_key=api_key["api_key"])
super().__init__()
2024-02-15 14:11:16 +01:00
2024-02-15 17:53:46 +01:00
def generate(
self,
system_prompt: str,
chat_prompt: str = "",
messages: list = [],
max_tokens: int = 512,
temperature: float = 1.0,
json_mode: bool = False,
top_p: float = 1,
frequency_penalty: float = 0,
presence_penalty: float = 0,
2024-02-15 17:53:46 +01:00
) -> str | dict:
logging.info(
f"Generating with OpenAI model {self.model} and system prompt: \n{system_prompt} and chat prompt: \n{chat_prompt[0:100]}..."
)
if chat_prompt:
messages = [
{"role": "user", "content": chat_prompt},
*messages,
]
for i, message in enumerate(messages):
if type(message["content"]) is list:
for i, content in enumerate(message["content"]):
if content["type"] == "image":
message["content"][i] = {
"type": "image_url",
"image_url": {
"url": f"data:{content['source']['media_type']};base64,{content['source']['data']}",
},
}
messages[i] = message
2024-02-29 16:51:40 +01:00
response = self.client.chat.completions.create(
2024-02-15 17:53:46 +01:00
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
*messages,
],
max_tokens=int(max_tokens) if max_tokens else openai.NOT_GIVEN,
2024-02-15 17:53:46 +01:00
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
response_format=(
{"type": "json_object"} if json_mode else openai.NOT_GIVEN
),
2024-02-15 17:53:46 +01:00
)
return (
response.choices[0].message.content
if not json_mode
else orjson.loads(response.choices[0].message.content)
)
@classmethod
def get_options(cls) -> list:
2024-02-14 17:49:51 +01:00
return [
gr.Dropdown(
label="Model",
2024-02-15 17:53:46 +01:00
choices=OPENAI_POSSIBLE_MODELS,
value=OPENAI_POSSIBLE_MODELS[0],
2024-02-14 17:49:51 +01:00
)
2024-02-15 14:11:16 +01:00
]
2024-02-29 16:51:40 +01:00
@classmethod
def get_settings(cls):
current_api_key = cls.retrieve_setting(identifier="openai_api_key")
current_api_key = current_api_key["api_key"] if current_api_key else ""
api_key_input = gr.Textbox(
label="OpenAI API Key",
type="password",
value=current_api_key,
)
save = gr.Button("Save")
def save_api_key(api_key: str):
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
gr.Info("API key saved successfully.")
return gr.update(value=api_key)
save.click(save_api_key, inputs=[api_key_input])
@property
def supports_vision(self) -> bool:
return True if self.model in ["gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4o"] else False