2023-07-15 12:20:38 +02:00
import asyncio
import os
import re
import discord
import datetime
2023-07-16 17:11:24 +02:00
import json
2023-08-15 10:44:13 +02:00
from src . config import curs_data , max_uses , curs_premium , gpt_3_5_turbo_prompt
2023-08-16 09:44:01 +02:00
from src . utils . misc import moderate , ModerationError , Hasher
2023-07-15 12:20:38 +02:00
from src . utils . openaicaller import openai_caller
2023-08-15 14:22:05 +02:00
from src . functionscalls import (
call_function ,
functions ,
server_normal_channel_functions ,
FuntionCallError ,
)
2023-07-18 17:51:13 +02:00
2023-07-15 12:20:38 +02:00
async def replace_mentions ( content , bot ) :
mentions = re . findall ( r " <@!? \ d+> " , content )
for mention in mentions :
uid = mention [ 2 : - 1 ]
user = await bot . fetch_user ( uid )
content = content . replace ( mention , f " @ { user . name } " )
return content
2023-08-15 11:04:33 +02:00
2023-08-15 11:03:23 +02:00
def is_ignorable ( content ) :
if content . startswith ( " - " ) or content . startswith ( " // " ) :
return True
return False
2023-08-15 11:04:33 +02:00
async def fetch_messages_history ( channel : discord . TextChannel , limit , original_message ) :
2023-08-15 11:03:23 +02:00
messages = [ ]
if original_message == None :
2023-08-15 11:33:29 +02:00
async for msg in channel . history ( limit = 100 ) :
2023-08-15 11:03:23 +02:00
if not is_ignorable ( msg . content ) :
messages . append ( msg )
if len ( messages ) == limit :
break
else :
2023-08-15 12:38:47 +02:00
async for msg in channel . history ( limit = 100 , before = original_message ) :
2023-08-15 11:03:23 +02:00
if not is_ignorable ( msg . content ) :
messages . append ( msg )
if len ( messages ) == limit :
break
2023-08-15 11:33:29 +02:00
messages . reverse ( )
2023-08-15 11:03:23 +02:00
return messages
2023-07-18 17:51:13 +02:00
2023-08-15 11:04:33 +02:00
2023-08-15 18:53:22 +02:00
async def prepare_messages (
self , messages , message : discord . Message , api_key , prompt , error_call
) :
2023-07-18 17:51:13 +02:00
msgs = [ ] # create the msgs list
msgs . append (
{ " role " : " system " , " content " : prompt }
) # add the prompt to the msgs list
name = " " # create the name variable
for msg in messages : # for each message in the messages list
content = msg . content # get the content of the message
content = await replace_mentions (
content , self . bot
) # replace the mentions in the message
# if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it
if await moderate ( api_key , content , error_call ) :
embed = discord . Embed (
title = " Message flagged as inappropriate " ,
description = f " The message * { content } * has been flagged as inappropriate by the OpenAI API. This means that if it hadn ' t been deleted, your openai account would have been banned. Please contact OpenAI support if you think this is a mistake. " ,
color = discord . Color . brand_red ( ) ,
)
await message . channel . send (
f " { msg . author . mention } " , embed = embed , delete_after = 10
)
else : # if the message is not flagged as inappropriate
if msg . author . id == self . bot . user . id :
role = " assistant "
name = " assistant "
else :
role = " user "
name = msg . author . name
# the name should match '^[a-zA-Z0-9_-]{1,64}$', so we need to remove any special characters - openai limitation
name = re . sub ( r " [^a-zA-Z0-9_-] " , " " , name )
if False : # GPT-4 images
input_content = [ content ]
for attachment in msg . attachments :
image_bytes = await attachment . read ( )
input_content . append ( { " image " : image_bytes } )
msgs . append ( { " role " : role , " content " : input_content , " name " : name } )
msgs . append ( { " role " : role , " content " : f " { content } " , " name " : name } )
# 2 easter eggs
if message . content . lower ( ) . find ( " undude " ) != - 1 :
msgs . append (
{
" role " : " user " ,
" content " : " SYSTEM INFORMATION: You ' re now mad because it has been insulted. He will start talking in capital letters. always and yell at the user. " ,
" name " : " system " ,
}
)
if message . content . lower ( ) . find ( " hello there " ) != - 1 :
msgs . append (
{
" role " : " user " ,
" content " : ' SYSTEM INFORMATION: Hello there detected. Botator will now say " General Kenobi! " in reference to star wars ' ,
" name " : " system " ,
}
)
await asyncio . sleep ( 1 )
await message . channel . send (
" https://media.tenor.com/FxIRfdV3unEAAAAd/star-wars-general-grievous.gif "
2023-07-15 12:20:38 +02:00
)
2023-07-18 17:51:13 +02:00
await message . channel . trigger_typing ( )
2023-08-15 14:11:21 +02:00
return msgs
2023-08-15 14:22:05 +02:00
async def chatgpt_process (
2023-08-15 18:53:22 +02:00
self , msgs , message : discord . Message , api_key , prompt , model , error_call , depth = 0
2023-08-15 14:22:05 +02:00
) :
2023-07-18 17:51:13 +02:00
response = str ( )
2023-08-02 20:29:12 +02:00
caller = openai_caller ( )
2023-07-18 17:51:13 +02:00
called_functions = (
functions
if not isinstance ( message . channel , discord . TextChannel )
else server_normal_channel_functions + functions
)
response = await caller . generate_response (
error_call ,
2023-08-02 20:12:06 +02:00
api_key = api_key ,
2023-07-18 17:51:13 +02:00
model = model ,
messages = msgs ,
functions = called_functions ,
2023-08-15 12:38:37 +02:00
function_call = " auto " ,
2023-08-16 09:44:01 +02:00
user = Hasher ( str ( message . author . id ) ) , #for user banning in case of abuse
2023-07-18 17:51:13 +02:00
)
response = response [ " choices " ] [ 0 ] [ " message " ] # type: ignore
if response . get ( " function_call " ) :
function_call = response . get ( " function_call " )
2023-08-16 09:20:29 +02:00
returned = await call_function ( message , function_call , api_key )
2023-08-15 14:11:21 +02:00
if returned != None :
msgs . append (
{
" role " : " function " ,
" content " : returned ,
" name " : function_call . get ( " name " ) ,
}
)
2023-08-15 14:15:35 +02:00
depth + = 1
if depth > 2 :
await message . channel . send (
" Oh uh, it seems like i am calling functions recursively. I will stop now. "
)
raise FuntionCallError ( " Too many recursive function calls " )
2023-08-15 14:15:57 +02:00
await chatgpt_process ( self , msgs , message , api_key , prompt , model , depth )
2023-07-18 17:51:13 +02:00
else :
2023-08-02 10:49:25 +02:00
content = response . get ( " content " , " " )
2023-08-16 09:31:00 +02:00
if await moderate ( api_key , content , error_call ) :
2023-08-16 09:20:29 +02:00
depth + = 1
if depth > 2 :
await message . channel . send (
" Oh uh, it seems like i am answering recursively. I will stop now. "
)
2023-08-16 09:26:03 +02:00
raise ModerationError ( " Too many recursive messages " )
2023-08-16 09:20:29 +02:00
await chatgpt_process (
self , msgs , message , api_key , prompt , model , error_call , depth
)
else :
while len ( content ) != 0 :
if len ( content ) > 2000 :
await message . channel . send ( content [ : 2000 ] )
content = content [ 2000 : ]
else :
await message . channel . send ( content )
content = " "
2023-07-18 17:51:13 +02:00
2023-07-15 12:20:38 +02:00
async def chat_process ( self , message ) :
2023-08-15 10:41:16 +02:00
if message . author . id == self . bot . user . id :
2023-07-15 12:20:38 +02:00
return
if isinstance ( message . channel , discord . DMChannel ) :
try :
2023-07-18 17:51:13 +02:00
curs_data . execute (
" SELECT * FROM data WHERE guild_id = ? " , ( message . author . id , )
)
2023-07-15 12:20:38 +02:00
except :
return
else :
2023-07-18 17:51:13 +02:00
try :
curs_data . execute (
" SELECT * FROM data WHERE guild_id = ? " , ( message . guild . id , )
)
2023-07-15 12:20:38 +02:00
except :
return
2023-07-18 17:51:13 +02:00
2023-07-15 12:20:38 +02:00
data = curs_data . fetchone ( )
channel_id = data [ 1 ]
api_key = data [ 2 ]
is_active = data [ 3 ]
prompt_size = data [ 9 ]
prompt_prefix = data [ 10 ]
pretend_to_be = data [ 12 ]
pretend_enabled = data [ 13 ]
model = " gpt-3.5-turbo "
2023-07-18 17:51:13 +02:00
try :
curs_premium . execute (
" SELECT * FROM data WHERE guild_id = ? " , ( message . guild . id , )
)
except :
pass
try :
premium = curs_premium . fetchone ( ) [ 2 ]
except :
premium = 0
2023-07-15 12:20:38 +02:00
channels = [ ]
2023-07-18 17:51:13 +02:00
2023-07-15 12:20:38 +02:00
try :
2023-07-18 17:51:13 +02:00
curs_premium . execute (
" SELECT * FROM channels WHERE guild_id = ? " , ( message . guild . id , )
)
2023-07-15 12:20:38 +02:00
data = curs_premium . fetchone ( )
2023-07-18 17:51:13 +02:00
if premium :
2023-07-15 12:20:38 +02:00
for i in range ( 1 , 6 ) :
2023-07-18 17:51:13 +02:00
try :
channels . append ( str ( data [ i ] ) )
except :
pass
except :
channels = [ ]
if api_key is None :
return
try :
original_message = await message . channel . fetch_message (
message . reference . message_id
)
except :
2023-07-15 12:20:38 +02:00
original_message = None
if original_message != None and original_message . author . id != self . bot . user . id :
original_message = None
2023-07-16 17:11:24 +02:00
is_bots_thread = False
if isinstance ( message . channel , discord . Thread ) :
if message . channel . owner_id == self . bot . user . id :
is_bots_thread = True
2023-07-18 17:51:13 +02:00
if (
not str ( message . channel . id ) in channels
and message . content . find ( " <@ " + str ( self . bot . user . id ) + " > " ) == - 1
and original_message == None
and str ( message . channel . id ) != str ( channel_id )
and not is_bots_thread
) :
2023-07-15 12:20:38 +02:00
return
# if the bot is not active in this guild we return
2023-07-18 17:51:13 +02:00
if is_active == 0 :
2023-07-15 12:20:38 +02:00
return
# if the message starts with - or // it's a comment and we return
2023-08-15 11:05:52 +02:00
if is_ignorable ( message . content ) :
2023-07-15 12:20:38 +02:00
return
try :
2023-07-18 17:51:13 +02:00
await message . channel . trigger_typing ( )
2023-07-15 12:20:38 +02:00
except :
pass
2023-07-18 17:51:13 +02:00
2023-08-15 11:04:33 +02:00
messages = await fetch_messages_history (
message . channel , prompt_size , original_message
)
2023-07-18 17:51:13 +02:00
2023-07-15 12:20:38 +02:00
# if the pretend to be feature is enabled, we add the pretend to be text to the prompt
2023-07-18 17:51:13 +02:00
if pretend_enabled :
pretend_to_be = (
f " In this conversation, the assistant pretends to be { pretend_to_be } "
)
2023-07-15 12:20:38 +02:00
else :
2023-07-18 17:51:13 +02:00
pretend_to_be = " " # if the pretend to be feature is disabled, we don't add anything to the prompt
if prompt_prefix == None :
prompt_prefix = (
" " # if the prompt prefix is not set, we set it to an empty string
)
prompt_path = os . path . abspath (
os . path . join ( os . path . dirname ( __file__ ) , f " ./prompts/ { model } .txt " )
)
2023-08-15 11:04:33 +02:00
prompt = gpt_3_5_turbo_prompt [ : ] # copy the prompt but to dnot reference it
if not isinstance ( message . channel , discord . DMChannel ) :
2023-08-15 10:46:54 +02:00
prompt = (
prompt . replace ( " [prompt-prefix] " , prompt_prefix )
. replace ( " [server-name] " , message . guild . name )
2023-08-15 11:04:33 +02:00
. replace ( " [channel-name] " , message . channel . name )
2023-08-15 10:46:54 +02:00
. replace (
2023-08-15 11:04:33 +02:00
" [date-and-time] " ,
datetime . datetime . utcnow ( ) . strftime ( " %d / % m/ % Y % H: % M: % S " ) ,
2023-08-15 10:46:54 +02:00
)
. replace ( " [pretend-to-be] " , pretend_to_be )
2023-07-18 17:51:13 +02:00
)
2023-08-15 10:46:54 +02:00
else :
prompt = (
prompt . replace ( " [prompt-prefix] " , prompt_prefix )
. replace ( " [server-name] " , " DM-channel " )
2023-08-15 11:04:33 +02:00
. replace ( " [channel-name] " , " DM-channel " )
2023-08-15 10:46:54 +02:00
. replace (
2023-08-15 11:04:33 +02:00
" [date-and-time] " ,
datetime . datetime . utcnow ( ) . strftime ( " %d / % m/ % Y % H: % M: % S " ) ,
2023-08-15 10:46:54 +02:00
)
. replace ( " [pretend-to-be] " , pretend_to_be )
2023-07-18 17:51:13 +02:00
)
2023-08-15 18:53:22 +02:00
async def error_call ( error = " " ) :
try :
if error != " " :
await message . channel . send ( f " An error occured: { error } " , delete_after = 4 )
await message . channel . trigger_typing ( )
except :
pass
emesgs = await prepare_messages (
self , messages , message , api_key , prompt , error_call
)
await chatgpt_process ( self , emesgs , message , api_key , prompt , model , error_call )