From 306eefbc75e81091fd54e2508affdcd495a79af0 Mon Sep 17 00:00:00 2001 From: Alexis LEBEL Date: Sat, 1 Apr 2023 13:17:54 +0200 Subject: [PATCH] [MAKEPROMPT] FIxed errors and cleaned a bit more --- code/makeprompt.py | 261 ++++++++++++++++++++++++--------------------- 1 file changed, 138 insertions(+), 123 deletions(-) diff --git a/code/makeprompt.py b/code/makeprompt.py index d1ce3b4..4364a44 100644 --- a/code/makeprompt.py +++ b/code/makeprompt.py @@ -8,7 +8,6 @@ import openai import emoji import os - async def replace_mentions(content, bot): mentions = re.findall(r"<@!?\d+>", content) for mention in mentions: @@ -69,7 +68,7 @@ def get_guild_data(message): data = curs_data.fetchone() model = model[1] except: - model = "chatGPT" + model = "gpt-3.5-turbo" try: # [2] # get the premium status of the guild @@ -89,22 +88,23 @@ def get_guild_data(message): guild_data["model"] = model guild_data["premium"] = premium guild_data["images"] = images + + guild_data["images_limit_reached"] = False return guild_data -async def need_ignore_message(self, data_dict, message, guild_data, original_message, channels): + +async def need_ignore_message(bot, data_dict, message, guild_data, original_message, channels): ## ---- Message ignore conditions ---- ## - if message.author.bot: - return True if data_dict["api_key"] is None: - return True # if the api key is not set, return + return True # if the api key is not set, return if ( # if the message is not in a premium channel and - not str(message.channel.id) in channels - # if the message doesn't mention the bot and - and message.content.find("<@" + str(self.bot.user.id) + ">") == -1 - and original_message == None # if the message is not a reply to the bot and + not (str(message.channel.id) in channels + # if the message doesn't mention the bot and + and (message.content.find("<@" + str(bot.user.id) + ">") != -1 + or original_message)) # if the message is not a reply to the bot and # if the message is not in the default channel and str(message.channel.id) != str(data_dict["channel_id"]) ): @@ -135,33 +135,38 @@ async def need_ignore_message(self, data_dict, message, guild_data, original_mes ) return True return False - - -def get_data_dict(self, message): + + +async def get_data_dict(message): try: curs_data.execute( "SELECT * FROM data WHERE guild_id = ?", (message.guild.id,)) - except: - return - data = curs_data.fetchone() - # Create a dict with the data - data_dict = { - "channel_id": data[1], - "api_key": data[2], - "is_active": data[3], - "max_tokens": data[4], - "temperature": data[5], - "frequency_penalty": data[6], - "presence_penalty": data[7], - "uses_count_today": data[8], - "prompt_size": data[9], - "prompt_prefix": data[10], - "tts": data[11], - "pretend_to_be": data[12], - "pretend_enabled": data[13], - } - return data_dict - + data = curs_data.fetchone() + # Create a dict with the data + data_dict = { + "channel_id": data[1], + "api_key": data[2], + "is_active": data[3], + "max_tokens": data[4], + "temperature": data[5], + "frequency_penalty": data[6], + "presence_penalty": data[7], + "uses_count_today": data[8], + "prompt_size": data[9], + "prompt_prefix": data[10], + "tts": bool(data[11]), + "pretend_to_be": data[12], + "pretend_enabled": data[13], + } + return data_dict + except Exception as e: + # Send an error message + await message.channel.send( + "The bot is not configured yet. Please use `//setup` to configure it. \n" + + "If it still doesn't work, it might be a database error. \n ```" + e.__str__() + + "```", delete_after=60 + ) + async def chat_process(self, message): """This function processes the message and sends the prompt to the API @@ -169,40 +174,11 @@ async def chat_process(self, message): Args: message (str): Data of the message that was sent """ - - if(await need_ignore_message(self, message, guild_data, original_message, channels)): + if message.author.bot: return - data_dict = get_data_dict(message) - - ## ---- Message processing ---- ## - - if data is None: - data = [message.guild.id, 0, 0] - - data_dict["images_usage"] = data[1] - data_dict["images_enabled"] = data[2] - - images_limit_reached = False guild_data = get_guild_data(message) - - channels = [] - if message.guild.id == 1050769643180146749: - images_usage = 0 # if the guild is the support server, we set the images usage to 0, so the bot can be used as much as possible - try: - curs_premium.execute( - "SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,)) - data = curs_premium.fetchone() - if guild_data["premium"]: - # for 5 times, we get c.fetchone()[1] to c.fetchone()[5] and we add it to the channels list, each time with try except - for i in range(1, 6): - # we use the i variable to get the channel id - try: - channels.append(str(data[i])) - except: - pass - except: - channels = [] + data_dict = await get_data_dict(message) try: original_message = await message.channel.fetch_message( @@ -215,6 +191,45 @@ async def chat_process(self, message): # if the message someone replied to is not from the bot, set original_message to None original_message = None + try: + # get the images setting in the database + curs_data.execute( + "SELECT * FROM images WHERE guild_id = ?", (message.guild.id,)) + images_data = curs_data.fetchone() + except: + images_data = None + + ## ---- Message processing ---- ## + + print(message) + + if not images_data: + images_data = [message.guild.id, 0, 0] + + data_dict["images_usage"] = 0 if message.guild.id == 1050769643180146749 else images_data[1] + data_dict["images_enabled"] = images_data[2] + + + channels = [] + try: + curs_premium.execute( + "SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,)) + images_data = curs_premium.fetchone() + if guild_data["premium"]: + # for 5 times, we get c.fetchone()[1] to c.fetchone()[5] and we add it to the channels list, each time with try except + for i in range(1, 6): + # we use the i variable to get the channel id + try: + channels.append(str(images_data[i])) + except: + pass + except: + debug("No premium channels found for this guild") + + if (await need_ignore_message(self.bot, data_dict, message, guild_data, original_message, channels)): + return + print("prompt handler") + try: await message.channel.trigger_typing() # if the message is not in the owner's guild we update the usage count @@ -240,11 +255,13 @@ async def chat_process(self, message): messages.append(message) except Exception as e: debug("Error while getting message history", e) + print(e) # if the pretend to be feature is enabled, we add the pretend to be text to the prompt - pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}" if data_dict["pretend_enabled"] else "" + pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}" if data_dict[ + "pretend_enabled"] else "" prompt_prefix = "" if data_dict["prompt_prefix"] == None else data_dict["prompt_prefix"] - + # open the prompt file for the selected model with utf-8 encoding for emojis with open(f"./prompts/{guild_data['model']}.txt", "r", encoding="utf-8") as f: prompt = f.read() @@ -261,16 +278,48 @@ async def chat_process(self, message): f.close() prompt_handlers = { - "chatGPT": self.gpt_prompt, - "gpt-4": self.gpt_prompt, - "davinci": self.davinci_prompt, + "gpt-3.5-turbo": gpt_prompt, + "gpt-4": gpt_prompt, + "davinci": davinci_prompt, } - prompt_handlers[guild_data["model"]]( - messages, message, data_dict, prompt, guild_data + response = await prompt_handlers[guild_data["model"]]( + self.bot, messages, message, data_dict, prompt, guild_data ) + + if response != "": + emojis, string = await extract_emoji(response) + debug(f"Emojis: {emojis}") + if len(string) < 1996: + await message.channel.send(string, tts=data_dict["tts"]) + else: + # we send in an embed if the message is too long + embed = discord.Embed( + title="Botator response", + description=string, + color=discord.Color.brand_green(), + ) + await message.channel.send(embed=embed, tts=data_dict["tts"]) + for emoji in emojis: + # if the emoji is longer than 1 character, it's a custom emoji + try: + if len(emoji) > 1: + # if the emoji is a custom emoji, we need to fetch it + # the emoji is in the format id + debug(f"Emoji: {emoji}") + emoji = await message.guild.fetch_emoji(int(emoji)) + await message.add_reaction(emoji) + else: + debug(f"Emoji: {emoji}") + await message.add_reaction(emoji) + except: + pass + else: + await message.channel.send( + "The AI is not sure what to say (the response was empty)" + ) -async def check_moderate(self, api_key, message, msg): +async def check_moderate(api_key, message, msg): if await moderate(api_key=api_key, text=msg.content): embed = discord.Embed( title="Message flagged as inappropriate", @@ -285,7 +334,7 @@ async def check_moderate(self, api_key, message, msg): return False -async def check_easter_egg(self, message, msgs): +async def check_easter_egg(message, msgs): if message.content.lower().find("undude") != -1: msgs.append( { @@ -310,7 +359,7 @@ async def check_easter_egg(self, message, msgs): return msgs -async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data): +async def gpt_prompt(bot, messages, message, data_dict, prompt, guild_data): msgs = [] # create the msgs list msgs.append( {"name": "System", "role": "user", "content": prompt} @@ -319,14 +368,14 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data): for msg in messages: # for each message in the messages list content = msg.content # get the content of the message content = await replace_mentions( - content, self.bot + content, bot ) # replace the mentions in the message # if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it - if await self.check_moderate(data_dict["api_key"], message, msg): + if await check_moderate(data_dict["api_key"], message, msg): continue # ignore the message - content = await replace_mentions(content, self.bot) + content = await replace_mentions(content, bot) prompt += f"{msg.author.name}: {content}\n" - if msg.author.id == self.bot.user.id: + if msg.author.id == bot.user.id: role = "assistant" name = "assistant" else: @@ -350,13 +399,13 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data): for attachment in msg.attachments: path = f"./../database/google-vision/results/{attachment.id}.txt" if images_usage >= 6 and guild_data["premium"] == 0: - images_limit_reached = True + guild_data["images_limit_reached"] = True elif images_usage >= 30 and guild_data["premium"] == 1: - images_limit_reached = True + guild_data["images_limit_reached"] = True if ( attachment.url.endswith((".png", ".jpg", ".jpeg", ".gif")) - and images_limit_reached == False - and os.path.exists(path) == False + and not guild_data["images_limit_reached"] + and not os.path.exists(path) ): images_usage += 1 analysis = await vision_processing.process(attachment) @@ -403,17 +452,15 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data): msgs.append({"role": role, "content": f"{content}", "name": name}) # We check for the eastereggs :) - msgs = await self.check_easter_egg(message, msgs) + msgs = await check_easter_egg(message, msgs) - if model == "chatGPT": - model = "gpt-3.5-turbo" # if the model is chatGPT, we set the model to gpt-3.5-turbo response = "" should_break = True for x in range(10): try: openai.api_key = data_dict["api_key"] response = await openai.ChatCompletion.acreate( - model=model, + model=guild_data["model"], temperature=2, top_p=0.9, frequency_penalty=0, @@ -446,11 +493,13 @@ async def gpt_prompt(self, messages, message, data_dict, prompt, guild_data): await asyncio.sleep(15) await message.channel.trigger_typing() response = response.choices[0].message.content - if images_limit_reached == True: + + if guild_data["images_limit_reached"]: await message.channel.send( f"```diff\n-Warning: You have reached the image limit for this server. You can upgrade to premium to get more images recognized. More info in our server: https://discord.gg/sxjHtmqrbf```", delete_after=10, ) + return response async def davinci_prompt(self, messages, message, data_dict, prompt, guild_data): @@ -488,38 +537,4 @@ async def davinci_prompt(self, messages, message, data_dict, prompt, guild_data) return if response != None: break - if response != "": - if tts: - tts = True - else: - tts = False - emojis, string = await extract_emoji(response) - debug(f"Emojis: {emojis}") - if len(string) < 1996: - await message.channel.send(string, tts=tts) - else: - # we send in an embed if the message is too long - embed = discord.Embed( - title="Botator response", - description=string, - color=discord.Color.brand_green(), - ) - await message.channel.send(embed=embed, tts=tts) - for emoji in emojis: - # if the emoji is longer than 1 character, it's a custom emoji - try: - if len(emoji) > 1: - # if the emoji is a custom emoji, we need to fetch it - # the emoji is in the format id - debug(f"Emoji: {emoji}") - emoji = await message.guild.fetch_emoji(int(emoji)) - await message.add_reaction(emoji) - else: - debug(f"Emoji: {emoji}") - await message.add_reaction(emoji) - except: - pass - else: - await message.channel.send( - "The AI is not sure what to say (the response was empty)" - ) + return response \ No newline at end of file