diff --git a/code/code.py b/code/code.py index d824ea5..3eb48aa 100644 --- a/code/code.py +++ b/code/code.py @@ -6,7 +6,6 @@ import logging # pip install logging import sqlite3 # pip install sqlite3 import asyncio # pip install asyncio import os # pip install os -import random # pip install random import re # pip install re import datetime # pip install datetime #set the debug mode to the maximum @@ -278,7 +277,7 @@ async def pretend(ctx, pretend_to_be: str): c.execute("UPDATE data SET pretend_to_be = ? WHERE guild_id = ?", (pretend_to_be, ctx.guild.id)) conn.commit() @bot.event -async def on_message(message): +async def on_message(message: discord.Message): #check if the message is from a bot if message.author.bot: return @@ -292,8 +291,15 @@ async def on_message(message): return #check if the message has been sent in the channel set in the database c.execute("SELECT channel_id FROM data WHERE guild_id = ?", (message.guild.id,)) + try : original_message = await message.channel.fetch_message(message.reference.message_id) + except : original_message = None + if original_message != None and original_message.author.id != bot.user.id: + original_message = None if str(message.channel.id) != str(c.fetchone()[0]): - if message.content.find("<@1046051875755134996>") != -1: + #check if the message is a mention or if the message replies to the bot + if original_message != None: + debug("wrong channel, but reply") + elif message.content.find("<@"+str(bot.user.id)+">") != -1: debug("wrong channel, but mention") else : debug("The message has been sent in the wrong channel") @@ -315,8 +321,14 @@ async def on_message(message): #get the advanced settings from the database c.execute("SELECT max_tokens, temperature, frequency_penalty, presence_penalty, prompt_size FROM data WHERE guild_id = ?", (message.guild.id,)) max_tokens, temperature, frequency_penalty, presence_penalty, prompt_size = c.fetchone() - messages = await message.channel.history(limit=prompt_size).flatten() - messages.reverse() + if original_message == None: + messages = await message.channel.history(limit=prompt_size).flatten() + messages.reverse() + else : + messages = await message.channel.history(limit=prompt_size, before=original_message).flatten() + messages.reverse() + messages.append(original_message) + messages.append(message) prompt = "" #get the channel id from the database c.execute("SELECT channel_id FROM data WHERE guild_id = ?", (message.guild.id,)) @@ -517,4 +529,5 @@ bot.loop.create_task(check_day_task()) # Replace the following with your bot's token with open("key.txt") as f: key = f.read() -bot.run(key) + +bot.run(key) \ No newline at end of file