diff --git a/src/cogs/chat.py b/src/cogs/chat.py index 6b69064..76a5fd1 100644 --- a/src/cogs/chat.py +++ b/src/cogs/chat.py @@ -107,7 +107,7 @@ class Chat(discord.Cog): if message.content.startswith("botator!unban"): user2ban = message.content.split(" ")[1] await banusr.unbanuser(user2ban) - await message.channel.send(f"User {user2ban} unbanned !") + await message.chafnnel.send(f"User {user2ban} unbanned !") return if str(message.author.id) in banusr.banend_users: await asyncio.sleep(2) @@ -115,11 +115,6 @@ class Chat(discord.Cog): return await mp.chat_process(self, message) - @discord.slash_command(name="say", description="Say a message") - async def say(self, ctx: discord.ApplicationContext, message: str): - await ctx.respond("Message sent !", ephemeral=True) - await ctx.send(message) - @discord.slash_command(name="redo", description="Redo a message") async def redo(self, ctx: discord.ApplicationContext): history = await ctx.channel.history(limit=2).flatten() diff --git a/src/functionscalls.py b/src/functionscalls.py index 2d94261..c7da912 100644 --- a/src/functionscalls.py +++ b/src/functionscalls.py @@ -5,6 +5,7 @@ import aiohttp import random import time +from src.utils.misc import moderate from simpleeval import simple_eval from bs4 import BeautifulSoup from src.config import tenor_api_key @@ -331,7 +332,7 @@ async def evaluate_math( return f"Result to math eval of {evaluable}: ```\n{str(result)}```" -async def call_function(message: discord.Message, function_call): +async def call_function(message: discord.Message, function_call, api_key): name = function_call.get("name", "") if name == "": raise FuntionCallError("No name provided") @@ -341,6 +342,14 @@ async def call_function(message: discord.Message, function_call): if name not in functions_matching: raise FuntionCallError("Invalid function name") function = functions_matching[name] + if arguments.get("message", "") != "" and await moderate( + api_key=api_key, text=arguments["message"] + ): + return "Message blocked by the moderation system. Please try again." + if arguments.get("query", "") != "" and await moderate( + api_key=api_key, text=arguments["query"] + ): + return "Query blocked by the moderation system. If the user asked for something edgy, please tell them in a funny way that you won't do it, but do not specify that it was blocked by the moderation system." returnable = await function(message, arguments) return returnable diff --git a/src/makeprompt.py b/src/makeprompt.py index fb39257..f710269 100644 --- a/src/makeprompt.py +++ b/src/makeprompt.py @@ -6,7 +6,7 @@ import datetime import json from src.config import curs_data, max_uses, curs_premium, gpt_3_5_turbo_prompt -from src.utils.misc import moderate +from src.utils.misc import moderate, ModerationError, Hasher from src.utils.openaicaller import openai_caller from src.functionscalls import ( call_function, @@ -131,11 +131,12 @@ async def chatgpt_process( messages=msgs, functions=called_functions, function_call="auto", + user=Hasher(str(message.author.id)), # for user banning in case of abuse ) response = response["choices"][0]["message"] # type: ignore if response.get("function_call"): function_call = response.get("function_call") - returned = await call_function(message, function_call) + returned = await call_function(message, function_call, api_key) if returned != None: msgs.append( { @@ -153,13 +154,24 @@ async def chatgpt_process( await chatgpt_process(self, msgs, message, api_key, prompt, model, depth) else: content = response.get("content", "") - while len(content) != 0: - if len(content) > 2000: - await message.channel.send(content[:2000]) - content = content[2000:] - else: - await message.channel.send(content) - content = "" + if await moderate(api_key, content, error_call): + depth += 1 + if depth > 2: + await message.channel.send( + "Oh uh, it seems like i am answering recursively. I will stop now." + ) + raise ModerationError("Too many recursive messages") + await chatgpt_process( + self, msgs, message, api_key, prompt, model, error_call, depth + ) + else: + while len(content) != 0: + if len(content) > 2000: + await message.channel.send(content[:2000]) + content = content[2000:] + else: + await message.channel.send(content) + content = "" async def chat_process(self, message): diff --git a/src/utils/misc.py b/src/utils/misc.py index 13f1aa9..15344e3 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -1,3 +1,5 @@ +import hashlib + from src.utils.openaicaller import openai_caller @@ -9,3 +11,20 @@ async def moderate(api_key, text, recall_func=None): input=text, ) return response["results"][0]["flagged"] # type: ignore + + +class ModerationError(Exception): + pass + + +class hasher: + def __init__(self): + self.hashes = {} + + def __call__(self, text: str) -> str: + if self.hashes.get(text, None) is None: + self.hashes[text] = hashlib.sha256(text.encode()).hexdigest() + return self.hashes[text] + + +Hasher = hasher()