From 8771247612fe3e47515dd57ff6c242fe933c8e78 Mon Sep 17 00:00:00 2001 From: Paillat Date: Mon, 21 Aug 2023 11:36:55 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(ChatProcess.py):=20remove=20?= =?UTF-8?q?unused=20imports=20and=20variables=20to=20improve=20code=20read?= =?UTF-8?q?ability=20and=20maintainability=20=F0=9F=90=9B=20fix(ChatProces?= =?UTF-8?q?s.py):=20fix=20logic=20error=20in=20the=20return=20criteria=20f?= =?UTF-8?q?or=20determining=20if=20the=20bot=20should=20respond=20to=20a?= =?UTF-8?q?=20message=20=F0=9F=90=9B=20fix(ChatProcess.py):=20fix=20typo?= =?UTF-8?q?=20in=20the=20'functions'=20variable=20name=20=F0=9F=90=9B=20fi?= =?UTF-8?q?x(ChatProcess.py):=20fix=20typo=20in=20the=20'functions'=20para?= =?UTF-8?q?meter=20name=20in=20the=20request=20function=20call=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(ChatProcess.py):=20fix=20typo=20in=20the=20'?= =?UTF-8?q?functions'=20parameter=20name=20in=20the=20processFunctioncallR?= =?UTF-8?q?esponse=20function=20call=20=F0=9F=90=9B=20fix(ChatProcess.py):?= =?UTF-8?q?=20remove=20unnecessary=20print=20statement=20in=20the=20proces?= =?UTF-8?q?sMessage=20function=20=F0=9F=90=9B=20fix(prompts.py):=20remove?= =?UTF-8?q?=20unnecessary=20print=20statement=20in=20the=20createPrompt=20?= =?UTF-8?q?function=20=F0=9F=90=9B=20fix(channelSetup.py):=20fix=20logic?= =?UTF-8?q?=20error=20in=20the=20is=5Fowner=20function=20call=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(moderation.py):=20remove=20unnecessary=20cod?= =?UTF-8?q?e=20for=20disabling=20moderation=20=F0=9F=90=9B=20fix(config.py?= =?UTF-8?q?):=20remove=20unnecessary=20code=20for=20creating=20tables=20in?= =?UTF-8?q?=20the=20database=20=F0=9F=90=9B=20fix(functionscalls.py):=20fi?= =?UTF-8?q?x=20type=20hint=20for=20the=20return=20value=20of=20the=20call?= =?UTF-8?q?=5Ffunction=20function=20=F0=9F=90=9B=20fix(guild.py):=20fix=20?= =?UTF-8?q?handling=20of=20serialized=20data=20in=20the=20load=20function?= =?UTF-8?q?=20=F0=9F=90=9B=20fix(SqlConnector.py):=20create=20setup=5Fdata?= =?UTF-8?q?=20table=20if=20it=20does=20not=20exist?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ChatProcess.py | 49 ++++++++++++++++++++++++--------------- src/chatUtils/prompts.py | 1 - src/cogs/channelSetup.py | 2 +- src/cogs/manage_chat.py | 1 - src/cogs/moderation.py | 8 ------- src/config.py | 20 ---------------- src/functionscalls.py | 6 +++-- src/guild.py | 5 +++- src/utils/SqlConnector.py | 5 ++++ 9 files changed, 44 insertions(+), 53 deletions(-) diff --git a/src/ChatProcess.py b/src/ChatProcess.py index dd690c1..a6e9b9a 100644 --- a/src/ChatProcess.py +++ b/src/ChatProcess.py @@ -10,10 +10,7 @@ from src.utils.variousclasses import models from src.guild import Guild from src.chatUtils.Chat import fetch_messages_history from src.chatUtils.prompts import createPrompt -from src.functionscalls import ( - call_function, - server_normal_channel_functions, -) +from src.functionscalls import call_function, server_normal_channel_functions, functions from src.config import debug from src.chatUtils.requesters.request import request @@ -49,7 +46,7 @@ class Chat: if ( self.original_message != None - and self.original_message.author.id == self.bot.user.id + and self.original_message.author.id != self.bot.user.id ): self.original_message = None @@ -67,18 +64,29 @@ class Chat: Returns True if any of the exit criterias are met (their opposite is met but there is a not in front of the any() function) This checks if the bot should actuallly respond to the message or if the message doesn't concern the bot """ - returnCriterias = [] - returnCriterias.append(self.openai_api_key != None) - returnCriterias.append( - self.message.content.find("<@" + str(self.bot.user.id) + ">") != -1 + + serverwideReturnCriterias = [] + serverwideReturnCriterias.append(self.original_message != None) + serverwideReturnCriterias.append( + self.message.content.find(f"<@{self.bot.user.id}>") != -1 ) - returnCriterias.append(self.original_message != None) - returnCriterias.append(self.is_bots_thread) - returnCriterias.append( - self.guild.sanitizedChannels.get(str(self.channelIdForSettings), None) - != None + serverwideReturnCriterias.append(self.is_bots_thread) + + channelReturnCriterias = [] + channelReturnCriterias.append(self.channelIdForSettings != "serverwide") + channelReturnCriterias.append( + self.guild.getChannelInfo(self.channelIdForSettings) != None ) - return not any(returnCriterias) + + messageReturnCriterias = [] + messageReturnCriterias.append( + any(serverwideReturnCriterias) + and self.guild.getChannelInfo("serverwide") != None + ) + messageReturnCriterias.append(all(channelReturnCriterias)) + + returnCriterias: bool = not any(messageReturnCriterias) + return returnCriterias async def getSettings(self): self.settings = self.guild.getChannelInfo( @@ -129,11 +137,14 @@ class Chat: """ This function gets the response from the ai """ + funcs = functions + if isinstance(self.message.channel, discord.TextChannel): + funcs.extend(server_normal_channel_functions) self.response = await request( model=self.model, prompt=self.prompt, openai_api_key=self.openai_api_key, - funtcions=server_normal_channel_functions, + funtcions=funcs, ) async def processResponse(self): @@ -142,14 +153,15 @@ class Chat: function_call=self.response, api_key=self.openai_api_key, ) - if response != None: + if response[0] != None: await self.processFunctioncallResponse(response) async def processFunctioncallResponse(self, response): self.context.append( { "role": "function", - "content": response, + "content": response[0], + "name": response[1], } ) if self.depth < 3: @@ -166,7 +178,6 @@ class Chat: This function processes the message """ if await self.preExitCriteria(): - print("pre exit criteria") return await self.getSupplementaryData() await self.getSettings() diff --git a/src/chatUtils/prompts.py b/src/chatUtils/prompts.py index 854e36a..6140079 100644 --- a/src/chatUtils/prompts.py +++ b/src/chatUtils/prompts.py @@ -27,7 +27,6 @@ def createPrompt( """ Creates a prompt from the messages list """ - print(f"Creating prompt with type {modeltype}") if modeltype == "chat": prompt = createChatPrompt(messages, model, character) sysprompt = prompt[0]["content"] diff --git a/src/cogs/channelSetup.py b/src/cogs/channelSetup.py index 964bec6..d83d346 100644 --- a/src/cogs/channelSetup.py +++ b/src/cogs/channelSetup.py @@ -236,7 +236,7 @@ class ChannelSetup(commands.Cog): async def premium(self, ctx: discord.ApplicationContext): guild = Guild(ctx.guild.id) guild.load() - if self.bot.is_owner(ctx.author): + if await self.bot.is_owner(ctx.author): guild.premium = True # also set expiry date in 6 months isofromat guild.premium_expiration = datetime.datetime.now() + datetime.timedelta( diff --git a/src/cogs/manage_chat.py b/src/cogs/manage_chat.py index 64953e0..e9bf25e 100644 --- a/src/cogs/manage_chat.py +++ b/src/cogs/manage_chat.py @@ -1,7 +1,6 @@ import discord import re import os -from src.config import debug, curs_data class ManageChat(discord.Cog): diff --git a/src/cogs/moderation.py b/src/cogs/moderation.py index 62a1d40..3e6305d 100644 --- a/src/cogs/moderation.py +++ b/src/cogs/moderation.py @@ -2,7 +2,6 @@ import discord from discord import default_permissions from discord.ext import commands import os -from src.config import debug, curs_data, con_data import openai import requests @@ -91,13 +90,6 @@ class Moderation(discord.Cog): "Our moderation capabilities have been switched to our new 100% free and open-source AI discord moderation bot! You add it to your server here: https://discord.com/api/oauth2/authorize?client_id=1071451913024974939&permissions=1377342450896&scope=bot and you can find the source code here: https://github.com/Paillat-dev/Moderator/ \n If you need help, you can join our support server here: https://discord.gg/pB6hXtUeDv", ephemeral=True, ) - if enable == False: - curs_data.execute( - "DELETE FROM moderation WHERE guild_id = ?", (str(ctx.guild.id),) - ) - con_data.commit() - await ctx.respond("Moderation disabled!", ephemeral=True) - return @discord.slash_command( name="get_toxicity", description="Get the toxicity of a message" diff --git a/src/config.py b/src/config.py index 075bd69..7d32c8d 100644 --- a/src/config.py +++ b/src/config.py @@ -45,29 +45,9 @@ def mg_to_guid(mg): return mg.guild.id -con_data = sqlite3.connect("./database/data.db") -curs_data = con_data.cursor() con_premium = sqlite3.connect("./database/premium.db") curs_premium = con_premium.cursor() -curs_data.execute( - """CREATE TABLE IF NOT EXISTS data (guild_id text, channel_id text, api_key text, is_active boolean, max_tokens integer, temperature real, frequency_penalty real, presence_penalty real, uses_count_today integer, prompt_size integer, prompt_prefix text, tts boolean, pretend_to_be text, pretend_enabled boolean)""" -) - -con_data.execute( - "CREATE TABLE IF NOT EXISTS setup_data (guild_id text, guild_settings text)" -) - -# This code creates the model table if it does not exist -curs_data.execute( - """CREATE TABLE IF NOT EXISTS model (guild_id text, model_name text)""" -) - -# This code creates the images table if it does not exist -curs_data.execute( - """CREATE TABLE IF NOT EXISTS images (guild_id text, usage_count integer, is_enabled boolean)""" -) - # This code creates the data table if it does not exist curs_premium.execute( """CREATE TABLE IF NOT EXISTS data (user_id text, guild_id text, premium boolean)""" diff --git a/src/functionscalls.py b/src/functionscalls.py index be64abb..e51079a 100644 --- a/src/functionscalls.py +++ b/src/functionscalls.py @@ -341,7 +341,9 @@ async def evaluate_math( return f"Result to math eval of {evaluable}: ```\n{str(result)}```" -async def call_function(message: discord.Message, function_call, api_key): +async def call_function( + message: discord.Message, function_call, api_key +) -> list[None | str]: name = function_call.get("name", "") if name == "": raise FuntionCallError("No name provided") @@ -359,7 +361,7 @@ async def call_function(message: discord.Message, function_call, api_key): ): return "Query blocked by the moderation system. If the user asked for something edgy, please tell them in a funny way that you won't do it, but do not specify that it was blocked by the moderation system." returnable = await function(message, arguments) - return returnable + return [returnable, name] functions_matching = { diff --git a/src/guild.py b/src/guild.py index a76e5f2..d409956 100644 --- a/src/guild.py +++ b/src/guild.py @@ -25,7 +25,10 @@ class Guild: "SELECT * FROM setup_data WHERE guild_id = ?", (self.id,) ) data = curs_data.fetchone() - data = orjson.loads(data[1]) + if type(data[1]) == str and data[1].startswith("b'"): + data = orjson.loads(data[1][2:-1]) + else: + data = orjson.loads(data[1]) self.premium = data["premium"] self.channels = data["channels"] self.api_keys = data["api_keys"] diff --git a/src/utils/SqlConnector.py b/src/utils/SqlConnector.py index d9f2f31..ea69ab3 100644 --- a/src/utils/SqlConnector.py +++ b/src/utils/SqlConnector.py @@ -22,3 +22,8 @@ class _sql: sql: _sql = _sql() + +command = "CREATE TABLE IF NOT EXISTS setup_data (guild_id text, guild_settings text)" + +with sql.mainDb as db: + db.execute(command)