fix(makeprompt.py): add api_key parameter to the generate_response function call to fix missing api_key error

fix(openaicaller.py): fix import statement for num_tokens_from_messages function
refactor(openaicaller.py): remove unused api_key parameter from openai_caller class constructor
refactor(openaicaller.py): change generate_response and moderation methods to accept variable number of arguments
refactor(openaicaller.py): change error_call parameter to be optional and provide a default function if not provided
refactor(openaicaller.py): remove unused api_key parameter from generate_response and moderation methods
refactor(openaicaller.py): remove unused api_key parameter from main function
This commit is contained in:
Paillat
2023-08-02 20:12:06 +02:00
parent 45dfafadd8
commit a7a0f5dac8
2 changed files with 29 additions and 12 deletions

View File

@@ -111,6 +111,7 @@ async def chatgpt_process(
)
response = await caller.generate_response(
error_call,
api_key=api_key,
model=model,
messages=msgs,
functions=called_functions,

View File

@@ -32,7 +32,7 @@ from openai.error import (
AuthenticationError,
ServiceUnavailableError,
)
from src.utils.tokens import num_tokens_from_messages
from utils.tokens import num_tokens_from_messages
class bcolors:
@@ -77,12 +77,19 @@ models_max_tokens = {
class openai_caller:
def __init__(self, api_key=None) -> None:
self.api_key = api_key
def __init__(self) -> None:
pass
async def generate_response(self, error_call=None, **kwargs):
if error_call is None:
error_call = lambda x: 2 # do nothing
# async def generate_response(self, error_call=None, **kwargs):
async def generate_response(*args, **kwargs):
self = args[0]
if len(args) > 1:
error_call = args[1]
else:
async def nothing(x):
return x
error_call = nothing
if kwargs.get("model", "") in chat_models:
return await self.chat_generate(error_call, **kwargs)
elif kwargs.get("engine", "") in text_models:
@@ -99,16 +106,24 @@ class openai_caller:
f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}"
)
tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"])
kwargs["api_key"] = self.api_key
if kwargs.get("api_key", None) == None:
raise ValueError("API key not set")
callable = lambda: openai_module.ChatCompletion.acreate(**kwargs)
response = await self.retryal_call(recall_func, callable)
return response
async def moderation(self, recall_func=None, **kwargs):
if recall_func is None:
recall_func = lambda x: 2
async def moderation(*args, **kwargs):
self = args[0]
if len(args) > 1:
error_call = args[1]
else:
async def nothing(x):
return x
error_call = nothing
callable = lambda: openai_module.Moderation.acreate(**kwargs)
response = await self.retryal_call(recall_func, callable)
response = await self.retryal_call(error_call, callable)
return response
async def retryal_call(self, recall_func, callable):
@@ -154,7 +169,7 @@ class openai_caller:
print(
f"\n\n{bcolors.BOLD}{bcolors.FAIL}InvalidRequestError. Please check your request.{bcolors.ENDC}"
)
await recall_func()
await recall_func("`InvalidRequestError. Please check your request.`")
raise e
except AuthenticationError as e:
print(
@@ -187,6 +202,7 @@ if __name__ == "__main__":
async def main():
openai = openai_caller(api_key="sk-")
response = await openai.generate_response(
api_key="sk-",
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "ping"}],
max_tokens=5,