diff --git a/src/ChatProcess.py b/src/ChatProcess.py index dd690c1..a6e9b9a 100644 --- a/src/ChatProcess.py +++ b/src/ChatProcess.py @@ -10,10 +10,7 @@ from src.utils.variousclasses import models from src.guild import Guild from src.chatUtils.Chat import fetch_messages_history from src.chatUtils.prompts import createPrompt -from src.functionscalls import ( - call_function, - server_normal_channel_functions, -) +from src.functionscalls import call_function, server_normal_channel_functions, functions from src.config import debug from src.chatUtils.requesters.request import request @@ -49,7 +46,7 @@ class Chat: if ( self.original_message != None - and self.original_message.author.id == self.bot.user.id + and self.original_message.author.id != self.bot.user.id ): self.original_message = None @@ -67,18 +64,29 @@ class Chat: Returns True if any of the exit criterias are met (their opposite is met but there is a not in front of the any() function) This checks if the bot should actuallly respond to the message or if the message doesn't concern the bot """ - returnCriterias = [] - returnCriterias.append(self.openai_api_key != None) - returnCriterias.append( - self.message.content.find("<@" + str(self.bot.user.id) + ">") != -1 + + serverwideReturnCriterias = [] + serverwideReturnCriterias.append(self.original_message != None) + serverwideReturnCriterias.append( + self.message.content.find(f"<@{self.bot.user.id}>") != -1 ) - returnCriterias.append(self.original_message != None) - returnCriterias.append(self.is_bots_thread) - returnCriterias.append( - self.guild.sanitizedChannels.get(str(self.channelIdForSettings), None) - != None + serverwideReturnCriterias.append(self.is_bots_thread) + + channelReturnCriterias = [] + channelReturnCriterias.append(self.channelIdForSettings != "serverwide") + channelReturnCriterias.append( + self.guild.getChannelInfo(self.channelIdForSettings) != None ) - return not any(returnCriterias) + + messageReturnCriterias = [] + messageReturnCriterias.append( + any(serverwideReturnCriterias) + and self.guild.getChannelInfo("serverwide") != None + ) + messageReturnCriterias.append(all(channelReturnCriterias)) + + returnCriterias: bool = not any(messageReturnCriterias) + return returnCriterias async def getSettings(self): self.settings = self.guild.getChannelInfo( @@ -129,11 +137,14 @@ class Chat: """ This function gets the response from the ai """ + funcs = functions + if isinstance(self.message.channel, discord.TextChannel): + funcs.extend(server_normal_channel_functions) self.response = await request( model=self.model, prompt=self.prompt, openai_api_key=self.openai_api_key, - funtcions=server_normal_channel_functions, + funtcions=funcs, ) async def processResponse(self): @@ -142,14 +153,15 @@ class Chat: function_call=self.response, api_key=self.openai_api_key, ) - if response != None: + if response[0] != None: await self.processFunctioncallResponse(response) async def processFunctioncallResponse(self, response): self.context.append( { "role": "function", - "content": response, + "content": response[0], + "name": response[1], } ) if self.depth < 3: @@ -166,7 +178,6 @@ class Chat: This function processes the message """ if await self.preExitCriteria(): - print("pre exit criteria") return await self.getSupplementaryData() await self.getSettings() diff --git a/src/chatUtils/prompts.py b/src/chatUtils/prompts.py index 854e36a..6140079 100644 --- a/src/chatUtils/prompts.py +++ b/src/chatUtils/prompts.py @@ -27,7 +27,6 @@ def createPrompt( """ Creates a prompt from the messages list """ - print(f"Creating prompt with type {modeltype}") if modeltype == "chat": prompt = createChatPrompt(messages, model, character) sysprompt = prompt[0]["content"] diff --git a/src/cogs/channelSetup.py b/src/cogs/channelSetup.py index 964bec6..d83d346 100644 --- a/src/cogs/channelSetup.py +++ b/src/cogs/channelSetup.py @@ -236,7 +236,7 @@ class ChannelSetup(commands.Cog): async def premium(self, ctx: discord.ApplicationContext): guild = Guild(ctx.guild.id) guild.load() - if self.bot.is_owner(ctx.author): + if await self.bot.is_owner(ctx.author): guild.premium = True # also set expiry date in 6 months isofromat guild.premium_expiration = datetime.datetime.now() + datetime.timedelta( diff --git a/src/cogs/manage_chat.py b/src/cogs/manage_chat.py index 64953e0..e9bf25e 100644 --- a/src/cogs/manage_chat.py +++ b/src/cogs/manage_chat.py @@ -1,7 +1,6 @@ import discord import re import os -from src.config import debug, curs_data class ManageChat(discord.Cog): diff --git a/src/cogs/moderation.py b/src/cogs/moderation.py index 62a1d40..3e6305d 100644 --- a/src/cogs/moderation.py +++ b/src/cogs/moderation.py @@ -2,7 +2,6 @@ import discord from discord import default_permissions from discord.ext import commands import os -from src.config import debug, curs_data, con_data import openai import requests @@ -91,13 +90,6 @@ class Moderation(discord.Cog): "Our moderation capabilities have been switched to our new 100% free and open-source AI discord moderation bot! You add it to your server here: https://discord.com/api/oauth2/authorize?client_id=1071451913024974939&permissions=1377342450896&scope=bot and you can find the source code here: https://github.com/Paillat-dev/Moderator/ \n If you need help, you can join our support server here: https://discord.gg/pB6hXtUeDv", ephemeral=True, ) - if enable == False: - curs_data.execute( - "DELETE FROM moderation WHERE guild_id = ?", (str(ctx.guild.id),) - ) - con_data.commit() - await ctx.respond("Moderation disabled!", ephemeral=True) - return @discord.slash_command( name="get_toxicity", description="Get the toxicity of a message" diff --git a/src/config.py b/src/config.py index 075bd69..7d32c8d 100644 --- a/src/config.py +++ b/src/config.py @@ -45,29 +45,9 @@ def mg_to_guid(mg): return mg.guild.id -con_data = sqlite3.connect("./database/data.db") -curs_data = con_data.cursor() con_premium = sqlite3.connect("./database/premium.db") curs_premium = con_premium.cursor() -curs_data.execute( - """CREATE TABLE IF NOT EXISTS data (guild_id text, channel_id text, api_key text, is_active boolean, max_tokens integer, temperature real, frequency_penalty real, presence_penalty real, uses_count_today integer, prompt_size integer, prompt_prefix text, tts boolean, pretend_to_be text, pretend_enabled boolean)""" -) - -con_data.execute( - "CREATE TABLE IF NOT EXISTS setup_data (guild_id text, guild_settings text)" -) - -# This code creates the model table if it does not exist -curs_data.execute( - """CREATE TABLE IF NOT EXISTS model (guild_id text, model_name text)""" -) - -# This code creates the images table if it does not exist -curs_data.execute( - """CREATE TABLE IF NOT EXISTS images (guild_id text, usage_count integer, is_enabled boolean)""" -) - # This code creates the data table if it does not exist curs_premium.execute( """CREATE TABLE IF NOT EXISTS data (user_id text, guild_id text, premium boolean)""" diff --git a/src/functionscalls.py b/src/functionscalls.py index be64abb..e51079a 100644 --- a/src/functionscalls.py +++ b/src/functionscalls.py @@ -341,7 +341,9 @@ async def evaluate_math( return f"Result to math eval of {evaluable}: ```\n{str(result)}```" -async def call_function(message: discord.Message, function_call, api_key): +async def call_function( + message: discord.Message, function_call, api_key +) -> list[None | str]: name = function_call.get("name", "") if name == "": raise FuntionCallError("No name provided") @@ -359,7 +361,7 @@ async def call_function(message: discord.Message, function_call, api_key): ): return "Query blocked by the moderation system. If the user asked for something edgy, please tell them in a funny way that you won't do it, but do not specify that it was blocked by the moderation system." returnable = await function(message, arguments) - return returnable + return [returnable, name] functions_matching = { diff --git a/src/guild.py b/src/guild.py index a76e5f2..d409956 100644 --- a/src/guild.py +++ b/src/guild.py @@ -25,7 +25,10 @@ class Guild: "SELECT * FROM setup_data WHERE guild_id = ?", (self.id,) ) data = curs_data.fetchone() - data = orjson.loads(data[1]) + if type(data[1]) == str and data[1].startswith("b'"): + data = orjson.loads(data[1][2:-1]) + else: + data = orjson.loads(data[1]) self.premium = data["premium"] self.channels = data["channels"] self.api_keys = data["api_keys"] diff --git a/src/utils/SqlConnector.py b/src/utils/SqlConnector.py index d9f2f31..ea69ab3 100644 --- a/src/utils/SqlConnector.py +++ b/src/utils/SqlConnector.py @@ -22,3 +22,8 @@ class _sql: sql: _sql = _sql() + +command = "CREATE TABLE IF NOT EXISTS setup_data (guild_id text, guild_settings text)" + +with sql.mainDb as db: + db.execute(command)