mirror of
https://github.com/Paillat-dev/Botator.git
synced 2026-01-02 09:16:19 +00:00
🐛 fix(main.py): handle on_application_command_error with proper error handling and response
✨ feat(main.py): add ChatProcess module for handling chat-related functionality 🔧 refactor(main.py): import necessary modules and update bot.add_cog calls 🔧 refactor(server.ts): change port variable case from lowercase port to uppercase PORT to improve semantics ✨ feat(server.ts): add support for process.env.PORT environment variable to be able to run app on a configurable port 🔧 refactor(cogs/__init__.py): import ChannelSetup cog ✨ feat(cogs/channelSetup.py): add ChannelSetup cog for setting up channels and server-wide settings 🔧 refactor(cogs/setup.py): import SlashCommandGroup and guild_only from discord module ✨ feat(cogs/setup.py): add setup_channel command for adding and removing channels ✨ feat(cogs/setup.py): add api command for setting API keys ✨ feat(cogs/setup.py): add premium command for setting guild to premium 🔧 refactor(cogs/settings.py): temporarily disable images command due to maintenance 🔧 refactor(config.py): remove unnecessary code related to moderation table ✨ feat(guild.py): add Guild class for managing guild-specific data and settings ✨ feat(SqlConnector.py): add SQLConnection and _sql classes for managing SQLite connections ✨ feat(variousclasses.py): add models, characters, and apis classes for autocomplete functionality in slash commands
This commit is contained in:
5
main.py
5
main.py
@@ -13,6 +13,7 @@ bot.add_cog(cogs.Help(bot))
|
||||
bot.add_cog(cogs.Chat(bot))
|
||||
bot.add_cog(cogs.ManageChat(bot))
|
||||
bot.add_cog(cogs.Moderation(bot))
|
||||
bot.add_cog(cogs.ChannelSetup(bot))
|
||||
|
||||
|
||||
# set the bot's watching status to watcing your messages to answer you
|
||||
@@ -36,9 +37,9 @@ async def on_guild_join(guild):
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_application_command_error(ctx, error):
|
||||
debug(error)
|
||||
async def on_application_command_error(ctx, error: discord.DiscordException):
|
||||
await ctx.respond(error, ephemeral=True)
|
||||
raise error
|
||||
|
||||
|
||||
bot.run(discord_token) # run the bot
|
||||
|
||||
86
src/ChatProcess.py
Normal file
86
src/ChatProcess.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import discord
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from src.utils.misc import moderate, ModerationError, Hasher
|
||||
from src.utils.variousclasses import models, characters, apis
|
||||
from src.guild import Guild
|
||||
from src.utils.openaicaller import openai_caller
|
||||
from src.functionscalls import (
|
||||
call_function,
|
||||
functions,
|
||||
server_normal_channel_functions,
|
||||
FuntionCallError,
|
||||
)
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, bot, message: discord.Message):
|
||||
self.bot = bot
|
||||
self.message: discord.Message = message
|
||||
self.guild = Guild(self.message.guild.id)
|
||||
self.author = message.author
|
||||
self.is_bots_thread = False
|
||||
|
||||
async def getSupplementaryData(self) -> None:
|
||||
"""
|
||||
This function gets various contextual data that will be needed later on
|
||||
- The original message (if the message is a reply to a previous message from the bot)
|
||||
- The channel the message was sent in (if the message was sent in a thread)
|
||||
"""
|
||||
if isinstance(self.message.channel, discord.Thread):
|
||||
if self.message.channel.owner_id == self.bot.user.id:
|
||||
self.is_bots_thread = True
|
||||
self.channelIdForSettings = self.message.channel.parent_id
|
||||
else:
|
||||
self.channelIdForSettings = self.message.channel.id
|
||||
|
||||
try:
|
||||
self.original_message = await self.message.channel.fetch_message(
|
||||
self.message.reference.message_id
|
||||
)
|
||||
except:
|
||||
self.original_message = None
|
||||
|
||||
if (
|
||||
self.original_message != None
|
||||
and self.original_message.author.id == self.bot.user.id
|
||||
):
|
||||
self.original_message = None
|
||||
|
||||
async def preExitCriteria(self) -> bool:
|
||||
"""
|
||||
Returns True if any of the exit criterias are met
|
||||
This checks if the guild has the needed settings for the bot to work
|
||||
"""
|
||||
returnCriterias = []
|
||||
returnCriterias.append(self.message.author.id == self.bot.user.id)
|
||||
returnCriterias.append(self.api_key == None)
|
||||
returnCriterias.append(self.is_active == 0)
|
||||
return any(returnCriterias)
|
||||
|
||||
async def postExitCriteria(self) -> bool:
|
||||
"""
|
||||
Returns True if any of the exit criterias are met (their opposite is met but there is a not in front of the any() function)
|
||||
This checks if the bot should actuallly respond to the message or if the message doesn't concern the bot
|
||||
"""
|
||||
returnCriterias = []
|
||||
returnCriterias.append(self.guild.sanitizedChannels.get(str(self.message.channel.id), None) != None)
|
||||
returnCriterias.append(
|
||||
self.message.content.find("<@" + str(self.bot.user.id) + ">") != -1
|
||||
)
|
||||
returnCriterias.append(self.original_message != None)
|
||||
returnCriterias.append(self.is_bots_thread)
|
||||
|
||||
return not any(returnCriterias)
|
||||
|
||||
async def getSettings(self):
|
||||
self.settings = self.guild.getChannelInfo(str(self.channelIdForSettings))
|
||||
self.model = self.settings["model"]
|
||||
self.character = self.settings["character"]
|
||||
self.openai_api_key = self.guild.api_keys.get("openai", None)
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@ from src.cogs.help import Help
|
||||
from src.cogs.chat import Chat
|
||||
from src.cogs.manage_chat import ManageChat
|
||||
from src.cogs.moderation import Moderation
|
||||
from src.cogs.channelSetup import ChannelSetup
|
||||
97
src/cogs/channelSetup.py
Normal file
97
src/cogs/channelSetup.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import discord
|
||||
import orjson
|
||||
|
||||
from discord import SlashCommandGroup
|
||||
from discord import default_permissions
|
||||
from discord.ext.commands import guild_only
|
||||
from discord.ext import commands
|
||||
from src.utils.variousclasses import models, characters, apis
|
||||
from src.guild import Guild
|
||||
|
||||
sampleDataFormatExample = {
|
||||
"guild_id": 1234567890,
|
||||
"premium": False,
|
||||
"premium_expiration": 0,
|
||||
}
|
||||
|
||||
class ChannelSetup(commands.Cog):
|
||||
def __init__(self, bot: discord.Bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
setup = SlashCommandGroup("setup", description="Setup commands for the bot, inlcuding channels, models, and more.")
|
||||
|
||||
setup_channel = setup.create_subgroup(name="channel", description="Setup, add, or remove channels for the bot to use.")
|
||||
|
||||
@setup_channel.command(name="add", description="Add a channel for the bot to use. Can also specify server-wide settings.")
|
||||
@discord.option(name="channel", description="The channel to setup. If not specified, will use the current channel.", type=discord.TextChannel, required=False)
|
||||
@discord.option(name="model", description="The model to use for this channel.", type=str, required=False, autocomplete=models.autocomplete)
|
||||
@discord.option(name="character", description="The character to use for this channel.", type=str, required=False, autocomplete=characters.autocomplete)
|
||||
@guild_only()
|
||||
async def channel(self, ctx: discord.ApplicationContext, channel: discord.TextChannel = None, model: str = models.default, character: str = characters.default):
|
||||
if channel is None:
|
||||
channel = ctx.channel
|
||||
guild = Guild(ctx.guild.id)
|
||||
guild.load()
|
||||
if not guild.premium:
|
||||
if len(guild.channels) >= 1 and guild.channels.get(str(channel.id), None) is None:
|
||||
await ctx.respond("`Warning: You are not a premium user, and can only have one channel setup. The settings will still be saved, but will not be used.`", ephemeral=True)
|
||||
if model != models.default:
|
||||
await ctx.respond("`Warning: You are not a premium user, and can only use the default model. The settings will still be saved, but will not be used.`", ephemeral=True)
|
||||
if character != characters.default:
|
||||
await ctx.respond("`Warning: You are not a premium user, and can only use the default character. The settings will still be saved, but will not be used.`", ephemeral=True)
|
||||
if guild.api_keys.get("openai", None) is None:
|
||||
await ctx.respond("`Error: No openai api key is set. The api key is needed for the openai models, as well as for the content moderation. The openai models will cost you tokens in your openai account. However, if you use one of the llama models, you will not be charged, but the api key is still needed for content moderation, wich is free but requires an api key.`", ephemeral=True)
|
||||
guild.addChannel(channel, models.matchingDict[model], characters.matchingDict[character])
|
||||
await ctx.respond(f"Set channel {channel.mention} with model `{model}` and character `{character}`.")
|
||||
|
||||
@setup_channel.command(name="remove", description="Remove a channel from the bot's usage.")
|
||||
@discord.option(name="channel", description="The channel to remove. If not specified, will use the current channel.", type=discord.TextChannel, required=False)
|
||||
@guild_only()
|
||||
async def remove(self, ctx: discord.ApplicationContext, channel: discord.TextChannel = None):
|
||||
if channel is None:
|
||||
channel = ctx.channel
|
||||
guild = Guild(ctx.guild.id)
|
||||
guild.load()
|
||||
if channel.id not in guild.channels:
|
||||
await ctx.respond("That channel is not setup.")
|
||||
return
|
||||
guild.delChannel(channel)
|
||||
await ctx.respond(f"Removed channel {channel.mention}.")
|
||||
|
||||
@setup_channel.command(name="list", description="List all channels that are setup.")
|
||||
@guild_only()
|
||||
async def list(self, ctx: discord.ApplicationContext):
|
||||
guild = Guild(ctx.guild.id)
|
||||
guild.load()
|
||||
if len(guild.channels) == 0:
|
||||
await ctx.respond("No channels are setup.")
|
||||
return
|
||||
embed = discord.Embed(title="Channels", description="All channels that are setup.", color=discord.Color.nitro_pink())
|
||||
channels = guild.sanitizedChannels
|
||||
for channel in channels:
|
||||
discochannel = await self.bot.fetch_channel(int(channel))
|
||||
model = models.reverseMatchingDict[channels[channel]["model"]]
|
||||
character = characters.reverseMatchingDict[channels[channel]["character"]]
|
||||
embed.add_field(name=f"{discochannel.mention}", value=f"Model: `{model}`\nCharacter: `{character}`", inline=False)
|
||||
await ctx.respond(embed=embed)
|
||||
|
||||
@setup.command(name="api", description="Set an API key for the bot to use.")
|
||||
@discord.option(name="api", description="The API to set. Currently only OpenAI is supported.", type=str, required=True, autocomplete=apis.autocomplete)
|
||||
@discord.option(name="key", description="The key to set.", type=str, required=True)
|
||||
@guild_only()
|
||||
async def api(self, ctx: discord.ApplicationContext, api: str, key: str):
|
||||
guild = Guild(ctx.guild.id)
|
||||
guild.load()
|
||||
guild.api_keys[apis.matchingDict[api]] = key
|
||||
guild.updateDbData()
|
||||
await ctx.respond(f"Set API key for {api} to `secret`.", ephemeral=True)
|
||||
|
||||
@setup.command(name="premium", description="Set the guild to premium.")
|
||||
async def premium(self, ctx: discord.ApplicationContext):
|
||||
guild = Guild(ctx.guild.id)
|
||||
guild.load()
|
||||
if not guild.premium:
|
||||
await ctx.respond("You can get your premium subscription at https://www.botator.dev/premium.", ephemeral=True)
|
||||
else:
|
||||
await ctx.respond("This guild is already premium.", ephemeral=True)
|
||||
@@ -266,6 +266,12 @@ class Settings(discord.Cog):
|
||||
)
|
||||
@default_permissions(administrator=True)
|
||||
async def images(self, ctx: discord.ApplicationContext, enable_disable: str):
|
||||
return await ctx.respond(
|
||||
"""
|
||||
Images recognition is under maintenance and will come back soon!
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
curs_data.execute(
|
||||
"SELECT * FROM images WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import discord
|
||||
from discord import SlashCommandGroup
|
||||
from discord import default_permissions, guild_only
|
||||
from discord.ext import commands
|
||||
from src.config import (
|
||||
@@ -28,7 +29,7 @@ class Setup(discord.Cog):
|
||||
def __init__(self, bot: discord.Bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
"""
|
||||
@discord.slash_command(name="setup", description="Setup the bot")
|
||||
@discord.option(name="channel_id", description="The channel id", required=True)
|
||||
@discord.option(name="api_key", description="The api key", required=True)
|
||||
@@ -138,7 +139,7 @@ class Setup(discord.Cog):
|
||||
)
|
||||
con_data.commit()
|
||||
await ctx.respond("The api key has been added", ephemeral=True)
|
||||
|
||||
"""
|
||||
@discord.slash_command(
|
||||
name="delete", description="Delete the information about this server"
|
||||
)
|
||||
@@ -190,8 +191,10 @@ class Setup(discord.Cog):
|
||||
await ctx.respond("Disabled", ephemeral=True)
|
||||
|
||||
# create a command calles "add channel" that can only be used in premium servers
|
||||
|
||||
"""
|
||||
@discord.slash_command(
|
||||
name="add_channel",
|
||||
name="setup_channel",
|
||||
description="Add a channel to the list of channels. Premium only.",
|
||||
)
|
||||
@discord.option(
|
||||
@@ -257,7 +260,7 @@ class Setup(discord.Cog):
|
||||
await ctx.respond(f"Added channel **{channel.name}**", ephemeral=True)
|
||||
return
|
||||
await ctx.respond("You can only add 5 channels", ephemeral=True)
|
||||
|
||||
"""
|
||||
# create a command called "remove channel" that can only be used in premium servers
|
||||
@discord.slash_command(
|
||||
name="remove_channel",
|
||||
|
||||
@@ -53,33 +53,8 @@ curs_premium = con_premium.cursor()
|
||||
curs_data.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS data (guild_id text, channel_id text, api_key text, is_active boolean, max_tokens integer, temperature real, frequency_penalty real, presence_penalty real, uses_count_today integer, prompt_size integer, prompt_prefix text, tts boolean, pretend_to_be text, pretend_enabled boolean)"""
|
||||
)
|
||||
# we delete the moderation table and create a new one, with all theese parameters as floats: TOXICITY: {result[0]}; SEVERE_TOXICITY: {result[1]}; IDENTITY ATTACK: {result[2]}; INSULT: {result[3]}; PROFANITY: {result[4]}; THREAT: {result[5]}; SEXUALLY EXPLICIT: {result[6]}; FLIRTATION: {result[7]}; OBSCENE: {result[8]}; SPAM: {result[9]}
|
||||
expected_columns = 14
|
||||
|
||||
# we delete the moderation table and create a new one
|
||||
curs_data.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS moderation (guild_id text, logs_channel_id text, is_enabled boolean, mod_role_id text, toxicity real, severe_toxicity real, identity_attack real, insult real, profanity real, threat real, sexually_explicit real, flirtation real, obscene real, spam real)"""
|
||||
)
|
||||
|
||||
# This code returns the number of columns in the table "moderation" in the database "data.db".
|
||||
curs_data.execute("PRAGMA table_info(moderation)")
|
||||
result = curs_data.fetchall()
|
||||
actual_columns = len(result)
|
||||
|
||||
if actual_columns != expected_columns:
|
||||
# we add the new columns
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN toxicity real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN severe_toxicity real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN identity_attack real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN insult real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN profanity real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN threat real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN sexually_explicit real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN flirtation real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN obscene real")
|
||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN spam real")
|
||||
else:
|
||||
print("Table already has the correct number of columns")
|
||||
con_data.execute('CREATE TABLE IF NOT EXISTS setup_data (guild_id text, guild_settings text)')
|
||||
|
||||
# This code creates the model table if it does not exist
|
||||
curs_data.execute(
|
||||
|
||||
102
src/guild.py
Normal file
102
src/guild.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import orjson
|
||||
import discord
|
||||
|
||||
from src.utils.SqlConnector import sql
|
||||
from datetime import datetime
|
||||
from src.utils.variousclasses import models, characters
|
||||
|
||||
class Guild:
|
||||
def __init__(self, id: int):
|
||||
self.id = id
|
||||
self.load()
|
||||
|
||||
def getDbData(self):
|
||||
with sql.mainDb as con:
|
||||
curs_data = con.cursor()
|
||||
curs_data.execute("SELECT * FROM setup_data WHERE guild_id = ?", (self.id,))
|
||||
data = curs_data.fetchone()
|
||||
self.isInDb = data is not None
|
||||
if not self.isInDb:
|
||||
self.updateDbData()
|
||||
with sql.mainDb as con:
|
||||
curs_data = con.cursor()
|
||||
curs_data.execute("SELECT * FROM setup_data WHERE guild_id = ?", (self.id,))
|
||||
data = curs_data.fetchone()
|
||||
data = orjson.loads(data[1])
|
||||
self.premium = data["premium"]
|
||||
self.channels = data["channels"]
|
||||
self.api_keys = data["api_keys"]
|
||||
if self.premium:
|
||||
self.premium_expiration = datetime.fromisoformat(data.get("premium_expiration", None))
|
||||
self.checkPremiumExpires()
|
||||
else:
|
||||
self.premium_expiration = None
|
||||
|
||||
def checkPremiumExpires(self):
|
||||
if self.premium_expiration is None:
|
||||
self.premium = False
|
||||
return
|
||||
if self.premium_expiration < datetime.now():
|
||||
self.premium = False
|
||||
self.premium_expiration = None
|
||||
self.updateDbData()
|
||||
|
||||
def updateDbData(self):
|
||||
if self.isInDb:
|
||||
data = {
|
||||
"guild_id": self.id,
|
||||
"premium": self.premium,
|
||||
"channels": self.channels,
|
||||
"api_keys": self.api_keys,
|
||||
"premium_expiration": self.premium_expiration.isoformat() if self.premium else None,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"guild_id": self.id,
|
||||
"premium": False,
|
||||
"channels": {},
|
||||
"api_keys": {},
|
||||
"premium_expiration": None,
|
||||
}
|
||||
with sql.mainDb as con:
|
||||
curs_data = con.cursor()
|
||||
if self.isInDb:
|
||||
curs_data.execute("UPDATE setup_data SET guild_settings = ? WHERE guild_id = ?", (orjson.dumps(data), self.id))
|
||||
else:
|
||||
curs_data.execute("INSERT INTO setup_data (guild_id, guild_settings) VALUES (?, ?)", (self.id, orjson.dumps(data)))
|
||||
self.isInDb = True
|
||||
|
||||
def load(self):
|
||||
self.getDbData()
|
||||
|
||||
def addChannel(self, channel: discord.TextChannel, model: str, character: str):
|
||||
print(f"Adding channel {channel.id} to guild {self.id} with model {model} and character {character}")
|
||||
self.channels[str(channel.id)] = {
|
||||
"model": model,
|
||||
"character": character,
|
||||
}
|
||||
self.updateDbData()
|
||||
|
||||
def delChannel(self, channel: discord.TextChannel):
|
||||
del self.channels[str(channel.id)]
|
||||
self.updateDbData()
|
||||
|
||||
@property
|
||||
def sanitizedChannels(self) -> dict:
|
||||
if self.premium:
|
||||
return self.channels
|
||||
if len(self.channels) == 0:
|
||||
return {}
|
||||
return {
|
||||
list(self.channels.keys())[0]: {
|
||||
"model": models.matchingDict[models.default],
|
||||
"character": characters.matchingDict[characters.default],
|
||||
}
|
||||
}
|
||||
|
||||
def getChannelInfo(self, channel: str):
|
||||
return self.sanitizedChannels.get(channel, None)
|
||||
|
||||
def addApiKey(self, api: str, key: str):
|
||||
self.api_keys[api] = key
|
||||
self.updateDbData()
|
||||
26
src/utils/SqlConnector.py
Normal file
26
src/utils/SqlConnector.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlite3 import connect
|
||||
from random import randint
|
||||
|
||||
|
||||
class SQLConnection:
|
||||
|
||||
def __init__(self,connection):
|
||||
self.connection = connection
|
||||
|
||||
def __enter__(self):
|
||||
return self.connection
|
||||
|
||||
def __exit__(self,exc_type,exc_val,exc_tb):
|
||||
self.connection.commit()
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class _sql:
|
||||
|
||||
@property
|
||||
def mainDb(self):
|
||||
s = connect('./database/data.db')
|
||||
return SQLConnection(s)
|
||||
|
||||
|
||||
sql: _sql = _sql()
|
||||
37
src/utils/variousclasses.py
Normal file
37
src/utils/variousclasses.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from discord import AutocompleteContext
|
||||
|
||||
class models:
|
||||
matchingDict = {
|
||||
"chatGPT (default - free)": "gpt-3.5-turbo",
|
||||
"davinci (premium)": "text-davinci-003",
|
||||
"llama (premium)": "text-llama",
|
||||
"llama 2 (premium)": "text-llama-2",
|
||||
}
|
||||
reverseMatchingDict = {v: k for k, v in matchingDict.items()}
|
||||
default = list(matchingDict.keys())[0]
|
||||
openaimodels = ["gpt-3.5-turbo", "text-davinci-003"]
|
||||
@classmethod
|
||||
async def autocomplete(cls, ctx: AutocompleteContext) -> list[str]:
|
||||
modls = cls.matchingDict.keys()
|
||||
return [model for model in modls if model.find(ctx.value.lower()) != -1]
|
||||
|
||||
class characters:
|
||||
matchingDict = {
|
||||
"Botator (default - free)": "botator",
|
||||
"Aurora (premium)": "aurora",
|
||||
}
|
||||
reverseMatchingDict = {v: k for k, v in matchingDict.items()}
|
||||
default = list(matchingDict.keys())[0]
|
||||
@classmethod
|
||||
async def autocomplete(cls, ctx: AutocompleteContext) -> list[str]:
|
||||
chars = characters = cls.matchingDict.keys()
|
||||
return [character for character in chars if character.find(ctx.value.lower()) != -1]
|
||||
|
||||
class apis:
|
||||
matchingDict = {
|
||||
"OpenAI": "openai",
|
||||
}
|
||||
@classmethod
|
||||
async def autocomplete(cls, ctx: AutocompleteContext) -> list[str]:
|
||||
apiss = cls.matchingDict.keys()
|
||||
return [api for api in apiss if api.find(ctx.value.lower()) != -1]
|
||||
Reference in New Issue
Block a user