fix(makeprompt.py): add api_key parameter to the generate_response function call to fix missing api_key error

fix(openaicaller.py): fix import statement for num_tokens_from_messages function
refactor(openaicaller.py): remove unused api_key parameter from openai_caller class constructor
refactor(openaicaller.py): change generate_response and moderation methods to accept variable number of arguments
refactor(openaicaller.py): change error_call parameter to be optional and provide a default function if not provided
refactor(openaicaller.py): remove unused api_key parameter from generate_response and moderation methods
refactor(openaicaller.py): remove unused api_key parameter from main function
This commit is contained in:
Paillat
2023-08-02 20:12:06 +02:00
parent 45dfafadd8
commit a7a0f5dac8
2 changed files with 29 additions and 12 deletions

View File

@@ -111,6 +111,7 @@ async def chatgpt_process(
) )
response = await caller.generate_response( response = await caller.generate_response(
error_call, error_call,
api_key=api_key,
model=model, model=model,
messages=msgs, messages=msgs,
functions=called_functions, functions=called_functions,

View File

@@ -32,7 +32,7 @@ from openai.error import (
AuthenticationError, AuthenticationError,
ServiceUnavailableError, ServiceUnavailableError,
) )
from src.utils.tokens import num_tokens_from_messages from utils.tokens import num_tokens_from_messages
class bcolors: class bcolors:
@@ -77,12 +77,19 @@ models_max_tokens = {
class openai_caller: class openai_caller:
def __init__(self, api_key=None) -> None: def __init__(self) -> None:
self.api_key = api_key pass
async def generate_response(self, error_call=None, **kwargs): # async def generate_response(self, error_call=None, **kwargs):
if error_call is None: async def generate_response(*args, **kwargs):
error_call = lambda x: 2 # do nothing self = args[0]
if len(args) > 1:
error_call = args[1]
else:
async def nothing(x):
return x
error_call = nothing
if kwargs.get("model", "") in chat_models: if kwargs.get("model", "") in chat_models:
return await self.chat_generate(error_call, **kwargs) return await self.chat_generate(error_call, **kwargs)
elif kwargs.get("engine", "") in text_models: elif kwargs.get("engine", "") in text_models:
@@ -99,16 +106,24 @@ class openai_caller:
f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}" f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}"
) )
tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"]) tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"])
kwargs["api_key"] = self.api_key if kwargs.get("api_key", None) == None:
raise ValueError("API key not set")
callable = lambda: openai_module.ChatCompletion.acreate(**kwargs) callable = lambda: openai_module.ChatCompletion.acreate(**kwargs)
response = await self.retryal_call(recall_func, callable) response = await self.retryal_call(recall_func, callable)
return response return response
async def moderation(self, recall_func=None, **kwargs): async def moderation(*args, **kwargs):
if recall_func is None: self = args[0]
recall_func = lambda x: 2 if len(args) > 1:
error_call = args[1]
else:
async def nothing(x):
return x
error_call = nothing
callable = lambda: openai_module.Moderation.acreate(**kwargs) callable = lambda: openai_module.Moderation.acreate(**kwargs)
response = await self.retryal_call(recall_func, callable) response = await self.retryal_call(error_call, callable)
return response return response
async def retryal_call(self, recall_func, callable): async def retryal_call(self, recall_func, callable):
@@ -154,7 +169,7 @@ class openai_caller:
print( print(
f"\n\n{bcolors.BOLD}{bcolors.FAIL}InvalidRequestError. Please check your request.{bcolors.ENDC}" f"\n\n{bcolors.BOLD}{bcolors.FAIL}InvalidRequestError. Please check your request.{bcolors.ENDC}"
) )
await recall_func() await recall_func("`InvalidRequestError. Please check your request.`")
raise e raise e
except AuthenticationError as e: except AuthenticationError as e:
print( print(
@@ -187,6 +202,7 @@ if __name__ == "__main__":
async def main(): async def main():
openai = openai_caller(api_key="sk-") openai = openai_caller(api_key="sk-")
response = await openai.generate_response( response = await openai.generate_response(
api_key="sk-",
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "ping"}], messages=[{"role": "user", "content": "ping"}],
max_tokens=5, max_tokens=5,