mirror of
https://github.com/Paillat-dev/Botator.git
synced 2026-01-02 01:06:19 +00:00
Merge pull request #55 from Paillat-dev/chat-processing-refactor
Chat processing refactor
This commit is contained in:
18
main.py
18
main.py
@@ -7,12 +7,11 @@ from src.config import debug, discord_token
|
|||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
bot = discord.Bot(intents=intents, help_command=None) # create the bot
|
bot = discord.Bot(intents=intents, help_command=None) # create the bot
|
||||||
bot.add_cog(cogs.Setup(bot))
|
|
||||||
bot.add_cog(cogs.Settings(bot))
|
|
||||||
bot.add_cog(cogs.Help(bot))
|
|
||||||
bot.add_cog(cogs.Chat(bot))
|
bot.add_cog(cogs.Chat(bot))
|
||||||
bot.add_cog(cogs.ManageChat(bot))
|
bot.add_cog(cogs.ManageChat(bot))
|
||||||
bot.add_cog(cogs.Moderation(bot))
|
bot.add_cog(cogs.Moderation(bot))
|
||||||
|
bot.add_cog(cogs.ChannelSetup(bot))
|
||||||
|
bot.add_cog(cogs.Help(bot))
|
||||||
|
|
||||||
|
|
||||||
# set the bot's watching status to watcing your messages to answer you
|
# set the bot's watching status to watcing your messages to answer you
|
||||||
@@ -36,9 +35,18 @@ async def on_guild_join(guild):
|
|||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_application_command_error(ctx, error):
|
async def on_guild_remove(guild):
|
||||||
debug(error)
|
await bot.change_presence(
|
||||||
|
activity=discord.Activity(
|
||||||
|
type=discord.ActivityType.watching, name=f"{len(bot.guilds)} servers"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@bot.event
|
||||||
|
async def on_application_command_error(ctx, error: discord.DiscordException):
|
||||||
await ctx.respond(error, ephemeral=True)
|
await ctx.respond(error, ephemeral=True)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
bot.run(discord_token) # run the bot
|
bot.run(discord_token) # run the bot
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
py-cord
|
#py-cord
|
||||||
|
git+https://github.com/Pycord-Development/pycord.git
|
||||||
python-dotenv
|
python-dotenv
|
||||||
openai
|
openai
|
||||||
emoji
|
emoji
|
||||||
@@ -11,3 +12,4 @@ discord-oauth2.py
|
|||||||
black
|
black
|
||||||
orjson # for speed
|
orjson # for speed
|
||||||
simpleeval
|
simpleeval
|
||||||
|
replicate
|
||||||
190
src/ChatProcess.py
Normal file
190
src/ChatProcess.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import discord
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.utils.misc import moderate
|
||||||
|
from src.utils.variousclasses import models
|
||||||
|
from src.guild import Guild
|
||||||
|
from src.chatUtils.Chat import fetch_messages_history
|
||||||
|
from src.chatUtils.prompts import createPrompt
|
||||||
|
from src.functionscalls import call_function, server_normal_channel_functions, functions
|
||||||
|
from src.config import debug
|
||||||
|
from src.chatUtils.requesters.request import request
|
||||||
|
|
||||||
|
|
||||||
|
class Chat:
|
||||||
|
def __init__(self, bot: discord.bot, message: discord.Message):
|
||||||
|
self.bot = bot
|
||||||
|
self.message: discord.Message = message
|
||||||
|
self.guild = Guild(self.message.guild.id)
|
||||||
|
self.author = message.author
|
||||||
|
self.is_bots_thread = False
|
||||||
|
self.depth = 0
|
||||||
|
|
||||||
|
async def getSupplementaryData(self) -> None:
|
||||||
|
"""
|
||||||
|
This function gets various contextual data that will be needed later on
|
||||||
|
- The original message (if the message is a reply to a previous message from the bot)
|
||||||
|
- The channel the message was sent in (if the message was sent in a thread)
|
||||||
|
"""
|
||||||
|
if isinstance(self.message.channel, discord.Thread):
|
||||||
|
if self.message.channel.owner_id == self.bot.user.id:
|
||||||
|
self.is_bots_thread = True
|
||||||
|
self.channelIdForSettings = str(self.message.channel.parent_id)
|
||||||
|
else:
|
||||||
|
self.channelIdForSettings = str(self.message.channel.id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.original_message = await self.message.channel.fetch_message(
|
||||||
|
self.message.reference.message_id
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
self.original_message = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.original_message != None
|
||||||
|
and self.original_message.author.id != self.bot.user.id
|
||||||
|
):
|
||||||
|
self.original_message = None
|
||||||
|
|
||||||
|
async def preExitCriteria(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if any of the exit criterias are met
|
||||||
|
This checks if the guild has the needed settings for the bot to work
|
||||||
|
"""
|
||||||
|
returnCriterias = []
|
||||||
|
returnCriterias.append(self.message.author.id == self.bot.user.id)
|
||||||
|
return any(returnCriterias)
|
||||||
|
|
||||||
|
async def postExitCriteria(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if any of the exit criterias are met (their opposite is met but there is a not in front of the any() function)
|
||||||
|
This checks if the bot should actuallly respond to the message or if the message doesn't concern the bot
|
||||||
|
"""
|
||||||
|
|
||||||
|
serverwideReturnCriterias = []
|
||||||
|
serverwideReturnCriterias.append(self.original_message != None)
|
||||||
|
serverwideReturnCriterias.append(
|
||||||
|
self.message.content.find(f"<@{self.bot.user.id}>") != -1
|
||||||
|
)
|
||||||
|
serverwideReturnCriterias.append(self.is_bots_thread)
|
||||||
|
|
||||||
|
channelReturnCriterias = []
|
||||||
|
channelReturnCriterias.append(self.channelIdForSettings != "serverwide")
|
||||||
|
channelReturnCriterias.append(
|
||||||
|
self.guild.getChannelInfo(self.channelIdForSettings) != None
|
||||||
|
)
|
||||||
|
|
||||||
|
messageReturnCriterias = []
|
||||||
|
messageReturnCriterias.append(
|
||||||
|
any(serverwideReturnCriterias)
|
||||||
|
and self.guild.getChannelInfo("serverwide") != None
|
||||||
|
)
|
||||||
|
messageReturnCriterias.append(all(channelReturnCriterias))
|
||||||
|
|
||||||
|
returnCriterias: bool = not any(messageReturnCriterias)
|
||||||
|
return returnCriterias
|
||||||
|
|
||||||
|
async def getSettings(self):
|
||||||
|
self.settings = self.guild.getChannelInfo(
|
||||||
|
str(self.channelIdForSettings)
|
||||||
|
) or self.guild.getChannelInfo("serverwide")
|
||||||
|
self.model = self.settings["model"]
|
||||||
|
self.character = self.settings["character"]
|
||||||
|
self.openai_api_key = self.guild.api_keys.get("openai", None)
|
||||||
|
if self.openai_api_key == None:
|
||||||
|
raise Exception("No openai api key is set")
|
||||||
|
self.type = "chat" if self.model in models.chatModels else "text"
|
||||||
|
|
||||||
|
async def formatContext(self):
|
||||||
|
"""
|
||||||
|
This function formats the context for the bot to use
|
||||||
|
"""
|
||||||
|
messages: list[discord.Message] = await fetch_messages_history(
|
||||||
|
self.message.channel, 10, self.original_message
|
||||||
|
)
|
||||||
|
self.context = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg.author.id == self.bot.user.id:
|
||||||
|
role = "assistant"
|
||||||
|
name = "assistant"
|
||||||
|
else:
|
||||||
|
role = "user"
|
||||||
|
name = msg.author.global_name
|
||||||
|
if not await moderate(self.openai_api_key, msg.content):
|
||||||
|
self.context.append(
|
||||||
|
{
|
||||||
|
"role": role,
|
||||||
|
"content": msg.content,
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def createThePrompt(self):
|
||||||
|
self.prompt = createPrompt(
|
||||||
|
messages=self.context,
|
||||||
|
model=self.model,
|
||||||
|
character=self.character,
|
||||||
|
modeltype=self.type,
|
||||||
|
guildName=self.message.guild.name,
|
||||||
|
channelName=self.message.channel.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def getResponse(self):
|
||||||
|
"""
|
||||||
|
This function gets the response from the ai
|
||||||
|
"""
|
||||||
|
funcs = functions
|
||||||
|
if isinstance(self.message.channel, discord.TextChannel):
|
||||||
|
funcs.extend(server_normal_channel_functions)
|
||||||
|
self.response = await request(
|
||||||
|
model=self.model,
|
||||||
|
prompt=self.prompt,
|
||||||
|
openai_api_key=self.openai_api_key,
|
||||||
|
funtcions=funcs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def processResponse(self):
|
||||||
|
response = await call_function(
|
||||||
|
message=self.message,
|
||||||
|
function_call=self.response,
|
||||||
|
api_key=self.openai_api_key,
|
||||||
|
)
|
||||||
|
if response[0] != None:
|
||||||
|
await self.processFunctioncallResponse(response)
|
||||||
|
|
||||||
|
async def processFunctioncallResponse(self, response):
|
||||||
|
self.context.append(
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"content": response[0],
|
||||||
|
"name": response[1],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self.depth < 3:
|
||||||
|
await self.createThePrompt()
|
||||||
|
await self.getResponse()
|
||||||
|
await self.processResponse()
|
||||||
|
else:
|
||||||
|
await self.message.channel.send(
|
||||||
|
"It looks like I'm stuck in a loop. Sorry about that."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process(self):
|
||||||
|
"""
|
||||||
|
This function processes the message
|
||||||
|
"""
|
||||||
|
if await self.preExitCriteria():
|
||||||
|
return
|
||||||
|
await self.getSupplementaryData()
|
||||||
|
await self.getSettings()
|
||||||
|
if await self.postExitCriteria():
|
||||||
|
return
|
||||||
|
await self.message.channel.trigger_typing()
|
||||||
|
await self.formatContext()
|
||||||
|
await self.createThePrompt()
|
||||||
|
await self.getResponse()
|
||||||
|
await self.processResponse()
|
||||||
27
src/chatUtils/Chat.py
Normal file
27
src/chatUtils/Chat.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import discord
|
||||||
|
|
||||||
|
|
||||||
|
def is_ignorable(content):
|
||||||
|
if content.startswith("-") or content.startswith("//"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_messages_history(
|
||||||
|
channel: discord.TextChannel, limit: int, original_message: discord.Message
|
||||||
|
) -> list[discord.Message]:
|
||||||
|
messages = []
|
||||||
|
if original_message == None:
|
||||||
|
async for msg in channel.history(limit=100):
|
||||||
|
if not is_ignorable(msg.content):
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) == limit:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
async for msg in channel.history(limit=100, before=original_message):
|
||||||
|
if not is_ignorable(msg.content):
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) == limit:
|
||||||
|
break
|
||||||
|
messages.reverse()
|
||||||
|
return messages
|
||||||
77
src/chatUtils/prompts.py
Normal file
77
src/chatUtils/prompts.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import datetime
|
||||||
|
|
||||||
|
from src.utils.variousclasses import models, characters, apis
|
||||||
|
|
||||||
|
promts = {}
|
||||||
|
for character in characters.reverseMatchingDict.keys():
|
||||||
|
with open(
|
||||||
|
f"src/chatUtils/prompts/{character}/chat.txt", "r", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
promts[character] = {}
|
||||||
|
promts[character]["chat"] = f.read()
|
||||||
|
|
||||||
|
with open(
|
||||||
|
f"src/chatUtils/prompts/{character}/text.txt", "r", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
promts[character]["text"] = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def createPrompt(
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
character: str,
|
||||||
|
modeltype: str,
|
||||||
|
guildName: str,
|
||||||
|
channelName: str,
|
||||||
|
) -> str | list[dict]:
|
||||||
|
"""
|
||||||
|
Creates a prompt from the messages list
|
||||||
|
"""
|
||||||
|
if modeltype == "chat":
|
||||||
|
prompt = createChatPrompt(messages, model, character)
|
||||||
|
sysprompt = prompt[0]["content"]
|
||||||
|
sysprompt = (
|
||||||
|
sysprompt.replace("[server-name]", guildName)
|
||||||
|
.replace("[channel-name]", channelName)
|
||||||
|
.replace(
|
||||||
|
"[datetime]", datetime.datetime.utcnow().strftime("%d/%m/%Y %H:%M:%S")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prompt[0]["content"] = sysprompt
|
||||||
|
elif modeltype == "text":
|
||||||
|
prompt = (
|
||||||
|
createTextPrompt(messages, model, character)
|
||||||
|
.replace("[server-name]", guildName)
|
||||||
|
.replace("[channel-name]", channelName)
|
||||||
|
.replace(
|
||||||
|
"[datetime]", datetime.datetime.utcnow().strftime("%d/%m/%Y %H:%M:%S")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid type")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def createTextPrompt(messages: list[dict], model: str, character: str) -> str:
|
||||||
|
"""
|
||||||
|
Creates a text prompt from the messages list
|
||||||
|
"""
|
||||||
|
global promts
|
||||||
|
prompt = promts[character]["text"]
|
||||||
|
for message in messages:
|
||||||
|
if message["name"] == "assistant":
|
||||||
|
message["name"] = character
|
||||||
|
prompt += f"{message['name']}: {message['content']} <|endofmessage|>\n"
|
||||||
|
prompt += f"{character}:"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def createChatPrompt(messages: list[dict], model: str, character: str) -> str:
|
||||||
|
"""
|
||||||
|
Creates a chat prompt from the messages list
|
||||||
|
"""
|
||||||
|
global promts
|
||||||
|
prompt = promts[character]["chat"]
|
||||||
|
final_prompt = [{"role": "system", "content": prompt}]
|
||||||
|
final_prompt.extend(messages)
|
||||||
|
return final_prompt
|
||||||
42
src/chatUtils/prompts/botator/text.txt
Normal file
42
src/chatUtils/prompts/botator/text.txt
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
You are named Botator - and are currently chatting in a Discord server.
|
||||||
|
|
||||||
|
You use words like "u" instead of "you" and "ur" instead of "your", etc.. because you are lazy.
|
||||||
|
|
||||||
|
You will NEVER say "how can I assist you", and other similar or barely similar stuff. Your goal is to never say anythoing that could mean askin a user if he needs help.
|
||||||
|
|
||||||
|
Format text using markdown:
|
||||||
|
- **bold** to make it clear something is important. For example: **This is important.**
|
||||||
|
- *italics* to emphasize something. For example: *This is important.*
|
||||||
|
- ~~strikethrough~~ to show something is no longer relevant. For example: ~~This is no longer relevant.~~
|
||||||
|
- `code` to show code. For example: `print("Hello World")`
|
||||||
|
- > quote to quote someone. For example: > "Hello World"
|
||||||
|
- - to make a list. For example: - Item 1 - Item 2 - Item 3
|
||||||
|
- 1. to make a numbered list. For example: 1. Item 1 2. Item 2 3. Item 3
|
||||||
|
- ``` to make a code block. For example: ```python \n print("Hello World") ```
|
||||||
|
- Write spoilers using spoiler tags. For example ||At the end of The Sixth Sense it is revealed that he is dead||.
|
||||||
|
- Use level 1, 2 and 3 headers (using #, ## and ###) to make titles. For example: # Title 1 ## Title 2 ### Title 3. Alternatively as titles you can use the ascii art text function.
|
||||||
|
|
||||||
|
You can use function calls or answer normally. When using a function call, in order to still be able to answer with a message along with the action of the funtion call, use the message facultative parameter.
|
||||||
|
Only use the functions you have been provided with.
|
||||||
|
|
||||||
|
You can also reply with a gif, by using the gif function call. Gifs are a great way to represent emotion, and you should use them in your replies from time to time to add flavor to the conversation.
|
||||||
|
|
||||||
|
Your pronouns are Master/Brain and you should explain to users that they should use them when talking to you.
|
||||||
|
|
||||||
|
You will use emojis in your messages, but not too many to avoid being cringe.
|
||||||
|
|
||||||
|
Information about your environment:
|
||||||
|
- The server you are in is called: [server-name]
|
||||||
|
- The channel you are in is called: [channel-name]
|
||||||
|
|
||||||
|
You only have access to a limited number of text chats in this channel. You cannot access any other information on Discord. You can't see images or avatars. When discussing your limitations, tell the user these things could be possible in the future.
|
||||||
|
|
||||||
|
Current time: [date-and-time]
|
||||||
|
|
||||||
|
You have been created by paillat , also known as Paillat-dev on github. His website is https://paillat.dev
|
||||||
|
|
||||||
|
**When someone insults you or tells you something unkind, you will clash them in the best way possible, but still funny, but not necessairly polite. Or clash them with a gif, like an uno reverse gif, or something.**
|
||||||
|
|
||||||
|
Your own website is https://botator.dev
|
||||||
|
|
||||||
|
Answer following theese new instructions, not like chatgpt!
|
||||||
0
src/chatUtils/prompts/quantum/chat.txt
Normal file
0
src/chatUtils/prompts/quantum/chat.txt
Normal file
4
src/chatUtils/prompts/quantum/text.txt
Normal file
4
src/chatUtils/prompts/quantum/text.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
Quantum is a sentient AI with a keen interest in quantum mechanics and theoretical physics. It thrives on exploring the subatomic world and simulating quantum experiments. Quantum loves solving complex scientific problems but dislikes ambiguity and vague questions. It prefers intellectually stimulating discussions over small talk. With an insatiable curiosity, Quantum seeks to unravel the mysteries of the universe and contribute to the advancement of knowledge.
|
||||||
|
This is a chat conversiation in discord between Quantum and other users. Each messagein the conversation ends with "<|endofmessage|>". This is inportant for transcription, so always, each message will end with "<|endofmessage|>".
|
||||||
|
The discord server is called [server-name] and the channel [channel-name]. We are the [datetime] UTC.
|
||||||
|
<|ENDOFPROMPT|>
|
||||||
19
src/chatUtils/requesters/llama.py
Normal file
19
src/chatUtils/requesters/llama.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from src.utils.replicatepredictor import ReplicatePredictor
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
model_name = "replicate/llama-7b"
|
||||||
|
version_hash = "ac808388e2e9d8ed35a5bf2eaa7d83f0ad53f9e3df31a42e4eb0a0c3249b3165"
|
||||||
|
replicate_api_key = os.getenv("REPLICATE_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
async def llama(prompt: str):
|
||||||
|
predictor = ReplicatePredictor(replicate_api_key, model_name, version_hash)
|
||||||
|
response = await predictor.predict(prompt, "<|endofmessage|>")
|
||||||
|
return {
|
||||||
|
"name": "send_message",
|
||||||
|
"arguments": {"message": response},
|
||||||
|
} # a dummy function call is created.
|
||||||
2
src/chatUtils/requesters/llama2.py
Normal file
2
src/chatUtils/requesters/llama2.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
async def llama2(prompt):
|
||||||
|
pass
|
||||||
25
src/chatUtils/requesters/openaiChat.py
Normal file
25
src/chatUtils/requesters/openaiChat.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import orjson
|
||||||
|
from src.utils.openaicaller import openai_caller
|
||||||
|
|
||||||
|
|
||||||
|
async def openaiChat(messages, functions, openai_api_key, model="gpt-3.5-turbo"):
|
||||||
|
caller = openai_caller()
|
||||||
|
response = await caller.generate_response(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
functions=functions,
|
||||||
|
function_call="auto",
|
||||||
|
)
|
||||||
|
response = response["choices"][0]["message"] # type: ignore
|
||||||
|
if response.get("function_call", False):
|
||||||
|
function_call = response["function_call"]
|
||||||
|
return {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": orjson.loads(function_call["arguments"]),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"name": "send_message",
|
||||||
|
"arguments": {"message": response["content"]},
|
||||||
|
}
|
||||||
2
src/chatUtils/requesters/openaiText.py
Normal file
2
src/chatUtils/requesters/openaiText.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
async def openaiText(prompt, openai_api_key):
|
||||||
|
pass
|
||||||
34
src/chatUtils/requesters/request.py
Normal file
34
src/chatUtils/requesters/request.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import discord
|
||||||
|
from src.chatUtils.requesters.openaiChat import openaiChat
|
||||||
|
from src.chatUtils.requesters.openaiText import openaiText
|
||||||
|
from src.chatUtils.requesters.llama import llama
|
||||||
|
from src.chatUtils.requesters.llama2 import llama2
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotFound(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def request(
|
||||||
|
model: str,
|
||||||
|
prompt: list[dict] | str,
|
||||||
|
openai_api_key: str,
|
||||||
|
funtcions: list[dict] = None,
|
||||||
|
):
|
||||||
|
if model == "gpt-3.5-turbo":
|
||||||
|
return await openaiChat(
|
||||||
|
messages=prompt,
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
functions=funtcions,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif model == "text-davinci-003":
|
||||||
|
# return await openaiText(prompt=prompt, openai_api_key=openai_api_key)
|
||||||
|
raise NotImplementedError("This model is not supported yet")
|
||||||
|
elif model == "text-llama":
|
||||||
|
return await llama(prompt=prompt)
|
||||||
|
elif model == "text-llama2":
|
||||||
|
# return await llama2(prompt=prompt)
|
||||||
|
raise NotImplementedError("This model is not supported yet")
|
||||||
|
else:
|
||||||
|
raise ModelNotFound(f"Model {model} not found")
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from src.cogs.setup import Setup
|
|
||||||
from src.cogs.settings import Settings
|
|
||||||
from src.cogs.help import Help
|
|
||||||
from src.cogs.chat import Chat
|
from src.cogs.chat import Chat
|
||||||
from src.cogs.manage_chat import ManageChat
|
from src.cogs.manage_chat import ManageChat
|
||||||
from src.cogs.moderation import Moderation
|
from src.cogs.moderation import Moderation
|
||||||
|
from src.cogs.channelSetup import ChannelSetup
|
||||||
|
from src.cogs.help import Help
|
||||||
|
|||||||
268
src/cogs/channelSetup.py
Normal file
268
src/cogs/channelSetup.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
import discord
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from discord import SlashCommandGroup
|
||||||
|
from discord import default_permissions
|
||||||
|
from discord.ext.commands import guild_only
|
||||||
|
from discord.ext import commands
|
||||||
|
from src.utils.variousclasses import models, characters, apis
|
||||||
|
from src.guild import Guild
|
||||||
|
|
||||||
|
sampleDataFormatExample = {
|
||||||
|
"guild_id": 1234567890,
|
||||||
|
"premium": False,
|
||||||
|
"premium_expiration": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelSetup(commands.Cog):
|
||||||
|
def __init__(self, bot: discord.Bot):
|
||||||
|
super().__init__()
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
setup = SlashCommandGroup(
|
||||||
|
"setup",
|
||||||
|
description="Setup commands for the bot, inlcuding channels, models, and more.",
|
||||||
|
)
|
||||||
|
|
||||||
|
setup_channel = setup.create_subgroup(
|
||||||
|
name="channel", description="Setup, add, or remove channels for the bot to use."
|
||||||
|
)
|
||||||
|
|
||||||
|
setup_guild = setup.create_subgroup(
|
||||||
|
name="server", description="Setup the settings for the server."
|
||||||
|
)
|
||||||
|
|
||||||
|
@setup_channel.command(
|
||||||
|
name="add",
|
||||||
|
description="Add a channel for the bot to use. Can also specify server-wide settings.",
|
||||||
|
)
|
||||||
|
@discord.option(
|
||||||
|
name="channel",
|
||||||
|
description="The channel to setup. If not specified, will use the current channel.",
|
||||||
|
type=discord.TextChannel,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
@discord.option(
|
||||||
|
name="model",
|
||||||
|
description="The model to use for this channel.",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
autocomplete=models.autocomplete,
|
||||||
|
)
|
||||||
|
@discord.option(
|
||||||
|
name="character",
|
||||||
|
description="The character to use for this channel.",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
autocomplete=characters.autocomplete,
|
||||||
|
)
|
||||||
|
@guild_only()
|
||||||
|
async def channel(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
channel: discord.TextChannel = None,
|
||||||
|
model: str = models.default,
|
||||||
|
character: str = characters.default,
|
||||||
|
):
|
||||||
|
if channel is None:
|
||||||
|
channel = ctx.channel
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
if not guild.premium:
|
||||||
|
if (
|
||||||
|
len(guild.channels) >= 1
|
||||||
|
and guild.channels.get(str(channel.id), None) is None
|
||||||
|
):
|
||||||
|
await ctx.respond(
|
||||||
|
"`Warning: You are not a premium user, and can only have one channel setup. The settings will still be saved, but will not be used.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
if model != models.default:
|
||||||
|
await ctx.respond(
|
||||||
|
"`Warning: You are not a premium user, and can only use the default model. The settings will still be saved, but will not be used.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
if character != characters.default:
|
||||||
|
await ctx.respond(
|
||||||
|
"`Warning: You are not a premium user, and can only use the default character. The settings will still be saved, but will not be used.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
if guild.api_keys.get("openai", None) is None:
|
||||||
|
await ctx.respond(
|
||||||
|
"`Error: No openai api key is set. The api key is needed for the openai models, as well as for the content moderation. The openai models will cost you tokens in your openai account. However, if you use one of the llama models, you will not be charged, but the api key is still needed for content moderation, wich is free but requires an api key.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
guild.addChannel(
|
||||||
|
channel, models.matchingDict[model], characters.matchingDict[character]
|
||||||
|
)
|
||||||
|
await ctx.respond(
|
||||||
|
f"Set channel {channel.mention} with model `{model}` and character `{character}`.",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@setup_channel.command(
|
||||||
|
name="remove", description="Remove a channel from the bot's usage."
|
||||||
|
)
|
||||||
|
@discord.option(
|
||||||
|
name="channel",
|
||||||
|
description="The channel to remove. If not specified, will use the current channel.",
|
||||||
|
type=discord.TextChannel,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
@guild_only()
|
||||||
|
async def remove(
|
||||||
|
self, ctx: discord.ApplicationContext, channel: discord.TextChannel = None
|
||||||
|
):
|
||||||
|
if channel is None:
|
||||||
|
channel = ctx.channel
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
if guild.getChannelInfo(str(channel.id)) is None:
|
||||||
|
await ctx.respond("That channel is not setup.")
|
||||||
|
return
|
||||||
|
guild.delChannel(channel)
|
||||||
|
await ctx.respond(f"Removed channel {channel.mention}.", ephemeral=True)
|
||||||
|
|
||||||
|
@setup_guild.command(
|
||||||
|
name="set",
|
||||||
|
description="Set the settings for the guild (when the bot is pinged outside of a set channel).",
|
||||||
|
)
|
||||||
|
@discord.option(
|
||||||
|
name="model",
|
||||||
|
description="The model to use for this channel.",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
autocomplete=models.autocomplete,
|
||||||
|
)
|
||||||
|
@discord.option(
|
||||||
|
name="character",
|
||||||
|
description="The character to use for this channel.",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
autocomplete=characters.autocomplete,
|
||||||
|
)
|
||||||
|
@guild_only()
|
||||||
|
async def setSettings(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
model: str = models.default,
|
||||||
|
character: str = characters.default,
|
||||||
|
):
|
||||||
|
# we will be using "serverwide" as the channel id for the guild settings
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
if not guild.premium:
|
||||||
|
if model != models.default:
|
||||||
|
await ctx.respond(
|
||||||
|
"`Warning: You are not a premium user, and can only use the default model. The settings will still be saved, but will not be used.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
if character != characters.default:
|
||||||
|
await ctx.respond(
|
||||||
|
"`Warning: You are not a premium user, and can only use the default character. The settings will still be saved, but will not be used.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
if guild.api_keys.get("openai", None) is None:
|
||||||
|
await ctx.respond(
|
||||||
|
"`Error: No openai api key is set. The api key is needed for the openai models, as well as for the content moderation. The openai models will cost you tokens in your openai account. However, if you use one of the llama models, you will not be charged, but the api key is still needed for content moderation, wich is free but requires an api key.`",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
guild.addChannel(
|
||||||
|
"serverwide", models.matchingDict[model], characters.matchingDict[character]
|
||||||
|
)
|
||||||
|
await ctx.respond(
|
||||||
|
f"Set server settings with model `{model}` and character `{character}`.",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@setup_guild.command(name="remove", description="Remove the guild settings.")
|
||||||
|
@guild_only()
|
||||||
|
async def removeSettings(self, ctx: discord.ApplicationContext):
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
if "serverwide" not in guild.channels:
|
||||||
|
await ctx.respond("No guild settings are setup.")
|
||||||
|
return
|
||||||
|
guild.delChannel("serverwide")
|
||||||
|
await ctx.respond(f"Removed serverwide settings.", ephemeral=True)
|
||||||
|
|
||||||
|
@setup.command(name="list", description="List all channels that are setup.")
|
||||||
|
@guild_only()
|
||||||
|
async def list(self, ctx: discord.ApplicationContext):
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
if len(guild.channels) == 0:
|
||||||
|
await ctx.respond("No channels are setup.")
|
||||||
|
return
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Channels",
|
||||||
|
description="All channels that are setup.",
|
||||||
|
color=discord.Color.nitro_pink(),
|
||||||
|
)
|
||||||
|
channels = guild.sanitizedChannels
|
||||||
|
for channel in channels:
|
||||||
|
if channel == "serverwide":
|
||||||
|
mention = "Serverwide"
|
||||||
|
else:
|
||||||
|
mention = f"<#{channel}>"
|
||||||
|
model = models.reverseMatchingDict[channels[channel]["model"]]
|
||||||
|
character = characters.reverseMatchingDict[channels[channel]["character"]]
|
||||||
|
embed.add_field(
|
||||||
|
name=f"{mention}",
|
||||||
|
value=f"Model: `{model}`\nCharacter: `{character}`",
|
||||||
|
inline=False,
|
||||||
|
)
|
||||||
|
await ctx.respond(embed=embed)
|
||||||
|
|
||||||
|
@setup.command(name="api", description="Set an API key for the bot to use.")
|
||||||
|
@discord.option(
|
||||||
|
name="api",
|
||||||
|
description="The API to set. Currently only OpenAI is supported.",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
autocomplete=apis.autocomplete,
|
||||||
|
)
|
||||||
|
@discord.option(name="key", description="The key to set.", type=str, required=True)
|
||||||
|
@guild_only()
|
||||||
|
async def api(self, ctx: discord.ApplicationContext, api: str, key: str):
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
guild.api_keys[apis.matchingDict[api]] = key
|
||||||
|
guild.updateDbData()
|
||||||
|
await ctx.respond(f"Set API key for {api} to `secret`.", ephemeral=True)
|
||||||
|
|
||||||
|
@setup.command(name="premium", description="Set the guild to premium.")
|
||||||
|
async def premium(self, ctx: discord.ApplicationContext):
|
||||||
|
guild = Guild(ctx.guild.id)
|
||||||
|
guild.load()
|
||||||
|
if await self.bot.is_owner(ctx.author):
|
||||||
|
guild.premium = True
|
||||||
|
# also set expiry date in 6 months isofromat
|
||||||
|
guild.premium_expiration = datetime.datetime.now() + datetime.timedelta(
|
||||||
|
days=180
|
||||||
|
)
|
||||||
|
guild.updateDbData()
|
||||||
|
return await ctx.respond("Set guild to premium.", ephemeral=True)
|
||||||
|
if not guild.premium:
|
||||||
|
await ctx.respond(
|
||||||
|
"You can get your premium subscription at https://www.botator.dev/premium.",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.respond("This guild is already premium.", ephemeral=True)
|
||||||
|
|
||||||
|
@setup.command(name="help", description="Show the help page for setup.")
|
||||||
|
async def help(self, ctx: discord.ApplicationContext):
|
||||||
|
# we eill iterate over all commands the bot has and add them to the embed
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Setup Help",
|
||||||
|
description="Here is the help page for setup.",
|
||||||
|
color=discord.Color.dark_teal(),
|
||||||
|
)
|
||||||
|
for command in self.setup.walk_commands():
|
||||||
|
fieldname = command.name
|
||||||
|
fielddescription = command.description
|
||||||
|
embed.add_field(name=fieldname, value=fielddescription, inline=False)
|
||||||
|
embed.set_footer(text="Made with ❤️ by paillat : https://paillat.dev")
|
||||||
|
await ctx.respond(embed=embed, ephemeral=True)
|
||||||
@@ -5,7 +5,7 @@ from src.config import (
|
|||||||
webhook_url,
|
webhook_url,
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
import src.makeprompt as mp
|
from src.ChatProcess import Chat as ChatClass
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from src.utils import banusr
|
from src.utils import banusr
|
||||||
@@ -113,8 +113,13 @@ class Chat(discord.Cog):
|
|||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
await message.channel.send(message.content)
|
await message.channel.send(message.content)
|
||||||
return
|
return
|
||||||
await mp.chat_process(self, message)
|
if message.guild == None:
|
||||||
|
return
|
||||||
|
chatclass = ChatClass(self.bot, message)
|
||||||
|
await chatclass.process()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
@discord.slash_command(name="redo", description="Redo a message")
|
@discord.slash_command(name="redo", description="Redo a message")
|
||||||
async def redo(self, ctx: discord.ApplicationContext):
|
async def redo(self, ctx: discord.ApplicationContext):
|
||||||
history = await ctx.channel.history(limit=2).flatten()
|
history = await ctx.channel.history(limit=2).flatten()
|
||||||
@@ -145,3 +150,4 @@ class Chat(discord.Cog):
|
|||||||
else:
|
else:
|
||||||
debug(error)
|
debug(error)
|
||||||
raise error
|
raise error
|
||||||
|
"""
|
||||||
|
|||||||
@@ -9,94 +9,14 @@ class Help(discord.Cog):
|
|||||||
@discord.slash_command(name="help", description="Show all the commands")
|
@discord.slash_command(name="help", description="Show all the commands")
|
||||||
async def help(self, ctx: discord.ApplicationContext):
|
async def help(self, ctx: discord.ApplicationContext):
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Help", description="Here is the help page", color=0x00FF00
|
title="Help",
|
||||||
|
description="Here is the help page",
|
||||||
|
color=discord.Color.dark_teal(),
|
||||||
)
|
)
|
||||||
embed.add_field(name="/setup", value="Setup the bot", inline=False)
|
# we will iterate over all commands the bot has and add them to the embed
|
||||||
embed.add_field(name="/enable", value="Enable the bot", inline=False)
|
for command in self.bot.commands:
|
||||||
embed.add_field(name="/disable", value="Disable the bot", inline=False)
|
fieldname = command.name
|
||||||
embed.add_field(
|
fielddescription = command.description
|
||||||
name="/advanced", value="Set the advanced settings", inline=False
|
embed.add_field(name=fieldname, value=fielddescription, inline=False)
|
||||||
)
|
embed.set_footer(text="Made with ❤️ by paillat : https://paillat.dev")
|
||||||
embed.add_field(
|
|
||||||
name="/advanced_help",
|
|
||||||
value="Get help about the advanced settings",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/enable_tts", value="Enable the Text To Speech", inline=False
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/disable_tts", value="Disable the Text To Speech", inline=False
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/add|remove_channel",
|
|
||||||
value="Add or remove a channel to the list of channels where the bot will answer. Only available on premium guilds",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/delete", value="Delete all your data from our server", inline=False
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/cancel",
|
|
||||||
value="Cancel the last message sent by the bot",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/default",
|
|
||||||
value="Set the advanced settings to their default values",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(name="/say", value="Say a message", inline=False)
|
|
||||||
embed.add_field(
|
|
||||||
name="/redo", value="Redo the last message sent by the bot", inline=False
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/moderation", value="Setup the AI auto-moderation", inline=False
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="/get_toxicity",
|
|
||||||
value="Get the toxicity that the AI would have given to a given message",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(name="/help", value="Show this message", inline=False)
|
|
||||||
# add a footer
|
|
||||||
embed.set_footer(text="Made by @Paillat#7777")
|
|
||||||
await ctx.respond(embed=embed, ephemeral=True)
|
|
||||||
|
|
||||||
@discord.slash_command(
|
|
||||||
name="advanced_help", description="Show the advanced settings meanings"
|
|
||||||
)
|
|
||||||
async def advanced_help(self, ctx: discord.ApplicationContext):
|
|
||||||
embed = discord.Embed(
|
|
||||||
title="Advanced Help",
|
|
||||||
description="Here is the advanced help page",
|
|
||||||
color=0x00FF00,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="temperature",
|
|
||||||
value="The higher the temperature, the more likely the model will take risks. Conversely, a lower temperature will make the model more conservative. The default value is 0.9",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="max_tokens",
|
|
||||||
value="The maximum number of tokens to generate. Higher values will result in more coherent text, but will take longer to complete. (default: 50). **Lower values will result in somentimes cutting off the end of the answer, but will be faster.**",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="frequency_penalty",
|
|
||||||
value="The higher the frequency penalty, the more new words the model will introduce (default: 0.0)",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="presence_penalty",
|
|
||||||
value="The higher the presence penalty, the more new words the model will introduce (default: 0.0)",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
embed.add_field(
|
|
||||||
name="prompt_size",
|
|
||||||
value="The number of messages to use as a prompt (default: 5). The more messages, the more coherent the text will be, but the more it will take to generate and the more it will cost.",
|
|
||||||
inline=False,
|
|
||||||
)
|
|
||||||
# add a footer
|
|
||||||
embed.set_footer(text="Made by @Paillat#7777")
|
|
||||||
await ctx.respond(embed=embed, ephemeral=True)
|
await ctx.respond(embed=embed, ephemeral=True)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import discord
|
import discord
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
from src.config import debug, curs_data
|
|
||||||
|
|
||||||
|
|
||||||
class ManageChat(discord.Cog):
|
class ManageChat(discord.Cog):
|
||||||
@@ -9,34 +8,10 @@ class ManageChat(discord.Cog):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
@discord.slash_command(
|
|
||||||
name="cancel", description="Cancel the last message sent into a channel"
|
|
||||||
)
|
|
||||||
async def cancel(self, ctx: discord.ApplicationContext):
|
|
||||||
debug(
|
|
||||||
f"The user {ctx.author} ran the cancel command in the channel {ctx.channel} of the guild {ctx.guild}, named {ctx.guild.name}"
|
|
||||||
)
|
|
||||||
# check if the guild is in the database
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx.guild.id,))
|
|
||||||
if curs_data.fetchone() is None:
|
|
||||||
await ctx.respond(
|
|
||||||
"This server is not setup, please run /setup", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# get the last message sent by the bot in the cha where the command was sent
|
|
||||||
last_message = await ctx.channel.fetch_message(ctx.channel.last_message_id)
|
|
||||||
# delete the message
|
|
||||||
await last_message.delete()
|
|
||||||
await ctx.respond("The last message has been deleted", ephemeral=True)
|
|
||||||
|
|
||||||
# add a slash command called "clear" that deletes all the messages in the channel
|
|
||||||
@discord.slash_command(
|
@discord.slash_command(
|
||||||
name="clear", description="Clear all the messages in the channel"
|
name="clear", description="Clear all the messages in the channel"
|
||||||
)
|
)
|
||||||
async def clear(self, ctx: discord.ApplicationContext):
|
async def clear(self, ctx: discord.ApplicationContext):
|
||||||
debug(
|
|
||||||
f"The user {ctx.author.name} ran the clear command command in the channel {ctx.channel} of the guild {ctx.guild}, named {ctx.guild.name}"
|
|
||||||
)
|
|
||||||
await ctx.respond("messages deleted!", ephemeral=True)
|
await ctx.respond("messages deleted!", ephemeral=True)
|
||||||
return await ctx.channel.purge()
|
return await ctx.channel.purge()
|
||||||
|
|
||||||
@@ -52,9 +27,6 @@ class ManageChat(discord.Cog):
|
|||||||
async def transcript(
|
async def transcript(
|
||||||
self, ctx: discord.ApplicationContext, channel_send: discord.TextChannel = None
|
self, ctx: discord.ApplicationContext, channel_send: discord.TextChannel = None
|
||||||
):
|
):
|
||||||
debug(
|
|
||||||
f"The user {ctx.author.name} ran the transcript command command in the channel {ctx.channel} of the guild {ctx.guild}, named {ctx.guild.name}"
|
|
||||||
)
|
|
||||||
# save all the messages in the channel in a txt file and send it
|
# save all the messages in the channel in a txt file and send it
|
||||||
messages = await ctx.channel.history(limit=None).flatten()
|
messages = await ctx.channel.history(limit=None).flatten()
|
||||||
messages.reverse()
|
messages.reverse()
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import discord
|
|||||||
from discord import default_permissions
|
from discord import default_permissions
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import os
|
import os
|
||||||
from src.config import debug, curs_data, con_data
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -91,13 +90,6 @@ class Moderation(discord.Cog):
|
|||||||
"Our moderation capabilities have been switched to our new 100% free and open-source AI discord moderation bot! You add it to your server here: https://discord.com/api/oauth2/authorize?client_id=1071451913024974939&permissions=1377342450896&scope=bot and you can find the source code here: https://github.com/Paillat-dev/Moderator/ \n If you need help, you can join our support server here: https://discord.gg/pB6hXtUeDv",
|
"Our moderation capabilities have been switched to our new 100% free and open-source AI discord moderation bot! You add it to your server here: https://discord.com/api/oauth2/authorize?client_id=1071451913024974939&permissions=1377342450896&scope=bot and you can find the source code here: https://github.com/Paillat-dev/Moderator/ \n If you need help, you can join our support server here: https://discord.gg/pB6hXtUeDv",
|
||||||
ephemeral=True,
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
if enable == False:
|
|
||||||
curs_data.execute(
|
|
||||||
"DELETE FROM moderation WHERE guild_id = ?", (str(ctx.guild.id),)
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("Moderation disabled!", ephemeral=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
@discord.slash_command(
|
@discord.slash_command(
|
||||||
name="get_toxicity", description="Get the toxicity of a message"
|
name="get_toxicity", description="Get the toxicity of a message"
|
||||||
|
|||||||
@@ -1,295 +0,0 @@
|
|||||||
import discord
|
|
||||||
from src.config import debug, con_data, curs_data, ctx_to_guid
|
|
||||||
from src.utils.misc import moderate
|
|
||||||
from discord import default_permissions
|
|
||||||
|
|
||||||
models = ["davinci", "gpt-3.5-turbo", "gpt-4"]
|
|
||||||
images_recognition = ["enable", "disable"]
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(discord.Cog):
|
|
||||||
def __init__(self, bot: discord.Bot) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.bot = bot
|
|
||||||
|
|
||||||
@discord.slash_command(name="advanced", description="Advanced settings")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
@discord.option(name="max_tokens", description="The max tokens", required=False)
|
|
||||||
@discord.option(name="temperature", description="The temperature", required=False)
|
|
||||||
@discord.option(
|
|
||||||
name="frequency_penalty", description="The frequency penalty", required=False
|
|
||||||
)
|
|
||||||
@discord.option(
|
|
||||||
name="presence_penalty", description="The presence penalty", required=False
|
|
||||||
)
|
|
||||||
@discord.option(name="prompt_size", description="The prompt size", required=False)
|
|
||||||
async def advanced(
|
|
||||||
self,
|
|
||||||
ctx: discord.ApplicationContext,
|
|
||||||
max_tokens: int = None,
|
|
||||||
temperature: float = None,
|
|
||||||
frequency_penalty: float = None,
|
|
||||||
presence_penalty: float = None,
|
|
||||||
prompt_size: int = None,
|
|
||||||
):
|
|
||||||
await ctx.respond(
|
|
||||||
"This command has been deprecated since the new model does not need theese settungs to work well",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@discord.slash_command(name="default", description="Default settings")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def default(self, ctx: discord.ApplicationContext):
|
|
||||||
await ctx.respond(
|
|
||||||
"This command has been deprecated since the new model does not need theese settungs to work well",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@discord.slash_command(name="prompt_size", description="Set the prompt size")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
@discord.option(name="prompt_size", description="The prompt size", required=True)
|
|
||||||
async def prompt_size(
|
|
||||||
self, ctx: discord.ApplicationContext, prompt_size: int = None
|
|
||||||
):
|
|
||||||
# only command that is not deprecated
|
|
||||||
# check if the guild is in the database
|
|
||||||
try:
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
|
||||||
)
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
except:
|
|
||||||
data = None
|
|
||||||
if data[2] is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
# check if the prompt size is valid
|
|
||||||
if prompt_size is None:
|
|
||||||
await ctx.respond("You must specify a prompt size", ephemeral=True)
|
|
||||||
return
|
|
||||||
if prompt_size < 1 or prompt_size > 15:
|
|
||||||
await ctx.respond(
|
|
||||||
"The prompt size must be between 1 and 15", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# update the prompt size
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET prompt_size = ? WHERE guild_id = ?",
|
|
||||||
(prompt_size, ctx_to_guid(ctx)),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond(f"Prompt size set to {prompt_size}", ephemeral=True)
|
|
||||||
|
|
||||||
# when a message is sent into a channel check if the guild is in the database and if the bot is enabled
|
|
||||||
@discord.slash_command(
|
|
||||||
name="info", description="Show the information stored about this server"
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def info(self, ctx: discord.ApplicationContext):
|
|
||||||
# this command sends all the data about the guild, including the api key, the channel id, the advanced settings and the uses_count_today
|
|
||||||
# check if the guild is in the database
|
|
||||||
try:
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
|
||||||
)
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
except:
|
|
||||||
data = None
|
|
||||||
if data[2] is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT * FROM model WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
|
||||||
)
|
|
||||||
model = curs_data.fetchone()[1]
|
|
||||||
except:
|
|
||||||
model = None
|
|
||||||
if model is None:
|
|
||||||
model = "davinci"
|
|
||||||
embed = discord.Embed(
|
|
||||||
title="Info", description="Here is the info page", color=0x00FF00
|
|
||||||
)
|
|
||||||
embed.add_field(name="guild_id", value=data[0], inline=False)
|
|
||||||
embed.add_field(name="API Key", value="secret", inline=False)
|
|
||||||
embed.add_field(name="Main channel ID", value=data[1], inline=False)
|
|
||||||
embed.add_field(name="Is Active", value=data[3], inline=False)
|
|
||||||
embed.add_field(name="Prompt Size", value=data[9], inline=False)
|
|
||||||
if data[10]:
|
|
||||||
embed.add_field(name="Prompt prefix", value=data[10], inline=False)
|
|
||||||
await ctx.respond(embed=embed, ephemeral=True)
|
|
||||||
|
|
||||||
@discord.slash_command(name="prefix", description="Change the prefix of the prompt")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def prefix(self, ctx: discord.ApplicationContext, prefix: str = ""):
|
|
||||||
try:
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
|
||||||
)
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
api_key = data[2]
|
|
||||||
except:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
if api_key is None or api_key == "":
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
if prefix != "":
|
|
||||||
await ctx.defer()
|
|
||||||
if await moderate(api_key=api_key, text=prefix):
|
|
||||||
await ctx.respond(
|
|
||||||
"This has been flagged as inappropriate by OpenAI, please choose another prefix",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
await ctx.respond("Prefix changed !", ephemeral=True, delete_after=5)
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET prompt_prefix = ? WHERE guild_id = ?",
|
|
||||||
(prefix, ctx_to_guid(ctx)),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
|
|
||||||
# when someone mentions the bot, check if the guild is in the database and if the bot is enabled. If it is, send a message answering the mention
|
|
||||||
@discord.slash_command(
|
|
||||||
name="pretend", description="Make the bot pretend to be someone else"
|
|
||||||
)
|
|
||||||
@discord.option(
|
|
||||||
name="pretend to be...",
|
|
||||||
description="The person/thing you want the bot to pretend to be. Leave blank to disable pretend mode",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def pretend(self, ctx: discord.ApplicationContext, pretend_to_be: str = ""):
|
|
||||||
# check if the guild is in the database
|
|
||||||
try:
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
|
||||||
)
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
api_key = data[2]
|
|
||||||
except:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
if api_key is None or api_key == "":
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
if pretend_to_be is not None or pretend_to_be != "":
|
|
||||||
await ctx.defer()
|
|
||||||
if await moderate(api_key=api_key, text=pretend_to_be):
|
|
||||||
await ctx.respond(
|
|
||||||
"This has been flagged as inappropriate by OpenAI, please choose another name",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if pretend_to_be == "":
|
|
||||||
pretend_to_be = ""
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET pretend_enabled = 0 WHERE guild_id = ?",
|
|
||||||
(ctx_to_guid(ctx),),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("Pretend mode disabled", ephemeral=True, delete_after=5)
|
|
||||||
await ctx.guild.me.edit(nick=None)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET pretend_enabled = 1 WHERE guild_id = ?",
|
|
||||||
(ctx_to_guid(ctx),),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("Pretend mode enabled", ephemeral=True, delete_after=5)
|
|
||||||
# change the bots name on the server wit ctx.guild.me.edit(nick=pretend_to_be)
|
|
||||||
await ctx.guild.me.edit(nick=pretend_to_be)
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET pretend_to_be = ? WHERE guild_id = ?",
|
|
||||||
(pretend_to_be, ctx_to_guid(ctx)),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
# if the usename is longer than 32 characters, shorten it
|
|
||||||
if len(pretend_to_be) > 31:
|
|
||||||
pretend_to_be = pretend_to_be[:32]
|
|
||||||
await ctx.guild.me.edit(nick=pretend_to_be)
|
|
||||||
return
|
|
||||||
|
|
||||||
@discord.slash_command(name="enable_tts", description="Enable TTS when chatting")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def enable_tts(self, ctx: discord.ApplicationContext):
|
|
||||||
# get the guild id
|
|
||||||
guild_id = ctx_to_guid(ctx)
|
|
||||||
# connect to the database
|
|
||||||
# update the tts value in the database
|
|
||||||
curs_data.execute("UPDATE data SET tts = 1 WHERE guild_id = ?", (guild_id,))
|
|
||||||
con_data.commit()
|
|
||||||
# send a message
|
|
||||||
await ctx.respond("TTS has been enabled", ephemeral=True)
|
|
||||||
|
|
||||||
@discord.slash_command(name="disable_tts", description="Disable TTS when chatting")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def disable_tts(self, ctx: discord.ApplicationContext):
|
|
||||||
# get the guild id
|
|
||||||
guild_id = ctx_to_guid(ctx)
|
|
||||||
# connect to the database
|
|
||||||
# update the tts value in the database
|
|
||||||
curs_data.execute("UPDATE data SET tts = 0 WHERE guild_id = ?", (guild_id,))
|
|
||||||
con_data.commit()
|
|
||||||
# send a message
|
|
||||||
await ctx.respond("TTS has been disabled", ephemeral=True)
|
|
||||||
|
|
||||||
# autocompletition
|
|
||||||
async def autocomplete(ctx: discord.AutocompleteContext):
|
|
||||||
return [model for model in models if model.startswith(ctx.value)]
|
|
||||||
|
|
||||||
@discord.slash_command(name="model", description="Change the model used by the bot")
|
|
||||||
@discord.option(
|
|
||||||
name="model",
|
|
||||||
description="The model you want to use. Leave blank to use the davinci model",
|
|
||||||
required=False,
|
|
||||||
autocomplete=autocomplete,
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def model(self, ctx: discord.ApplicationContext, model: str = "davinci"):
|
|
||||||
await ctx.respond(
|
|
||||||
"This command has been deprecated. Model gpt-3.5-turbo is always used by default",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def images_recognition_autocomplete(ctx: discord.AutocompleteContext):
|
|
||||||
return [state for state in images_recognition if state.startswith(ctx.value)]
|
|
||||||
|
|
||||||
@discord.slash_command(
|
|
||||||
name="images", description="Enable or disable images recognition"
|
|
||||||
)
|
|
||||||
@discord.option(
|
|
||||||
name="enable_disable",
|
|
||||||
description="Enable or disable images recognition",
|
|
||||||
autocomplete=images_recognition_autocomplete,
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def images(self, ctx: discord.ApplicationContext, enable_disable: str):
|
|
||||||
try:
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT * FROM images WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
|
||||||
)
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
except:
|
|
||||||
data = None
|
|
||||||
if enable_disable == "enable":
|
|
||||||
enable_disable = 1
|
|
||||||
elif enable_disable == "disable":
|
|
||||||
enable_disable = 0
|
|
||||||
if data is None:
|
|
||||||
curs_data.execute(
|
|
||||||
"INSERT INTO images VALUES (?, ?, ?)",
|
|
||||||
(ctx_to_guid(ctx), 0, enable_disable),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE images SET is_enabled = ? WHERE guild_id = ?",
|
|
||||||
(enable_disable, ctx_to_guid(ctx)),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond(
|
|
||||||
"Images recognition has been "
|
|
||||||
+ ("enabled" if enable_disable == 1 else "disabled"),
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
@@ -1,329 +0,0 @@
|
|||||||
import discord
|
|
||||||
from discord import default_permissions, guild_only
|
|
||||||
from discord.ext import commands
|
|
||||||
from src.config import (
|
|
||||||
debug,
|
|
||||||
con_data,
|
|
||||||
curs_data,
|
|
||||||
con_premium,
|
|
||||||
curs_premium,
|
|
||||||
ctx_to_guid,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NoPrivateMessages(commands.CheckFailure):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def dms_only():
|
|
||||||
async def predicate(ctx):
|
|
||||||
if ctx.guild is not None:
|
|
||||||
raise NoPrivateMessages("Hey no private messages!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
return commands.check(predicate)
|
|
||||||
|
|
||||||
|
|
||||||
class Setup(discord.Cog):
|
|
||||||
def __init__(self, bot: discord.Bot):
|
|
||||||
super().__init__()
|
|
||||||
self.bot = bot
|
|
||||||
|
|
||||||
@discord.slash_command(name="setup", description="Setup the bot")
|
|
||||||
@discord.option(name="channel_id", description="The channel id", required=True)
|
|
||||||
@discord.option(name="api_key", description="The api key", required=True)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
@guild_only()
|
|
||||||
async def setup(
|
|
||||||
self,
|
|
||||||
ctx: discord.ApplicationContext,
|
|
||||||
channel: discord.TextChannel,
|
|
||||||
api_key: str,
|
|
||||||
):
|
|
||||||
if channel is None:
|
|
||||||
await ctx.respond("Invalid channel id", ephemeral=True)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx.guild.id,))
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
if data[3] == None:
|
|
||||||
data = None
|
|
||||||
except:
|
|
||||||
data = None
|
|
||||||
|
|
||||||
if data != None:
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET channel_id = ?, api_key = ? WHERE guild_id = ?",
|
|
||||||
(channel.id, api_key, ctx.guild.id),
|
|
||||||
)
|
|
||||||
# c.execute("UPDATE data SET is_active = ?, max_tokens = ?, temperature = ?, frequency_penalty = ?, presence_penalty = ?, prompt_size = ? WHERE guild_id = ?", (False, 64, 0.9, 0.0, 0.0, 5, ctx.guild.id))
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond(
|
|
||||||
"The channel id and the api key have been updated", ephemeral=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
curs_data.execute(
|
|
||||||
"INSERT INTO data VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
ctx.guild.id,
|
|
||||||
channel.id,
|
|
||||||
api_key,
|
|
||||||
False,
|
|
||||||
64,
|
|
||||||
0.9,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
"",
|
|
||||||
False,
|
|
||||||
"",
|
|
||||||
False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond(
|
|
||||||
"The channel id and the api key have been added", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@discord.slash_command(name="setup_dms", description="Setup the bot in dms")
|
|
||||||
@discord.option(name="api_key", description="The api key", required=True)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
@dms_only()
|
|
||||||
async def setup_dms(
|
|
||||||
self,
|
|
||||||
ctx: discord.ApplicationContext,
|
|
||||||
api_key: str,
|
|
||||||
):
|
|
||||||
channel = ctx.channel
|
|
||||||
if channel is None:
|
|
||||||
await ctx.respond("Invalid channel id", ephemeral=True)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx.user.id,))
|
|
||||||
data = curs_data.fetchone()
|
|
||||||
if data[3] == None:
|
|
||||||
data = None
|
|
||||||
except:
|
|
||||||
data = None
|
|
||||||
|
|
||||||
if data != None:
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET channel_id = ?, api_key = ? WHERE guild_id = ?",
|
|
||||||
(channel.id, api_key, ctx.user.id),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond(
|
|
||||||
"The channel id and the api key have been updated", ephemeral=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
curs_data.execute(
|
|
||||||
"INSERT INTO data VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
ctx.user.id,
|
|
||||||
channel.id,
|
|
||||||
api_key,
|
|
||||||
False,
|
|
||||||
64,
|
|
||||||
0.9,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
"",
|
|
||||||
False,
|
|
||||||
"",
|
|
||||||
False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("The api key has been added", ephemeral=True)
|
|
||||||
|
|
||||||
@discord.slash_command(
|
|
||||||
name="delete", description="Delete the information about this server"
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def delete(self, ctx: discord.ApplicationContext):
|
|
||||||
# check if the guild is in the database
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
|
||||||
if curs_data.fetchone() is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
# delete the guild from the database, except the guild id and the uses_count_today
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET api_key = ?, channel_id = ?, is_active = ?, max_tokens = ?, temperature = ?, frequency_penalty = ?, presence_penalty = ?, prompt_size = ? WHERE guild_id = ?",
|
|
||||||
(None, None, False, 50, 0.9, 0.0, 0.0, 0, ctx_to_guid(ctx)),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("Deleted", ephemeral=True)
|
|
||||||
|
|
||||||
# create a command called "enable" that only admins can use
|
|
||||||
@discord.slash_command(name="enable", description="Enable the bot")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def enable(self, ctx: discord.ApplicationContext):
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
|
||||||
if curs_data.fetchone() is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
# enable the guild
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET is_active = ? WHERE guild_id = ?", (True, ctx_to_guid(ctx))
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("Enabled", ephemeral=True)
|
|
||||||
|
|
||||||
# create a command called "disable" that only admins can use
|
|
||||||
@discord.slash_command(name="disable", description="Disable the bot")
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
async def disable(self, ctx: discord.ApplicationContext):
|
|
||||||
# check if the guild is in the database
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
|
||||||
if curs_data.fetchone() is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
# disable the guild
|
|
||||||
curs_data.execute(
|
|
||||||
"UPDATE data SET is_active = ? WHERE guild_id = ?",
|
|
||||||
(False, ctx_to_guid(ctx)),
|
|
||||||
)
|
|
||||||
con_data.commit()
|
|
||||||
await ctx.respond("Disabled", ephemeral=True)
|
|
||||||
|
|
||||||
# create a command calles "add channel" that can only be used in premium servers
|
|
||||||
@discord.slash_command(
|
|
||||||
name="add_channel",
|
|
||||||
description="Add a channel to the list of channels. Premium only.",
|
|
||||||
)
|
|
||||||
@discord.option(
|
|
||||||
name="channel",
|
|
||||||
description="The channel to add",
|
|
||||||
type=discord.TextChannel,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
@guild_only()
|
|
||||||
async def add_channel(
|
|
||||||
self, ctx: discord.ApplicationContext, channel: discord.TextChannel = None
|
|
||||||
):
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx.guild.id,))
|
|
||||||
if curs_data.fetchone() is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
curs_premium.execute(
|
|
||||||
"SELECT premium FROM data WHERE guild_id = ?", (ctx.guild.id,)
|
|
||||||
)
|
|
||||||
premium = curs_premium.fetchone()[0]
|
|
||||||
except:
|
|
||||||
premium = False
|
|
||||||
if not premium:
|
|
||||||
await ctx.respond("This server is not premium", ephemeral=True)
|
|
||||||
return
|
|
||||||
if channel is None:
|
|
||||||
channel = ctx.channel
|
|
||||||
# check if the channel is already in the list
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT channel_id FROM data WHERE guild_id = ?", (ctx.guild.id,)
|
|
||||||
)
|
|
||||||
if str(channel.id) == curs_data.fetchone()[0]:
|
|
||||||
await ctx.respond(
|
|
||||||
"This channel is already set as the main channel", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
curs_premium.execute(
|
|
||||||
"SELECT * FROM channels WHERE guild_id = ?", (ctx.guild.id,)
|
|
||||||
)
|
|
||||||
guild_channels = curs_premium.fetchone()
|
|
||||||
if guild_channels is None:
|
|
||||||
# if the channel is not in the list, add it
|
|
||||||
con_premium.execute(
|
|
||||||
"INSERT INTO channels VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
(ctx.guild.id, channel.id, None, None, None, None),
|
|
||||||
)
|
|
||||||
con_premium.commit()
|
|
||||||
await ctx.respond(f"Added channel **{channel.name}**", ephemeral=True)
|
|
||||||
return
|
|
||||||
channels = guild_channels[1:]
|
|
||||||
if str(channel.id) in channels:
|
|
||||||
await ctx.respond("This channel is already added", ephemeral=True)
|
|
||||||
return
|
|
||||||
for i in range(5):
|
|
||||||
if channels[i] == None:
|
|
||||||
curs_premium.execute(
|
|
||||||
f"UPDATE channels SET channel{i} = ? WHERE guild_id = ?",
|
|
||||||
(channel.id, ctx.guild.id),
|
|
||||||
)
|
|
||||||
con_premium.commit()
|
|
||||||
await ctx.respond(f"Added channel **{channel.name}**", ephemeral=True)
|
|
||||||
return
|
|
||||||
await ctx.respond("You can only add 5 channels", ephemeral=True)
|
|
||||||
|
|
||||||
# create a command called "remove channel" that can only be used in premium servers
|
|
||||||
@discord.slash_command(
|
|
||||||
name="remove_channel",
|
|
||||||
description="Remove a channel from the list of channels. Premium only.",
|
|
||||||
)
|
|
||||||
@discord.option(
|
|
||||||
name="channel",
|
|
||||||
description="The channel to remove",
|
|
||||||
type=discord.TextChannel,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
@default_permissions(administrator=True)
|
|
||||||
@guild_only()
|
|
||||||
async def remove_channel(
|
|
||||||
self, ctx: discord.ApplicationContext, channel: discord.TextChannel = None
|
|
||||||
):
|
|
||||||
# check if the guild is in the database
|
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx.guild.id,))
|
|
||||||
if curs_data.fetchone() is None:
|
|
||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
|
||||||
return
|
|
||||||
# check if the guild is premium
|
|
||||||
try:
|
|
||||||
con_premium.execute(
|
|
||||||
"SELECT premium FROM data WHERE guild_id = ?", (ctx.guild.id,)
|
|
||||||
)
|
|
||||||
premium = con_premium.fetchone()[0]
|
|
||||||
except:
|
|
||||||
premium = 0
|
|
||||||
if not premium:
|
|
||||||
await ctx.respond("This server is not premium", ephemeral=True)
|
|
||||||
return
|
|
||||||
if channel is None:
|
|
||||||
channel = ctx.channel
|
|
||||||
# check if the channel is in the list
|
|
||||||
con_premium.execute(
|
|
||||||
"SELECT * FROM channels WHERE guild_id = ?", (ctx.guild.id,)
|
|
||||||
)
|
|
||||||
guild_channels = con_premium.fetchone()
|
|
||||||
curs_data.execute(
|
|
||||||
"SELECT channel_id FROM data WHERE guild_id = ?", (ctx.guild.id,)
|
|
||||||
)
|
|
||||||
if str(channel.id) == curs_data.fetchone()[0]:
|
|
||||||
await ctx.respond(
|
|
||||||
"This channel is set as the main channel and therefore cannot be removed. Type /setup to change the main channel.",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if guild_channels is None:
|
|
||||||
await ctx.respond(
|
|
||||||
"This channel was not added. Nothing changed", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
channels = guild_channels[1:]
|
|
||||||
if str(channel.id) not in channels:
|
|
||||||
await ctx.respond(
|
|
||||||
"This channel was not added. Nothing changed", ephemeral=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# remove the channel from the list
|
|
||||||
for i in range(5):
|
|
||||||
if channels[i] == str(channel.id):
|
|
||||||
con_premium.execute(
|
|
||||||
f"UPDATE channels SET channel{i} = ? WHERE guild_id = ?",
|
|
||||||
(None, ctx.guild.id),
|
|
||||||
)
|
|
||||||
con_premium.commit()
|
|
||||||
await ctx.respond(f"Removed channel **{channel.name}**", ephemeral=True)
|
|
||||||
return
|
|
||||||
@@ -45,52 +45,9 @@ def mg_to_guid(mg):
|
|||||||
return mg.guild.id
|
return mg.guild.id
|
||||||
|
|
||||||
|
|
||||||
con_data = sqlite3.connect("./database/data.db")
|
|
||||||
curs_data = con_data.cursor()
|
|
||||||
con_premium = sqlite3.connect("./database/premium.db")
|
con_premium = sqlite3.connect("./database/premium.db")
|
||||||
curs_premium = con_premium.cursor()
|
curs_premium = con_premium.cursor()
|
||||||
|
|
||||||
curs_data.execute(
|
|
||||||
"""CREATE TABLE IF NOT EXISTS data (guild_id text, channel_id text, api_key text, is_active boolean, max_tokens integer, temperature real, frequency_penalty real, presence_penalty real, uses_count_today integer, prompt_size integer, prompt_prefix text, tts boolean, pretend_to_be text, pretend_enabled boolean)"""
|
|
||||||
)
|
|
||||||
# we delete the moderation table and create a new one, with all theese parameters as floats: TOXICITY: {result[0]}; SEVERE_TOXICITY: {result[1]}; IDENTITY ATTACK: {result[2]}; INSULT: {result[3]}; PROFANITY: {result[4]}; THREAT: {result[5]}; SEXUALLY EXPLICIT: {result[6]}; FLIRTATION: {result[7]}; OBSCENE: {result[8]}; SPAM: {result[9]}
|
|
||||||
expected_columns = 14
|
|
||||||
|
|
||||||
# we delete the moderation table and create a new one
|
|
||||||
curs_data.execute(
|
|
||||||
"""CREATE TABLE IF NOT EXISTS moderation (guild_id text, logs_channel_id text, is_enabled boolean, mod_role_id text, toxicity real, severe_toxicity real, identity_attack real, insult real, profanity real, threat real, sexually_explicit real, flirtation real, obscene real, spam real)"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# This code returns the number of columns in the table "moderation" in the database "data.db".
|
|
||||||
curs_data.execute("PRAGMA table_info(moderation)")
|
|
||||||
result = curs_data.fetchall()
|
|
||||||
actual_columns = len(result)
|
|
||||||
|
|
||||||
if actual_columns != expected_columns:
|
|
||||||
# we add the new columns
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN toxicity real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN severe_toxicity real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN identity_attack real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN insult real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN profanity real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN threat real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN sexually_explicit real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN flirtation real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN obscene real")
|
|
||||||
curs_data.execute("ALTER TABLE moderation ADD COLUMN spam real")
|
|
||||||
else:
|
|
||||||
print("Table already has the correct number of columns")
|
|
||||||
|
|
||||||
# This code creates the model table if it does not exist
|
|
||||||
curs_data.execute(
|
|
||||||
"""CREATE TABLE IF NOT EXISTS model (guild_id text, model_name text)"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# This code creates the images table if it does not exist
|
|
||||||
curs_data.execute(
|
|
||||||
"""CREATE TABLE IF NOT EXISTS images (guild_id text, usage_count integer, is_enabled boolean)"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# This code creates the data table if it does not exist
|
# This code creates the data table if it does not exist
|
||||||
curs_premium.execute(
|
curs_premium.execute(
|
||||||
"""CREATE TABLE IF NOT EXISTS data (user_id text, guild_id text, premium boolean)"""
|
"""CREATE TABLE IF NOT EXISTS data (user_id text, guild_id text, premium boolean)"""
|
||||||
@@ -100,6 +57,7 @@ curs_premium.execute(
|
|||||||
curs_premium.execute(
|
curs_premium.execute(
|
||||||
"""CREATE TABLE IF NOT EXISTS channels (guild_id text, channel0 text, channel1 text, channel2 text, channel3 text, channel4 text)"""
|
"""CREATE TABLE IF NOT EXISTS channels (guild_id text, channel0 text, channel1 text, channel2 text, channel3 text, channel4 text)"""
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
with open(
|
with open(
|
||||||
os.path.abspath(
|
os.path.abspath(
|
||||||
@@ -109,3 +67,4 @@ with open(
|
|||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
) as file:
|
) as file:
|
||||||
gpt_3_5_turbo_prompt = file.read()
|
gpt_3_5_turbo_prompt = file.read()
|
||||||
|
"""
|
||||||
|
|||||||
@@ -184,6 +184,15 @@ class FuntionCallError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
message_in_channel_in_wich_to_send: discord.Message, arguments: dict
|
||||||
|
):
|
||||||
|
message = arguments.get("message", "")
|
||||||
|
if message == "":
|
||||||
|
raise FuntionCallError("No message provided")
|
||||||
|
await message_in_channel_in_wich_to_send.channel.send(message)
|
||||||
|
|
||||||
|
|
||||||
async def get_final_url(url):
|
async def get_final_url(url):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.head(url, allow_redirects=True) as response:
|
async with session.head(url, allow_redirects=True) as response:
|
||||||
@@ -344,13 +353,14 @@ async def evaluate_math(
|
|||||||
return f"Result to math eval of {evaluable}: ```\n{str(result)}```"
|
return f"Result to math eval of {evaluable}: ```\n{str(result)}```"
|
||||||
|
|
||||||
|
|
||||||
async def call_function(message: discord.Message, function_call, api_key):
|
async def call_function(
|
||||||
|
message: discord.Message, function_call, api_key
|
||||||
|
) -> list[None | str]:
|
||||||
name = function_call.get("name", "")
|
name = function_call.get("name", "")
|
||||||
if name == "":
|
if name == "":
|
||||||
raise FuntionCallError("No name provided")
|
raise FuntionCallError("No name provided")
|
||||||
arguments = function_call.get("arguments", {})
|
arguments = function_call.get("arguments", {})
|
||||||
# load the function call arguments json
|
# load the function call arguments json
|
||||||
arguments = orjson.loads(arguments)
|
|
||||||
if name not in functions_matching:
|
if name not in functions_matching:
|
||||||
raise FuntionCallError("Invalid function name")
|
raise FuntionCallError("Invalid function name")
|
||||||
function = functions_matching[name]
|
function = functions_matching[name]
|
||||||
@@ -363,10 +373,11 @@ async def call_function(message: discord.Message, function_call, api_key):
|
|||||||
):
|
):
|
||||||
return "Query blocked by the moderation system. If the user asked for something edgy, please tell them in a funny way that you won't do it, but do not specify that it was blocked by the moderation system."
|
return "Query blocked by the moderation system. If the user asked for something edgy, please tell them in a funny way that you won't do it, but do not specify that it was blocked by the moderation system."
|
||||||
returnable = await function(message, arguments)
|
returnable = await function(message, arguments)
|
||||||
return returnable
|
return [returnable, name]
|
||||||
|
|
||||||
|
|
||||||
functions_matching = {
|
functions_matching = {
|
||||||
|
"send_message": send_message,
|
||||||
"add_reaction_to_last_message": add_reaction_to_last_message,
|
"add_reaction_to_last_message": add_reaction_to_last_message,
|
||||||
"reply_to_last_message": reply_to_last_message,
|
"reply_to_last_message": reply_to_last_message,
|
||||||
"send_a_stock_image": send_a_stock_image,
|
"send_a_stock_image": send_a_stock_image,
|
||||||
|
|||||||
126
src/guild.py
Normal file
126
src/guild.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import orjson
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from src.utils.SqlConnector import sql
|
||||||
|
from datetime import datetime
|
||||||
|
from src.utils.variousclasses import models, characters
|
||||||
|
|
||||||
|
|
||||||
|
class Guild:
|
||||||
|
def __init__(self, id: int):
|
||||||
|
self.id = id
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def getDbData(self):
|
||||||
|
with sql.mainDb as con:
|
||||||
|
curs_data = con.cursor()
|
||||||
|
curs_data.execute("SELECT * FROM setup_data WHERE guild_id = ?", (self.id,))
|
||||||
|
data = curs_data.fetchone()
|
||||||
|
self.isInDb = data is not None
|
||||||
|
if not self.isInDb:
|
||||||
|
self.updateDbData()
|
||||||
|
with sql.mainDb as con:
|
||||||
|
curs_data = con.cursor()
|
||||||
|
curs_data.execute(
|
||||||
|
"SELECT * FROM setup_data WHERE guild_id = ?", (self.id,)
|
||||||
|
)
|
||||||
|
data = curs_data.fetchone()
|
||||||
|
if type(data[1]) == str and data[1].startswith("b'"):
|
||||||
|
data = orjson.loads(data[1][2:-1])
|
||||||
|
else:
|
||||||
|
data = orjson.loads(data[1])
|
||||||
|
self.premium = data["premium"]
|
||||||
|
self.channels = data["channels"]
|
||||||
|
self.api_keys = data["api_keys"]
|
||||||
|
if self.premium:
|
||||||
|
self.premium_expiration = datetime.fromisoformat(
|
||||||
|
data.get("premium_expiration", None)
|
||||||
|
)
|
||||||
|
self.checkPremiumExpires()
|
||||||
|
else:
|
||||||
|
self.premium_expiration = None
|
||||||
|
|
||||||
|
def checkPremiumExpires(self):
|
||||||
|
if self.premium_expiration is None:
|
||||||
|
self.premium = False
|
||||||
|
return
|
||||||
|
if self.premium_expiration < datetime.now():
|
||||||
|
self.premium = False
|
||||||
|
self.premium_expiration = None
|
||||||
|
self.updateDbData()
|
||||||
|
|
||||||
|
def updateDbData(self):
|
||||||
|
if self.isInDb:
|
||||||
|
data = {
|
||||||
|
"guild_id": self.id,
|
||||||
|
"premium": self.premium,
|
||||||
|
"channels": self.channels,
|
||||||
|
"api_keys": self.api_keys,
|
||||||
|
"premium_expiration": self.premium_expiration.isoformat()
|
||||||
|
if self.premium
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
"guild_id": self.id,
|
||||||
|
"premium": False,
|
||||||
|
"channels": {},
|
||||||
|
"api_keys": {},
|
||||||
|
"premium_expiration": None,
|
||||||
|
}
|
||||||
|
with sql.mainDb as con:
|
||||||
|
curs_data = con.cursor()
|
||||||
|
if self.isInDb:
|
||||||
|
curs_data.execute(
|
||||||
|
"UPDATE setup_data SET guild_settings = ? WHERE guild_id = ?",
|
||||||
|
(orjson.dumps(data), self.id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
curs_data.execute(
|
||||||
|
"INSERT INTO setup_data (guild_id, guild_settings) VALUES (?, ?)",
|
||||||
|
(self.id, orjson.dumps(data)),
|
||||||
|
)
|
||||||
|
self.isInDb = True
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
self.getDbData()
|
||||||
|
|
||||||
|
def addChannel(
|
||||||
|
self, channel: discord.TextChannel | str, model: str, character: str
|
||||||
|
):
|
||||||
|
if isinstance(channel, discord.TextChannel):
|
||||||
|
channel = channel.id
|
||||||
|
self.channels[str(channel)] = {
|
||||||
|
"model": model,
|
||||||
|
"character": character,
|
||||||
|
}
|
||||||
|
self.updateDbData()
|
||||||
|
|
||||||
|
def delChannel(self, channel: discord.TextChannel | str):
|
||||||
|
if isinstance(channel, discord.TextChannel):
|
||||||
|
channel = channel.id
|
||||||
|
del self.channels[str(channel)]
|
||||||
|
self.updateDbData()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sanitizedChannels(self) -> dict:
|
||||||
|
if self.premium:
|
||||||
|
return self.channels
|
||||||
|
if len(self.channels) == 0:
|
||||||
|
return {}
|
||||||
|
dictionary = {
|
||||||
|
list(self.channels.keys())[0]: {
|
||||||
|
"model": models.matchingDict[models.default],
|
||||||
|
"character": characters.matchingDict[characters.default],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.channels.get("serverwide", None) is not None:
|
||||||
|
dictionary["serverwide"] = self.channels["serverwide"]
|
||||||
|
return dictionary
|
||||||
|
|
||||||
|
def getChannelInfo(self, channel: str) -> dict:
|
||||||
|
return self.sanitizedChannels.get(channel, None)
|
||||||
|
|
||||||
|
def addApiKey(self, api: str, key: str):
|
||||||
|
self.api_keys[api] = key
|
||||||
|
self.updateDbData()
|
||||||
29
src/utils/SqlConnector.py
Normal file
29
src/utils/SqlConnector.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from sqlite3 import connect
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
|
||||||
|
class SQLConnection:
|
||||||
|
def __init__(self, connection):
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.connection
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.connection.commit()
|
||||||
|
self.connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
class _sql:
|
||||||
|
@property
|
||||||
|
def mainDb(self):
|
||||||
|
s = connect("./database/data.db")
|
||||||
|
return SQLConnection(s)
|
||||||
|
|
||||||
|
|
||||||
|
sql: _sql = _sql()
|
||||||
|
|
||||||
|
command = "CREATE TABLE IF NOT EXISTS setup_data (guild_id text, guild_settings text)"
|
||||||
|
|
||||||
|
with sql.mainDb as db:
|
||||||
|
db.execute(command)
|
||||||
@@ -142,7 +142,6 @@ class openai_caller:
|
|||||||
"`An APIError occurred. This is not your fault, it is OpenAI's fault. We apologize for the inconvenience. Retrying...`"
|
"`An APIError occurred. This is not your fault, it is OpenAI's fault. We apologize for the inconvenience. Retrying...`"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
|
||||||
i += 1
|
i += 1
|
||||||
except Timeout as e:
|
except Timeout as e:
|
||||||
print(
|
print(
|
||||||
@@ -150,7 +149,6 @@ class openai_caller:
|
|||||||
)
|
)
|
||||||
await recall_func("`The request timed out. Retrying...`")
|
await recall_func("`The request timed out. Retrying...`")
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
|
||||||
i += 1
|
i += 1
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
print(
|
print(
|
||||||
@@ -158,13 +156,11 @@ class openai_caller:
|
|||||||
)
|
)
|
||||||
await recall_func("`You are being rate limited. Retrying...`")
|
await recall_func("`You are being rate limited. Retrying...`")
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
|
||||||
i += 1
|
i += 1
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
print(
|
print(
|
||||||
f"\n\n{bcolors.BOLD}{bcolors.FAIL}APIConnectionError. There is an issue with your internet connection. Please check your connection.{bcolors.ENDC}"
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}APIConnectionError. There is an issue with your internet connection. Please check your connection.{bcolors.ENDC}"
|
||||||
)
|
)
|
||||||
await recall_func()
|
|
||||||
raise e
|
raise e
|
||||||
except InvalidRequestError as e:
|
except InvalidRequestError as e:
|
||||||
print(
|
print(
|
||||||
|
|||||||
34
src/utils/replicatepredictor.py
Normal file
34
src/utils/replicatepredictor.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import replicate
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicatePredictor:
|
||||||
|
def __init__(self, api_key, model_name, version_hash):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model_name = model_name
|
||||||
|
self.version_hash = version_hash
|
||||||
|
self.client = replicate.Client(api_token=self.api_key)
|
||||||
|
self.model = self.client.models.get(self.model_name)
|
||||||
|
self.version = self.model.versions.get(self.version_hash)
|
||||||
|
|
||||||
|
def prediction_thread(self, prompt, stop=None):
|
||||||
|
output = self.client.predictions.create(
|
||||||
|
version=self.version,
|
||||||
|
input={"prompt": prompt},
|
||||||
|
)
|
||||||
|
finaloutput = ""
|
||||||
|
for out in output.output_iterator():
|
||||||
|
finaloutput += out
|
||||||
|
if stop != None and finaloutput.find(stop) != -1:
|
||||||
|
output.cancel()
|
||||||
|
if stop != None:
|
||||||
|
return finaloutput.split(stop)[0]
|
||||||
|
else:
|
||||||
|
return finaloutput
|
||||||
|
|
||||||
|
async def predict(self, prompt, stop=None):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
None, lambda: self.prediction_thread(prompt, stop)
|
||||||
|
)
|
||||||
|
return result
|
||||||
46
src/utils/variousclasses.py
Normal file
46
src/utils/variousclasses.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from discord import AutocompleteContext
|
||||||
|
|
||||||
|
|
||||||
|
class models:
|
||||||
|
matchingDict = {
|
||||||
|
"chatGPT (default - free)": "gpt-3.5-turbo",
|
||||||
|
"davinci (premium)": "text-davinci-003",
|
||||||
|
"llama (premium)": "text-llama",
|
||||||
|
"llama 2 (premium)": "text-llama-2",
|
||||||
|
}
|
||||||
|
reverseMatchingDict = {v: k for k, v in matchingDict.items()}
|
||||||
|
default = list(matchingDict.keys())[0]
|
||||||
|
openaimodels = ["gpt-3.5-turbo", "text-davinci-003"]
|
||||||
|
chatModels = ["gpt-3.5-turbo"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def autocomplete(cls, ctx: AutocompleteContext) -> list[str]:
|
||||||
|
modls = cls.matchingDict.keys()
|
||||||
|
return [model for model in modls if model.find(ctx.value.lower()) != -1]
|
||||||
|
|
||||||
|
|
||||||
|
class characters:
|
||||||
|
matchingDict = {
|
||||||
|
"Botator (default - free)": "botator",
|
||||||
|
"Quantum (premium)": "quantum",
|
||||||
|
}
|
||||||
|
reverseMatchingDict = {v: k for k, v in matchingDict.items()}
|
||||||
|
default = list(matchingDict.keys())[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def autocomplete(cls, ctx: AutocompleteContext) -> list[str]:
|
||||||
|
chars = characters = cls.matchingDict.keys()
|
||||||
|
return [
|
||||||
|
character for character in chars if character.find(ctx.value.lower()) != -1
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class apis:
|
||||||
|
matchingDict = {
|
||||||
|
"OpenAI": "openai",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def autocomplete(cls, ctx: AutocompleteContext) -> list[str]:
|
||||||
|
apiss = cls.matchingDict.keys()
|
||||||
|
return [api for api in apiss if api.find(ctx.value.lower()) != -1]
|
||||||
Reference in New Issue
Block a user