[MAKEPROMPT] FIxed errors and cleaned a bit more

This commit is contained in:
Alexis LEBEL
2023-04-01 13:17:54 +02:00
parent dd14472447
commit 306eefbc75

View File

@@ -8,7 +8,6 @@ import openai
import emoji
import os
async def replace_mentions(content, bot):
mentions = re.findall(r"<@!?\d+>", content)
for mention in mentions:
@@ -69,7 +68,7 @@ def get_guild_data(message):
data = curs_data.fetchone()
model = model[1]
except:
model = "chatGPT"
model = "gpt-3.5-turbo"
try:
# [2] # get the premium status of the guild
@@ -90,21 +89,22 @@ def get_guild_data(message):
guild_data["premium"] = premium
guild_data["images"] = images
guild_data["images_limit_reached"] = False
return guild_data
async def need_ignore_message(self, data_dict, message, guild_data, original_message, channels):
async def need_ignore_message(bot, data_dict, message, guild_data, original_message, channels):
## ---- Message ignore conditions ---- ##
if message.author.bot:
return True
if data_dict["api_key"] is None:
return True # if the api key is not set, return
return True # if the api key is not set, return
if (
# if the message is not in a premium channel and
not str(message.channel.id) in channels
# if the message doesn't mention the bot and
and message.content.find("<@" + str(self.bot.user.id) + ">") == -1
and original_message == None # if the message is not a reply to the bot and
not (str(message.channel.id) in channels
# if the message doesn't mention the bot and
and (message.content.find("<@" + str(bot.user.id) + ">") != -1
or original_message)) # if the message is not a reply to the bot and
# if the message is not in the default channel
and str(message.channel.id) != str(data_dict["channel_id"])
):
@@ -137,30 +137,35 @@ async def need_ignore_message(self, data_dict, message, guild_data, original_mes
return False
def get_data_dict(self, message):
async def get_data_dict(message):
try:
curs_data.execute(
"SELECT * FROM data WHERE guild_id = ?", (message.guild.id,))
except:
return
data = curs_data.fetchone()
# Create a dict with the data
data_dict = {
"channel_id": data[1],
"api_key": data[2],
"is_active": data[3],
"max_tokens": data[4],
"temperature": data[5],
"frequency_penalty": data[6],
"presence_penalty": data[7],
"uses_count_today": data[8],
"prompt_size": data[9],
"prompt_prefix": data[10],
"tts": data[11],
"pretend_to_be": data[12],
"pretend_enabled": data[13],
}
return data_dict
data = curs_data.fetchone()
# Create a dict with the data
data_dict = {
"channel_id": data[1],
"api_key": data[2],
"is_active": data[3],
"max_tokens": data[4],
"temperature": data[5],
"frequency_penalty": data[6],
"presence_penalty": data[7],
"uses_count_today": data[8],
"prompt_size": data[9],
"prompt_prefix": data[10],
"tts": bool(data[11]),
"pretend_to_be": data[12],
"pretend_enabled": data[13],
}
return data_dict
except Exception as e:
# Send an error message
await message.channel.send(
"The bot is not configured yet. Please use `//setup` to configure it. \n" +
"If it still doesn't work, it might be a database error. \n ```" + e.__str__()
+ "```", delete_after=60
)
async def chat_process(self, message):
@@ -169,40 +174,11 @@ async def chat_process(self, message):
Args:
message (str): Data of the message that was sent
"""
if(await need_ignore_message(self, message, guild_data, original_message, channels)):
if message.author.bot:
return
data_dict = get_data_dict(message)
## ---- Message processing ---- ##
if data is None:
data = [message.guild.id, 0, 0]
data_dict["images_usage"] = data[1]
data_dict["images_enabled"] = data[2]
images_limit_reached = False
guild_data = get_guild_data(message)
channels = []
if message.guild.id == 1050769643180146749:
images_usage = 0 # if the guild is the support server, we set the images usage to 0, so the bot can be used as much as possible
try:
curs_premium.execute(
"SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
data = curs_premium.fetchone()
if guild_data["premium"]:
# for 5 times, we get c.fetchone()[1] to c.fetchone()[5] and we add it to the channels list, each time with try except
for i in range(1, 6):
# we use the i variable to get the channel id
try:
channels.append(str(data[i]))
except:
pass
except:
channels = []
data_dict = await get_data_dict(message)
try:
original_message = await message.channel.fetch_message(
@@ -215,6 +191,45 @@ async def chat_process(self, message):
# if the message someone replied to is not from the bot, set original_message to None
original_message = None
try:
# get the images setting in the database
curs_data.execute(
"SELECT * FROM images WHERE guild_id = ?", (message.guild.id,))
images_data = curs_data.fetchone()
except:
images_data = None
## ---- Message processing ---- ##
print(message)
if not images_data:
images_data = [message.guild.id, 0, 0]
data_dict["images_usage"] = 0 if message.guild.id == 1050769643180146749 else images_data[1]
data_dict["images_enabled"] = images_data[2]
channels = []
try:
curs_premium.execute(
"SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
images_data = curs_premium.fetchone()
if guild_data["premium"]:
# for 5 times, we get c.fetchone()[1] to c.fetchone()[5] and we add it to the channels list, each time with try except
for i in range(1, 6):
# we use the i variable to get the channel id
try:
channels.append(str(images_data[i]))
except:
pass
except:
debug("No premium channels found for this guild")
if (await need_ignore_message(self.bot, data_dict, message, guild_data, original_message, channels)):
return
print("prompt handler")
try:
await message.channel.trigger_typing()
# if the message is not in the owner's guild we update the usage count
@@ -240,9 +255,11 @@ async def chat_process(self, message):
messages.append(message)
except Exception as e:
debug("Error while getting message history", e)
print(e)
# if the pretend to be feature is enabled, we add the pretend to be text to the prompt
pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}" if data_dict["pretend_enabled"] else ""
pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}" if data_dict[
"pretend_enabled"] else ""
prompt_prefix = "" if data_dict["prompt_prefix"] == None else data_dict["prompt_prefix"]
# open the prompt file for the selected model with utf-8 encoding for emojis
@@ -261,16 +278,48 @@ async def chat_process(self, message):
f.close()
prompt_handlers = {
"chatGPT": self.gpt_prompt,
"gpt-4": self.gpt_prompt,
"davinci": self.davinci_prompt,
"gpt-3.5-turbo": gpt_prompt,
"gpt-4": gpt_prompt,
"davinci": davinci_prompt,
}
prompt_handlers[guild_data["model"]](
messages, message, data_dict, prompt, guild_data
response = await prompt_handlers[guild_data["model"]](
self.bot, messages, message, data_dict, prompt, guild_data
)
if response != "":
emojis, string = await extract_emoji(response)
debug(f"Emojis: {emojis}")
if len(string) < 1996:
await message.channel.send(string, tts=data_dict["tts"])
else:
# we send in an embed if the message is too long
embed = discord.Embed(
title="Botator response",
description=string,
color=discord.Color.brand_green(),
)
await message.channel.send(embed=embed, tts=data_dict["tts"])
for emoji in emojis:
# if the emoji is longer than 1 character, it's a custom emoji
try:
if len(emoji) > 1:
# if the emoji is a custom emoji, we need to fetch it
# the emoji is in the format id
debug(f"Emoji: {emoji}")
emoji = await message.guild.fetch_emoji(int(emoji))
await message.add_reaction(emoji)
else:
debug(f"Emoji: {emoji}")
await message.add_reaction(emoji)
except:
pass
else:
await message.channel.send(
"The AI is not sure what to say (the response was empty)"
)
async def check_moderate(self, api_key, message, msg):
async def check_moderate(api_key, message, msg):
if await moderate(api_key=api_key, text=msg.content):
embed = discord.Embed(
title="Message flagged as inappropriate",
@@ -285,7 +334,7 @@ async def check_moderate(self, api_key, message, msg):
return False
async def check_easter_egg(self, message, msgs):
async def check_easter_egg(message, msgs):
if message.content.lower().find("undude") != -1:
msgs.append(
{
@@ -310,7 +359,7 @@ async def check_easter_egg(self, message, msgs):
return msgs
async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
async def gpt_prompt(bot, messages, message, data_dict, prompt, guild_data):
msgs = [] # create the msgs list
msgs.append(
{"name": "System", "role": "user", "content": prompt}
@@ -319,14 +368,14 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
for msg in messages: # for each message in the messages list
content = msg.content # get the content of the message
content = await replace_mentions(
content, self.bot
content, bot
) # replace the mentions in the message
# if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it
if await self.check_moderate(data_dict["api_key"], message, msg):
if await check_moderate(data_dict["api_key"], message, msg):
continue # ignore the message
content = await replace_mentions(content, self.bot)
content = await replace_mentions(content, bot)
prompt += f"{msg.author.name}: {content}\n"
if msg.author.id == self.bot.user.id:
if msg.author.id == bot.user.id:
role = "assistant"
name = "assistant"
else:
@@ -350,13 +399,13 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
for attachment in msg.attachments:
path = f"./../database/google-vision/results/{attachment.id}.txt"
if images_usage >= 6 and guild_data["premium"] == 0:
images_limit_reached = True
guild_data["images_limit_reached"] = True
elif images_usage >= 30 and guild_data["premium"] == 1:
images_limit_reached = True
guild_data["images_limit_reached"] = True
if (
attachment.url.endswith((".png", ".jpg", ".jpeg", ".gif"))
and images_limit_reached == False
and os.path.exists(path) == False
and not guild_data["images_limit_reached"]
and not os.path.exists(path)
):
images_usage += 1
analysis = await vision_processing.process(attachment)
@@ -403,17 +452,15 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
msgs.append({"role": role, "content": f"{content}", "name": name})
# We check for the eastereggs :)
msgs = await self.check_easter_egg(message, msgs)
msgs = await check_easter_egg(message, msgs)
if model == "chatGPT":
model = "gpt-3.5-turbo" # if the model is chatGPT, we set the model to gpt-3.5-turbo
response = ""
should_break = True
for x in range(10):
try:
openai.api_key = data_dict["api_key"]
response = await openai.ChatCompletion.acreate(
model=model,
model=guild_data["model"],
temperature=2,
top_p=0.9,
frequency_penalty=0,
@@ -446,11 +493,13 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
await asyncio.sleep(15)
await message.channel.trigger_typing()
response = response.choices[0].message.content
if images_limit_reached == True:
if guild_data["images_limit_reached"]:
await message.channel.send(
f"```diff\n-Warning: You have reached the image limit for this server. You can upgrade to premium to get more images recognized. More info in our server: https://discord.gg/sxjHtmqrbf```",
delete_after=10,
)
return response
async def davinci_prompt(self, messages, message, data_dict, prompt, guild_data):
@@ -488,38 +537,4 @@ async def davinci_prompt(self, messages, message, data_dict, prompt, guild_data)
return
if response != None:
break
if response != "":
if tts:
tts = True
else:
tts = False
emojis, string = await extract_emoji(response)
debug(f"Emojis: {emojis}")
if len(string) < 1996:
await message.channel.send(string, tts=tts)
else:
# we send in an embed if the message is too long
embed = discord.Embed(
title="Botator response",
description=string,
color=discord.Color.brand_green(),
)
await message.channel.send(embed=embed, tts=tts)
for emoji in emojis:
# if the emoji is longer than 1 character, it's a custom emoji
try:
if len(emoji) > 1:
# if the emoji is a custom emoji, we need to fetch it
# the emoji is in the format id
debug(f"Emoji: {emoji}")
emoji = await message.guild.fetch_emoji(int(emoji))
await message.add_reaction(emoji)
else:
debug(f"Emoji: {emoji}")
await message.add_reaction(emoji)
except:
pass
else:
await message.channel.send(
"The AI is not sure what to say (the response was empty)"
)
return response