Merge pull request #53 from Paillat-dev/dev

Dev
This commit is contained in:
2023-08-17 17:21:28 +02:00
committed by GitHub
4 changed files with 51 additions and 16 deletions

View File

@@ -107,7 +107,7 @@ class Chat(discord.Cog):
if message.content.startswith("botator!unban"): if message.content.startswith("botator!unban"):
user2ban = message.content.split(" ")[1] user2ban = message.content.split(" ")[1]
await banusr.unbanuser(user2ban) await banusr.unbanuser(user2ban)
await message.channel.send(f"User {user2ban} unbanned !") await message.chafnnel.send(f"User {user2ban} unbanned !")
return return
if str(message.author.id) in banusr.banend_users: if str(message.author.id) in banusr.banend_users:
await asyncio.sleep(2) await asyncio.sleep(2)
@@ -115,11 +115,6 @@ class Chat(discord.Cog):
return return
await mp.chat_process(self, message) await mp.chat_process(self, message)
@discord.slash_command(name="say", description="Say a message")
async def say(self, ctx: discord.ApplicationContext, message: str):
await ctx.respond("Message sent !", ephemeral=True)
await ctx.send(message)
@discord.slash_command(name="redo", description="Redo a message") @discord.slash_command(name="redo", description="Redo a message")
async def redo(self, ctx: discord.ApplicationContext): async def redo(self, ctx: discord.ApplicationContext):
history = await ctx.channel.history(limit=2).flatten() history = await ctx.channel.history(limit=2).flatten()

View File

@@ -5,6 +5,7 @@ import aiohttp
import random import random
import time import time
from src.utils.misc import moderate
from simpleeval import simple_eval from simpleeval import simple_eval
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from src.config import tenor_api_key from src.config import tenor_api_key
@@ -331,7 +332,7 @@ async def evaluate_math(
return f"Result to math eval of {evaluable}: ```\n{str(result)}```" return f"Result to math eval of {evaluable}: ```\n{str(result)}```"
async def call_function(message: discord.Message, function_call): async def call_function(message: discord.Message, function_call, api_key):
name = function_call.get("name", "") name = function_call.get("name", "")
if name == "": if name == "":
raise FuntionCallError("No name provided") raise FuntionCallError("No name provided")
@@ -341,6 +342,14 @@ async def call_function(message: discord.Message, function_call):
if name not in functions_matching: if name not in functions_matching:
raise FuntionCallError("Invalid function name") raise FuntionCallError("Invalid function name")
function = functions_matching[name] function = functions_matching[name]
if arguments.get("message", "") != "" and await moderate(
api_key=api_key, text=arguments["message"]
):
return "Message blocked by the moderation system. Please try again."
if arguments.get("query", "") != "" and await moderate(
api_key=api_key, text=arguments["query"]
):
return "Query blocked by the moderation system. If the user asked for something edgy, please tell them in a funny way that you won't do it, but do not specify that it was blocked by the moderation system."
returnable = await function(message, arguments) returnable = await function(message, arguments)
return returnable return returnable

View File

@@ -6,7 +6,7 @@ import datetime
import json import json
from src.config import curs_data, max_uses, curs_premium, gpt_3_5_turbo_prompt from src.config import curs_data, max_uses, curs_premium, gpt_3_5_turbo_prompt
from src.utils.misc import moderate from src.utils.misc import moderate, ModerationError, Hasher
from src.utils.openaicaller import openai_caller from src.utils.openaicaller import openai_caller
from src.functionscalls import ( from src.functionscalls import (
call_function, call_function,
@@ -131,11 +131,12 @@ async def chatgpt_process(
messages=msgs, messages=msgs,
functions=called_functions, functions=called_functions,
function_call="auto", function_call="auto",
user=Hasher(str(message.author.id)), # for user banning in case of abuse
) )
response = response["choices"][0]["message"] # type: ignore response = response["choices"][0]["message"] # type: ignore
if response.get("function_call"): if response.get("function_call"):
function_call = response.get("function_call") function_call = response.get("function_call")
returned = await call_function(message, function_call) returned = await call_function(message, function_call, api_key)
if returned != None: if returned != None:
msgs.append( msgs.append(
{ {
@@ -153,6 +154,17 @@ async def chatgpt_process(
await chatgpt_process(self, msgs, message, api_key, prompt, model, depth) await chatgpt_process(self, msgs, message, api_key, prompt, model, depth)
else: else:
content = response.get("content", "") content = response.get("content", "")
if await moderate(api_key, content, error_call):
depth += 1
if depth > 2:
await message.channel.send(
"Oh uh, it seems like i am answering recursively. I will stop now."
)
raise ModerationError("Too many recursive messages")
await chatgpt_process(
self, msgs, message, api_key, prompt, model, error_call, depth
)
else:
while len(content) != 0: while len(content) != 0:
if len(content) > 2000: if len(content) > 2000:
await message.channel.send(content[:2000]) await message.channel.send(content[:2000])

View File

@@ -1,3 +1,5 @@
import hashlib
from src.utils.openaicaller import openai_caller from src.utils.openaicaller import openai_caller
@@ -9,3 +11,20 @@ async def moderate(api_key, text, recall_func=None):
input=text, input=text,
) )
return response["results"][0]["flagged"] # type: ignore return response["results"][0]["flagged"] # type: ignore
class ModerationError(Exception):
pass
class hasher:
def __init__(self):
self.hashes = {}
def __call__(self, text: str) -> str:
if self.hashes.get(text, None) is None:
self.hashes[text] = hashlib.sha256(text.encode()).hexdigest()
return self.hashes[text]
Hasher = hasher()