mirror of
https://github.com/Paillat-dev/Botator.git
synced 2026-01-02 09:16:19 +00:00
[MAKEPROMPT] FIxed errors and cleaned a bit more
This commit is contained in:
@@ -8,7 +8,6 @@ import openai
|
||||
import emoji
|
||||
import os
|
||||
|
||||
|
||||
async def replace_mentions(content, bot):
|
||||
mentions = re.findall(r"<@!?\d+>", content)
|
||||
for mention in mentions:
|
||||
@@ -69,7 +68,7 @@ def get_guild_data(message):
|
||||
data = curs_data.fetchone()
|
||||
model = model[1]
|
||||
except:
|
||||
model = "chatGPT"
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
try:
|
||||
# [2] # get the premium status of the guild
|
||||
@@ -90,21 +89,22 @@ def get_guild_data(message):
|
||||
guild_data["premium"] = premium
|
||||
guild_data["images"] = images
|
||||
|
||||
guild_data["images_limit_reached"] = False
|
||||
|
||||
return guild_data
|
||||
|
||||
async def need_ignore_message(self, data_dict, message, guild_data, original_message, channels):
|
||||
|
||||
async def need_ignore_message(bot, data_dict, message, guild_data, original_message, channels):
|
||||
## ---- Message ignore conditions ---- ##
|
||||
if message.author.bot:
|
||||
return True
|
||||
if data_dict["api_key"] is None:
|
||||
return True # if the api key is not set, return
|
||||
|
||||
if (
|
||||
# if the message is not in a premium channel and
|
||||
not str(message.channel.id) in channels
|
||||
not (str(message.channel.id) in channels
|
||||
# if the message doesn't mention the bot and
|
||||
and message.content.find("<@" + str(self.bot.user.id) + ">") == -1
|
||||
and original_message == None # if the message is not a reply to the bot and
|
||||
and (message.content.find("<@" + str(bot.user.id) + ">") != -1
|
||||
or original_message)) # if the message is not a reply to the bot and
|
||||
# if the message is not in the default channel
|
||||
and str(message.channel.id) != str(data_dict["channel_id"])
|
||||
):
|
||||
@@ -137,12 +137,10 @@ async def need_ignore_message(self, data_dict, message, guild_data, original_mes
|
||||
return False
|
||||
|
||||
|
||||
def get_data_dict(self, message):
|
||||
async def get_data_dict(message):
|
||||
try:
|
||||
curs_data.execute(
|
||||
"SELECT * FROM data WHERE guild_id = ?", (message.guild.id,))
|
||||
except:
|
||||
return
|
||||
data = curs_data.fetchone()
|
||||
# Create a dict with the data
|
||||
data_dict = {
|
||||
@@ -156,11 +154,18 @@ def get_data_dict(self, message):
|
||||
"uses_count_today": data[8],
|
||||
"prompt_size": data[9],
|
||||
"prompt_prefix": data[10],
|
||||
"tts": data[11],
|
||||
"tts": bool(data[11]),
|
||||
"pretend_to_be": data[12],
|
||||
"pretend_enabled": data[13],
|
||||
}
|
||||
return data_dict
|
||||
except Exception as e:
|
||||
# Send an error message
|
||||
await message.channel.send(
|
||||
"The bot is not configured yet. Please use `//setup` to configure it. \n" +
|
||||
"If it still doesn't work, it might be a database error. \n ```" + e.__str__()
|
||||
+ "```", delete_after=60
|
||||
)
|
||||
|
||||
|
||||
async def chat_process(self, message):
|
||||
@@ -169,40 +174,11 @@ async def chat_process(self, message):
|
||||
Args:
|
||||
message (str): Data of the message that was sent
|
||||
"""
|
||||
|
||||
if(await need_ignore_message(self, message, guild_data, original_message, channels)):
|
||||
if message.author.bot:
|
||||
return
|
||||
data_dict = get_data_dict(message)
|
||||
|
||||
## ---- Message processing ---- ##
|
||||
|
||||
if data is None:
|
||||
data = [message.guild.id, 0, 0]
|
||||
|
||||
data_dict["images_usage"] = data[1]
|
||||
data_dict["images_enabled"] = data[2]
|
||||
|
||||
images_limit_reached = False
|
||||
|
||||
guild_data = get_guild_data(message)
|
||||
|
||||
channels = []
|
||||
if message.guild.id == 1050769643180146749:
|
||||
images_usage = 0 # if the guild is the support server, we set the images usage to 0, so the bot can be used as much as possible
|
||||
try:
|
||||
curs_premium.execute(
|
||||
"SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
|
||||
data = curs_premium.fetchone()
|
||||
if guild_data["premium"]:
|
||||
# for 5 times, we get c.fetchone()[1] to c.fetchone()[5] and we add it to the channels list, each time with try except
|
||||
for i in range(1, 6):
|
||||
# we use the i variable to get the channel id
|
||||
try:
|
||||
channels.append(str(data[i]))
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
channels = []
|
||||
data_dict = await get_data_dict(message)
|
||||
|
||||
try:
|
||||
original_message = await message.channel.fetch_message(
|
||||
@@ -215,6 +191,45 @@ async def chat_process(self, message):
|
||||
# if the message someone replied to is not from the bot, set original_message to None
|
||||
original_message = None
|
||||
|
||||
try:
|
||||
# get the images setting in the database
|
||||
curs_data.execute(
|
||||
"SELECT * FROM images WHERE guild_id = ?", (message.guild.id,))
|
||||
images_data = curs_data.fetchone()
|
||||
except:
|
||||
images_data = None
|
||||
|
||||
## ---- Message processing ---- ##
|
||||
|
||||
print(message)
|
||||
|
||||
if not images_data:
|
||||
images_data = [message.guild.id, 0, 0]
|
||||
|
||||
data_dict["images_usage"] = 0 if message.guild.id == 1050769643180146749 else images_data[1]
|
||||
data_dict["images_enabled"] = images_data[2]
|
||||
|
||||
|
||||
channels = []
|
||||
try:
|
||||
curs_premium.execute(
|
||||
"SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
|
||||
images_data = curs_premium.fetchone()
|
||||
if guild_data["premium"]:
|
||||
# for 5 times, we get c.fetchone()[1] to c.fetchone()[5] and we add it to the channels list, each time with try except
|
||||
for i in range(1, 6):
|
||||
# we use the i variable to get the channel id
|
||||
try:
|
||||
channels.append(str(images_data[i]))
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
debug("No premium channels found for this guild")
|
||||
|
||||
if (await need_ignore_message(self.bot, data_dict, message, guild_data, original_message, channels)):
|
||||
return
|
||||
print("prompt handler")
|
||||
|
||||
try:
|
||||
await message.channel.trigger_typing()
|
||||
# if the message is not in the owner's guild we update the usage count
|
||||
@@ -240,9 +255,11 @@ async def chat_process(self, message):
|
||||
messages.append(message)
|
||||
except Exception as e:
|
||||
debug("Error while getting message history", e)
|
||||
print(e)
|
||||
|
||||
# if the pretend to be feature is enabled, we add the pretend to be text to the prompt
|
||||
pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}" if data_dict["pretend_enabled"] else ""
|
||||
pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}" if data_dict[
|
||||
"pretend_enabled"] else ""
|
||||
prompt_prefix = "" if data_dict["prompt_prefix"] == None else data_dict["prompt_prefix"]
|
||||
|
||||
# open the prompt file for the selected model with utf-8 encoding for emojis
|
||||
@@ -261,16 +278,48 @@ async def chat_process(self, message):
|
||||
f.close()
|
||||
|
||||
prompt_handlers = {
|
||||
"chatGPT": self.gpt_prompt,
|
||||
"gpt-4": self.gpt_prompt,
|
||||
"davinci": self.davinci_prompt,
|
||||
"gpt-3.5-turbo": gpt_prompt,
|
||||
"gpt-4": gpt_prompt,
|
||||
"davinci": davinci_prompt,
|
||||
}
|
||||
prompt_handlers[guild_data["model"]](
|
||||
messages, message, data_dict, prompt, guild_data
|
||||
response = await prompt_handlers[guild_data["model"]](
|
||||
self.bot, messages, message, data_dict, prompt, guild_data
|
||||
)
|
||||
|
||||
if response != "":
|
||||
emojis, string = await extract_emoji(response)
|
||||
debug(f"Emojis: {emojis}")
|
||||
if len(string) < 1996:
|
||||
await message.channel.send(string, tts=data_dict["tts"])
|
||||
else:
|
||||
# we send in an embed if the message is too long
|
||||
embed = discord.Embed(
|
||||
title="Botator response",
|
||||
description=string,
|
||||
color=discord.Color.brand_green(),
|
||||
)
|
||||
await message.channel.send(embed=embed, tts=data_dict["tts"])
|
||||
for emoji in emojis:
|
||||
# if the emoji is longer than 1 character, it's a custom emoji
|
||||
try:
|
||||
if len(emoji) > 1:
|
||||
# if the emoji is a custom emoji, we need to fetch it
|
||||
# the emoji is in the format id
|
||||
debug(f"Emoji: {emoji}")
|
||||
emoji = await message.guild.fetch_emoji(int(emoji))
|
||||
await message.add_reaction(emoji)
|
||||
else:
|
||||
debug(f"Emoji: {emoji}")
|
||||
await message.add_reaction(emoji)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
await message.channel.send(
|
||||
"The AI is not sure what to say (the response was empty)"
|
||||
)
|
||||
|
||||
|
||||
async def check_moderate(self, api_key, message, msg):
|
||||
async def check_moderate(api_key, message, msg):
|
||||
if await moderate(api_key=api_key, text=msg.content):
|
||||
embed = discord.Embed(
|
||||
title="Message flagged as inappropriate",
|
||||
@@ -285,7 +334,7 @@ async def check_moderate(self, api_key, message, msg):
|
||||
return False
|
||||
|
||||
|
||||
async def check_easter_egg(self, message, msgs):
|
||||
async def check_easter_egg(message, msgs):
|
||||
if message.content.lower().find("undude") != -1:
|
||||
msgs.append(
|
||||
{
|
||||
@@ -310,7 +359,7 @@ async def check_easter_egg(self, message, msgs):
|
||||
return msgs
|
||||
|
||||
|
||||
async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
|
||||
async def gpt_prompt(bot, messages, message, data_dict, prompt, guild_data):
|
||||
msgs = [] # create the msgs list
|
||||
msgs.append(
|
||||
{"name": "System", "role": "user", "content": prompt}
|
||||
@@ -319,14 +368,14 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
|
||||
for msg in messages: # for each message in the messages list
|
||||
content = msg.content # get the content of the message
|
||||
content = await replace_mentions(
|
||||
content, self.bot
|
||||
content, bot
|
||||
) # replace the mentions in the message
|
||||
# if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it
|
||||
if await self.check_moderate(data_dict["api_key"], message, msg):
|
||||
if await check_moderate(data_dict["api_key"], message, msg):
|
||||
continue # ignore the message
|
||||
content = await replace_mentions(content, self.bot)
|
||||
content = await replace_mentions(content, bot)
|
||||
prompt += f"{msg.author.name}: {content}\n"
|
||||
if msg.author.id == self.bot.user.id:
|
||||
if msg.author.id == bot.user.id:
|
||||
role = "assistant"
|
||||
name = "assistant"
|
||||
else:
|
||||
@@ -350,13 +399,13 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
|
||||
for attachment in msg.attachments:
|
||||
path = f"./../database/google-vision/results/{attachment.id}.txt"
|
||||
if images_usage >= 6 and guild_data["premium"] == 0:
|
||||
images_limit_reached = True
|
||||
guild_data["images_limit_reached"] = True
|
||||
elif images_usage >= 30 and guild_data["premium"] == 1:
|
||||
images_limit_reached = True
|
||||
guild_data["images_limit_reached"] = True
|
||||
if (
|
||||
attachment.url.endswith((".png", ".jpg", ".jpeg", ".gif"))
|
||||
and images_limit_reached == False
|
||||
and os.path.exists(path) == False
|
||||
and not guild_data["images_limit_reached"]
|
||||
and not os.path.exists(path)
|
||||
):
|
||||
images_usage += 1
|
||||
analysis = await vision_processing.process(attachment)
|
||||
@@ -403,17 +452,15 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
|
||||
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||
|
||||
# We check for the eastereggs :)
|
||||
msgs = await self.check_easter_egg(message, msgs)
|
||||
msgs = await check_easter_egg(message, msgs)
|
||||
|
||||
if model == "chatGPT":
|
||||
model = "gpt-3.5-turbo" # if the model is chatGPT, we set the model to gpt-3.5-turbo
|
||||
response = ""
|
||||
should_break = True
|
||||
for x in range(10):
|
||||
try:
|
||||
openai.api_key = data_dict["api_key"]
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
model=model,
|
||||
model=guild_data["model"],
|
||||
temperature=2,
|
||||
top_p=0.9,
|
||||
frequency_penalty=0,
|
||||
@@ -446,11 +493,13 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data):
|
||||
await asyncio.sleep(15)
|
||||
await message.channel.trigger_typing()
|
||||
response = response.choices[0].message.content
|
||||
if images_limit_reached == True:
|
||||
|
||||
if guild_data["images_limit_reached"]:
|
||||
await message.channel.send(
|
||||
f"```diff\n-Warning: You have reached the image limit for this server. You can upgrade to premium to get more images recognized. More info in our server: https://discord.gg/sxjHtmqrbf```",
|
||||
delete_after=10,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def davinci_prompt(self, messages, message, data_dict, prompt, guild_data):
|
||||
@@ -488,38 +537,4 @@ async def davinci_prompt(self, messages, message, data_dict, prompt, guild_data)
|
||||
return
|
||||
if response != None:
|
||||
break
|
||||
if response != "":
|
||||
if tts:
|
||||
tts = True
|
||||
else:
|
||||
tts = False
|
||||
emojis, string = await extract_emoji(response)
|
||||
debug(f"Emojis: {emojis}")
|
||||
if len(string) < 1996:
|
||||
await message.channel.send(string, tts=tts)
|
||||
else:
|
||||
# we send in an embed if the message is too long
|
||||
embed = discord.Embed(
|
||||
title="Botator response",
|
||||
description=string,
|
||||
color=discord.Color.brand_green(),
|
||||
)
|
||||
await message.channel.send(embed=embed, tts=tts)
|
||||
for emoji in emojis:
|
||||
# if the emoji is longer than 1 character, it's a custom emoji
|
||||
try:
|
||||
if len(emoji) > 1:
|
||||
# if the emoji is a custom emoji, we need to fetch it
|
||||
# the emoji is in the format id
|
||||
debug(f"Emoji: {emoji}")
|
||||
emoji = await message.guild.fetch_emoji(int(emoji))
|
||||
await message.add_reaction(emoji)
|
||||
else:
|
||||
debug(f"Emoji: {emoji}")
|
||||
await message.add_reaction(emoji)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
await message.channel.send(
|
||||
"The AI is not sure what to say (the response was empty)"
|
||||
)
|
||||
return response
|
||||
Reference in New Issue
Block a user