2023-07-15 12:20:38 +02:00
|
|
|
"""
|
|
|
|
|
This file provides a Python module that wraps the OpenAI API for making API calls.
|
|
|
|
|
|
|
|
|
|
The module includes:
|
|
|
|
|
|
|
|
|
|
- Functions for generating responses using chat-based models and handling API errors.
|
|
|
|
|
- Constants for chat and text models and their maximum token limits.
|
|
|
|
|
- Imports for required modules, including OpenAI and asyncio.
|
|
|
|
|
- A color formatting class, `bcolors`, for console output.
|
|
|
|
|
|
|
|
|
|
The main component is the `openai_caller` class with methods:
|
|
|
|
|
- `__init__(self, api_key=None)`: Initializes an instance of the class and sets the API key if provided.
|
|
|
|
|
- `set_api_key(self, key)`: Sets the API key for OpenAI.
|
|
|
|
|
- `generate_response(self, **kwargs)`: Asynchronously generates a response based on the provided arguments.
|
|
|
|
|
- `chat_generate(self, **kwargs)`: Asynchronously generates a chat-based response, handling token limits and API errors.
|
|
|
|
|
|
|
|
|
|
The module assumes the presence of `num_tokens_from_messages` function in a separate module called `utils.tokens`, used for token calculation.
|
|
|
|
|
|
|
|
|
|
Refer to function and method documentation for further details.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import openai as openai_module
|
|
|
|
|
import asyncio
|
|
|
|
|
|
2023-07-18 17:51:13 +02:00
|
|
|
from openai.error import (
|
|
|
|
|
APIError,
|
|
|
|
|
Timeout,
|
|
|
|
|
RateLimitError,
|
|
|
|
|
APIConnectionError,
|
|
|
|
|
InvalidRequestError,
|
|
|
|
|
AuthenticationError,
|
|
|
|
|
ServiceUnavailableError,
|
|
|
|
|
)
|
2023-08-02 20:16:55 +02:00
|
|
|
from src.utils.tokens import num_tokens_from_messages
|
2023-07-15 12:20:38 +02:00
|
|
|
|
2023-07-18 17:51:13 +02:00
|
|
|
|
2023-07-15 12:20:38 +02:00
|
|
|
class bcolors:
|
2023-07-18 17:51:13 +02:00
|
|
|
HEADER = "\033[95m"
|
|
|
|
|
OKBLUE = "\033[94m"
|
|
|
|
|
OKCYAN = "\033[96m"
|
|
|
|
|
OKGREEN = "\033[92m"
|
|
|
|
|
WARNING = "\033[93m"
|
|
|
|
|
FAIL = "\033[91m"
|
|
|
|
|
ENDC = "\033[0m"
|
|
|
|
|
BOLD = "\033[1m"
|
|
|
|
|
UNDERLINE = "\033[4m"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_models = [
|
|
|
|
|
"gpt-4",
|
|
|
|
|
"gpt-4-32k",
|
|
|
|
|
"gpt-3.5-turbo",
|
|
|
|
|
"gpt-3.5-turbo-16k",
|
|
|
|
|
"gpt-3.5-turbo-0613",
|
|
|
|
|
]
|
|
|
|
|
text_models = [
|
|
|
|
|
"text-davinci-003",
|
|
|
|
|
"text-davinci-002",
|
|
|
|
|
"text-curie-001",
|
|
|
|
|
"text-babbage-001",
|
|
|
|
|
"text-ada-001",
|
|
|
|
|
]
|
2023-07-15 12:20:38 +02:00
|
|
|
|
|
|
|
|
models_max_tokens = {
|
|
|
|
|
"gpt-4": 8_192,
|
|
|
|
|
"gpt-4-32k": 32_768,
|
|
|
|
|
"gpt-3.5-turbo": 4_096,
|
2023-07-16 17:11:24 +02:00
|
|
|
"gpt-3.5-turbo-0613": 4_096,
|
2023-07-15 12:20:38 +02:00
|
|
|
"gpt-3.5-turbo-16k": 16_384,
|
|
|
|
|
"text-davinci-003": 4_097,
|
|
|
|
|
"text-davinci-002": 4_097,
|
|
|
|
|
"text-curie-001": 2_049,
|
|
|
|
|
"text-babbage-001": 2_049,
|
|
|
|
|
"text-ada-001": 2_049,
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-18 17:51:13 +02:00
|
|
|
|
2023-07-15 12:20:38 +02:00
|
|
|
class openai_caller:
|
2023-08-02 20:12:06 +02:00
|
|
|
def __init__(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# async def generate_response(self, error_call=None, **kwargs):
|
|
|
|
|
async def generate_response(*args, **kwargs):
|
|
|
|
|
self = args[0]
|
2023-08-23 15:24:12 +02:00
|
|
|
if args.get("error_call", None) != None:
|
2023-08-02 20:12:06 +02:00
|
|
|
error_call = args[1]
|
|
|
|
|
else:
|
2023-07-18 17:51:13 +02:00
|
|
|
|
2023-08-02 20:12:06 +02:00
|
|
|
async def nothing(x):
|
|
|
|
|
return x
|
2023-08-02 21:48:16 +02:00
|
|
|
|
2023-08-02 20:12:06 +02:00
|
|
|
error_call = nothing
|
2023-07-16 17:11:24 +02:00
|
|
|
if kwargs.get("model", "") in chat_models:
|
|
|
|
|
return await self.chat_generate(error_call, **kwargs)
|
|
|
|
|
elif kwargs.get("engine", "") in text_models:
|
2023-07-15 12:20:38 +02:00
|
|
|
raise NotImplementedError("Text models are not supported yet")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Model not found")
|
2023-07-18 17:51:13 +02:00
|
|
|
|
2023-07-16 17:11:24 +02:00
|
|
|
async def chat_generate(self, recall_func, **kwargs):
|
2023-07-18 17:51:13 +02:00
|
|
|
tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"])
|
|
|
|
|
model_max_tokens = models_max_tokens[kwargs["model"]]
|
2023-07-15 12:20:38 +02:00
|
|
|
while tokens > model_max_tokens:
|
2023-07-18 17:51:13 +02:00
|
|
|
kwargs["messages"] = kwargs["messages"][1:]
|
|
|
|
|
print(
|
|
|
|
|
f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}"
|
|
|
|
|
)
|
|
|
|
|
tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"])
|
2023-08-02 20:12:06 +02:00
|
|
|
if kwargs.get("api_key", None) == None:
|
|
|
|
|
raise ValueError("API key not set")
|
2023-07-16 22:43:18 +02:00
|
|
|
callable = lambda: openai_module.ChatCompletion.acreate(**kwargs)
|
|
|
|
|
response = await self.retryal_call(recall_func, callable)
|
|
|
|
|
return response
|
|
|
|
|
|
2023-08-02 20:12:06 +02:00
|
|
|
async def moderation(*args, **kwargs):
|
|
|
|
|
self = args[0]
|
|
|
|
|
if len(args) > 1:
|
|
|
|
|
error_call = args[1]
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
async def nothing(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
error_call = nothing
|
2023-07-16 22:43:18 +02:00
|
|
|
callable = lambda: openai_module.Moderation.acreate(**kwargs)
|
2023-08-02 20:12:06 +02:00
|
|
|
response = await self.retryal_call(error_call, callable)
|
2023-07-16 22:43:18 +02:00
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
async def retryal_call(self, recall_func, callable):
|
2023-07-15 12:20:38 +02:00
|
|
|
i = 0
|
|
|
|
|
response = None
|
|
|
|
|
while i < 10:
|
|
|
|
|
try:
|
2023-07-16 22:43:18 +02:00
|
|
|
response = await callable()
|
|
|
|
|
return response
|
2023-07-16 17:11:24 +02:00
|
|
|
except APIError as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}APIError. This is not your fault. Retrying...{bcolors.ENDC}"
|
|
|
|
|
)
|
|
|
|
|
await recall_func(
|
2023-07-19 17:00:55 +02:00
|
|
|
"`An APIError occurred. This is not your fault, it is OpenAI's fault. We apologize for the inconvenience. Retrying...`"
|
2023-07-18 17:51:13 +02:00
|
|
|
)
|
2023-07-15 12:20:38 +02:00
|
|
|
await asyncio.sleep(10)
|
|
|
|
|
i += 1
|
2023-07-16 17:11:24 +02:00
|
|
|
except Timeout as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}The request timed out. Retrying...{bcolors.ENDC}"
|
|
|
|
|
)
|
2023-07-16 17:11:24 +02:00
|
|
|
await recall_func("`The request timed out. Retrying...`")
|
2023-07-15 12:20:38 +02:00
|
|
|
await asyncio.sleep(10)
|
|
|
|
|
i += 1
|
2023-07-16 17:11:24 +02:00
|
|
|
except RateLimitError as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}RateLimitError. You are being rate limited. Retrying...{bcolors.ENDC}"
|
|
|
|
|
)
|
2023-07-16 17:11:24 +02:00
|
|
|
await recall_func("`You are being rate limited. Retrying...`")
|
2023-07-15 12:20:38 +02:00
|
|
|
await asyncio.sleep(10)
|
|
|
|
|
i += 1
|
|
|
|
|
except APIConnectionError as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}APIConnectionError. There is an issue with your internet connection. Please check your connection.{bcolors.ENDC}"
|
|
|
|
|
)
|
2023-07-15 12:20:38 +02:00
|
|
|
raise e
|
|
|
|
|
except InvalidRequestError as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}InvalidRequestError. Please check your request.{bcolors.ENDC}"
|
|
|
|
|
)
|
2023-08-02 20:12:06 +02:00
|
|
|
await recall_func("`InvalidRequestError. Please check your request.`")
|
2023-07-15 12:20:38 +02:00
|
|
|
raise e
|
|
|
|
|
except AuthenticationError as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}AuthenticationError. Please check your API key and if needed, also your organization ID.{bcolors.ENDC}"
|
|
|
|
|
)
|
2023-07-16 17:11:24 +02:00
|
|
|
await recall_func("`AuthenticationError. Please check your API key.`")
|
2023-07-15 12:20:38 +02:00
|
|
|
raise e
|
2023-07-16 17:11:24 +02:00
|
|
|
except ServiceUnavailableError as e:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}ServiceUnavailableError. The OpenAI API is not responding. Retrying...{bcolors.ENDC}"
|
|
|
|
|
)
|
2023-07-16 17:11:24 +02:00
|
|
|
await recall_func("`The OpenAI API is not responding. Retrying...`")
|
2023-07-15 12:20:38 +02:00
|
|
|
await asyncio.sleep(10)
|
2023-07-16 17:11:24 +02:00
|
|
|
await recall_func()
|
2023-07-15 12:20:38 +02:00
|
|
|
i += 1
|
|
|
|
|
finally:
|
|
|
|
|
if i == 10:
|
2023-07-18 17:51:13 +02:00
|
|
|
print(
|
|
|
|
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}OpenAI API is not responding. Please try again later.{bcolors.ENDC}"
|
|
|
|
|
)
|
|
|
|
|
raise TimeoutError(
|
|
|
|
|
"OpenAI API is not responding. Please try again later."
|
|
|
|
|
)
|
2023-07-16 22:43:18 +02:00
|
|
|
return response
|
2023-07-18 17:51:13 +02:00
|
|
|
|
|
|
|
|
|
2023-07-16 17:11:24 +02:00
|
|
|
##testing
|
|
|
|
|
if __name__ == "__main__":
|
2023-07-18 17:51:13 +02:00
|
|
|
|
2023-07-16 17:11:24 +02:00
|
|
|
async def main():
|
2023-07-16 17:12:21 +02:00
|
|
|
openai = openai_caller(api_key="sk-")
|
2023-07-16 17:11:24 +02:00
|
|
|
response = await openai.generate_response(
|
2023-08-02 20:12:06 +02:00
|
|
|
api_key="sk-",
|
2023-07-16 17:11:24 +02:00
|
|
|
model="gpt-3.5-turbo",
|
2023-07-18 17:51:13 +02:00
|
|
|
messages=[{"role": "user", "content": "ping"}],
|
2023-07-16 17:11:24 +02:00
|
|
|
max_tokens=5,
|
|
|
|
|
temperature=0.7,
|
|
|
|
|
top_p=1,
|
|
|
|
|
frequency_penalty=0,
|
|
|
|
|
presence_penalty=0,
|
2023-07-18 17:51:13 +02:00
|
|
|
stop=["\n", " Human:", " AI:"],
|
2023-07-16 17:11:24 +02:00
|
|
|
)
|
|
|
|
|
print(response)
|
2023-07-18 17:51:13 +02:00
|
|
|
|
|
|
|
|
asyncio.run(main())
|