Files
viralfactory/src/engines/LLMEngine/OpenaiLLMEngine.py

66 lines
1.8 KiB
Python
Raw Normal View History

2024-02-15 17:53:46 +01:00
import openai
2024-02-14 17:49:51 +01:00
import gradio as gr
2024-02-15 17:53:46 +01:00
import orjson
from abc import ABC, abstractmethod
2024-02-14 17:49:51 +01:00
from .BaseLLMEngine import BaseLLMEngine
2024-02-15 17:53:46 +01:00
OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json
"gpt-3.5-turbo-0125",
"gpt-4-turbo-preview",
2024-02-14 17:49:51 +01:00
]
2024-02-15 17:53:46 +01:00
class OpenaiLLMEngine(BaseLLMEngine):
2024-02-14 17:49:51 +01:00
num_options = 1
2024-02-15 17:53:46 +01:00
name = "OpenAI"
description = "OpenAI language model engine."
2024-02-14 17:49:51 +01:00
def __init__(self, options: list) -> None:
self.model = options[0]
super().__init__()
2024-02-15 14:11:16 +01:00
2024-02-15 17:53:46 +01:00
def generate(
self,
system_prompt: str,
chat_prompt: str,
max_tokens: int = 512,
temperature: float = 1.0,
json_mode: bool = False,
top_p: float = 1,
frequency_penalty: float = 0,
presence_penalty: float = 0,
) -> str | dict:
response = openai.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": chat_prompt},
],
2024-02-15 17:53:46 +01:00
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
response_format={"type": "json_object"}
if json_mode
else openai._types.NOT_GIVEN,
)
return (
response.choices[0].message.content
if not json_mode
else orjson.loads(response.choices[0].message.content)
)
@classmethod
def get_options(cls) -> list:
2024-02-14 17:49:51 +01:00
return [
gr.Dropdown(
label="Model",
2024-02-15 17:53:46 +01:00
choices=OPENAI_POSSIBLE_MODELS,
2024-02-14 17:49:51 +01:00
max_choices=1,
2024-02-15 17:53:46 +01:00
value=OPENAI_POSSIBLE_MODELS[0],
2024-02-14 17:49:51 +01:00
)
2024-02-15 14:11:16 +01:00
]