Fixed error with textchannel ony function and deprecated some comands

This commit is contained in:
Paillat
2023-07-16 22:43:18 +02:00
parent 06b0e921db
commit 41b2fde1e6
5 changed files with 71 additions and 183 deletions

View File

@@ -1,5 +1,6 @@
import discord
from src.config import debug, con_data, curs_data, moderate, ctx_to_guid
from src.config import debug, con_data, curs_data, ctx_to_guid
from src.utils.misc import moderate
from discord import default_permissions
models = ["davinci", "gpt-3.5-turbo", "gpt-4"]
@@ -31,145 +32,42 @@ class Settings(discord.Cog):
presence_penalty: float = None,
prompt_size: int = None,
):
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
if curs_data.fetchone() is None:
await ctx.respond("This server is not setup", ephemeral=True)
return
if (
max_tokens is None
and temperature is None
and frequency_penalty is None
and presence_penalty is None
and prompt_size is None
):
await ctx.respond("You must enter at least one argument", ephemeral=True)
return
if max_tokens is not None and (max_tokens < 1 or max_tokens > 4000):
await ctx.respond("Invalid max tokens", ephemeral=True)
return
if temperature is not None and (temperature < 0.0 or temperature > 1.0):
await ctx.respond("Invalid temperature", ephemeral=True)
return
if frequency_penalty is not None and (
frequency_penalty < 0.0 or frequency_penalty > 2.0
):
await ctx.respond("Invalid frequency penalty", ephemeral=True)
return
if presence_penalty is not None and (
presence_penalty < 0.0 or presence_penalty > 2.0
):
await ctx.respond("Invalid presence penalty", ephemeral=True)
return
if prompt_size is not None and (prompt_size < 1 or prompt_size > 10):
await ctx.respond("Invalid prompt size", ephemeral=True)
return
if max_tokens is None:
if (
curs_data.execute(
"SELECT max_tokens FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
).fetchone()[0]
is not None
and max_tokens is None
):
max_tokens = curs_data.execute(
"SELECT max_tokens FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
).fetchone()[0]
else:
max_tokens = 64
if temperature is None:
if (
curs_data.execute(
"SELECT temperature FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
).fetchone()[0]
is not None
and temperature is None
):
temperature = curs_data.execute(
"SELECT temperature FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
).fetchone()[0]
else:
temperature = 0.9
if frequency_penalty is None:
if (
curs_data.execute(
"SELECT frequency_penalty FROM data WHERE guild_id = ?",
(ctx_to_guid(ctx),),
).fetchone()[0]
is not None
and frequency_penalty is None
):
frequency_penalty = curs_data.execute(
"SELECT frequency_penalty FROM data WHERE guild_id = ?",
(ctx_to_guid(ctx),),
).fetchone()[0]
else:
frequency_penalty = 0.0
if presence_penalty is None:
if (
curs_data.execute(
"SELECT presence_penalty FROM data WHERE guild_id = ?",
(ctx_to_guid(ctx),),
).fetchone()[0]
is not None
and presence_penalty is None
):
presence_penalty = curs_data.execute(
"SELECT presence_penalty FROM data WHERE guild_id = ?",
(ctx_to_guid(ctx),),
).fetchone()[0]
else:
presence_penalty = 0.0
if prompt_size is None:
if (
curs_data.execute(
"SELECT prompt_size FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
).fetchone()[0]
is not None
and prompt_size is None
):
prompt_size = curs_data.execute(
"SELECT prompt_size FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
).fetchone()[0]
else:
prompt_size = 1
# update the database
curs_data.execute(
"UPDATE data SET max_tokens = ?, temperature = ?, frequency_penalty = ?, presence_penalty = ?, prompt_size = ? WHERE guild_id = ?",
(
max_tokens,
temperature,
frequency_penalty,
presence_penalty,
prompt_size,
ctx_to_guid(ctx),
),
)
con_data.commit()
await ctx.respond("Advanced settings updated", ephemeral=True)
# create a command called "delete" that only admins can use wich deletes the guild id, the api key, the channel id and the advanced settings from the database
await ctx.respond("This command has been deprecated since the new model does not need theese settungs to work well", ephemeral=True)
@discord.slash_command(name="default", description="Default settings")
@default_permissions(administrator=True)
async def default(self, ctx: discord.ApplicationContext):
await ctx.respond("This command has been deprecated since the new model does not need theese settungs to work well", ephemeral=True)
@discord.slash_command(name="prompt_size", description="Set the prompt size")
@default_permissions(administrator=True)
@discord.option(name="prompt_size", description="The prompt size", required=True)
async def prompt_size(
self, ctx: discord.ApplicationContext, prompt_size: int = None
):
#only command that is not deprecated
# check if the guild is in the database
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
if curs_data.fetchone() is None:
await ctx.respond(
"This server is not setup, please run /setup", ephemeral=True
)
try:
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
data = curs_data.fetchone()
except:
data = None
if data[2] is None:
await ctx.respond("This server is not setup", ephemeral=True)
return
# set the advanced settings (max_tokens, temperature, frequency_penalty, presence_penalty, prompt_size) and also prompt_prefix to their default values
# check if the prompt size is valid
if prompt_size is None:
await ctx.respond("You must specify a prompt size", ephemeral=True)
return
if prompt_size < 1 or prompt_size > 15:
await ctx.respond("The prompt size must be between 1 and 15", ephemeral=True)
return
# update the prompt size
curs_data.execute(
"UPDATE data SET max_tokens = ?, temperature = ?, frequency_penalty = ?, presence_penalty = ?, prompt_size = ? WHERE guild_id = ?",
(64, 0.9, 0.0, 0.0, 5, ctx_to_guid(ctx)),
"UPDATE data SET prompt_size = ? WHERE guild_id = ?", (prompt_size, ctx_to_guid(ctx))
)
con_data.commit()
await ctx.respond(
"The advanced settings have been set to their default values",
ephemeral=True,
)
# create a command called "cancel" that deletes the last message sent by the bot in the response channel
await ctx.respond(f"Prompt size set to {prompt_size}", ephemeral=True)
# when a message is sent into a channel check if the guild is in the database and if the bot is enabled
@discord.slash_command(
@@ -200,14 +98,8 @@ class Settings(discord.Cog):
embed.add_field(name="guild_id", value=data[0], inline=False)
embed.add_field(name="API Key", value="secret", inline=False)
embed.add_field(name="Main channel ID", value=data[1], inline=False)
embed.add_field(name="Model", value=model, inline=False)
embed.add_field(name="Is Active", value=data[3], inline=False)
embed.add_field(name="Max Tokens", value=data[4], inline=False)
embed.add_field(name="Temperature", value=data[5], inline=False)
embed.add_field(name="Frequency Penalty", value=data[6], inline=False)
embed.add_field(name="Presence Penalty", value=data[7], inline=False)
embed.add_field(name="Prompt Size", value=data[9], inline=False)
embed.add_field(name="Uses Count Today", value=data[8], inline=False)
if data[10]:
embed.add_field(name="Prompt prefix", value=data[10], inline=False)
await ctx.respond(embed=embed, ephemeral=True)
@@ -337,20 +229,7 @@ class Settings(discord.Cog):
)
@default_permissions(administrator=True)
async def model(self, ctx: discord.ApplicationContext, model: str = "davinci"):
try:
curs_data.execute("SELECT * FROM model WHERE guild_id = ?", (ctx_to_guid(ctx),))
data = curs_data.fetchone()[1]
except:
data = None
if data is None:
curs_data.execute("INSERT INTO model VALUES (?, ?)", (ctx_to_guid(ctx), model))
else:
curs_data.execute(
"UPDATE model SET model_name = ? WHERE guild_id = ?",
(model, ctx_to_guid(ctx)),
)
con_data.commit()
await ctx.respond("Model changed !", ephemeral=True)
await ctx.respond("This command has been deprecated. Model gpt-3.5-turbo is always used by default", ephemeral=True)
async def images_recognition_autocomplete(ctx: discord.AutocompleteContext):
return [state for state in images_recognition if state.startswith(ctx.value)]

View File

@@ -43,15 +43,6 @@ curs_data = con_data.cursor()
con_premium = sqlite3.connect("./database/premium.db")
curs_premium = con_premium.cursor()
async def moderate(api_key, text):
openai.api_key = api_key
response = await openai.Moderation.acreate(
input=text,
)
return response["results"][0]["flagged"] # type: ignore
curs_data.execute(
"""CREATE TABLE IF NOT EXISTS data (guild_id text, channel_id text, api_key text, is_active boolean, max_tokens integer, temperature real, frequency_penalty real, presence_penalty real, uses_count_today integer, prompt_size integer, prompt_prefix text, tts boolean, pretend_to_be text, pretend_enabled boolean)"""
)

View File

@@ -1,10 +1,11 @@
import asyncio
import os
from src.config import curs_data, max_uses, curs_premium, moderate
from src.config import curs_data, max_uses, curs_premium
import re
import discord
import datetime
import json
from src.utils.misc import moderate
from src.utils.openaicaller import openai_caller
from src.functionscalls import add_reaction_to_last_message, reply_to_last_message, send_a_stock_image, create_a_thread, functions, server_normal_channel_functions
async def replace_mentions(content, bot):
@@ -16,6 +17,14 @@ async def replace_mentions(content, bot):
return content
async def chatgpt_process(self, messages, message: discord.Message, api_key, prompt, model):
async def error_call(error=""):
try:
if error != "":
await message.channel.send(f"An error occured: {error}", delete_after=10)
await message.channel.trigger_typing()
except:
pass
msgs = [] # create the msgs list
msgs.append({"role": "system", "content": prompt}) # add the prompt to the msgs list
name = "" # create the name variable
@@ -23,7 +32,7 @@ async def chatgpt_process(self, messages, message: discord.Message, api_key, pro
content = msg.content # get the content of the message
content = await replace_mentions(content, self.bot) # replace the mentions in the message
# if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it
if await moderate(api_key=api_key, text=content):
if await moderate(api_key, content, error_call):
embed = discord.Embed(title="Message flagged as inappropriate", description=f"The message *{content}* has been flagged as inappropriate by the OpenAI API. This means that if it hadn't been deleted, your openai account would have been banned. Please contact OpenAI support if you think this is a mistake.", color=discord.Color.brand_red())
await message.channel.send(f"{msg.author.mention}", embed=embed, delete_after=10)
await message.delete()
@@ -53,23 +62,12 @@ async def chatgpt_process(self, messages, message: discord.Message, api_key, pro
response = str()
caller = openai_caller(api_key=api_key)
async def error_call(error=""):
try:
if error != "":
await message.channel.send(f"An error occured: {error}", delete_after=10)
await message.channel.trigger_typing()
except:
pass
funcs = functions
if isinstance(message.channel, discord.TextChannel):
for func in server_normal_channel_functions:
funcs.append(func)
called_functions = functions if not isinstance(message.channel, discord.TextChannel) else server_normal_channel_functions + functions
response = await caller.generate_response(
error_call,
model=model,
messages=msgs,
functions=functions,
functions=called_functions,
#function_call="auto",
)
response = response["choices"][0]["message"] #type: ignore
@@ -99,7 +97,7 @@ async def chatgpt_process(self, messages, message: discord.Message, api_key, pro
if isinstance(message.channel, discord.TextChannel):
await create_a_thread(message.channel, name, reply)
else:
await message.channel.send("`A server normal text channel only function has been called in a DM channel. Please retry.`", delete_after=10)
await message.channel.send("`A server normal text channel only function has been called in a non standard channel. Please retry`", delete_after=10)
if name == "":
await message.channel.send("The function call is empty. Please retry.", delete_after=10)
else:

10
src/utils/misc.py Normal file
View File

@@ -0,0 +1,10 @@
from src.utils.openaicaller import openai_caller
async def moderate(api_key, text, recall_func=None):
caller = openai_caller(api_key)
response = await caller.moderation(
recall_func,
api_key=api_key,
input=text,
)
return response["results"][0]["flagged"] # type: ignore

View File

@@ -72,15 +72,25 @@ class openai_caller:
kwargs['messages'] = kwargs['messages'][1:]
print(f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}")
tokens = await num_tokens_from_messages(kwargs['messages'], kwargs['model'])
kwargs['api_key'] = self.api_key
callable = lambda: openai_module.ChatCompletion.acreate(**kwargs)
response = await self.retryal_call(recall_func, callable)
return response
async def moderation(self, recall_func=None, **kwargs):
if recall_func is None:
recall_func = lambda x: 2
callable = lambda: openai_module.Moderation.acreate(**kwargs)
response = await self.retryal_call(recall_func, callable)
return response
async def retryal_call(self, recall_func, callable):
i = 0
response = None
kwargs['api_key'] = self.api_key
while i < 10:
try:
response = await openai_module.ChatCompletion.acreate(
**kwargs
)
break
response = await callable()
return response
except APIError as e:
print(f"\n\n{bcolors.BOLD}{bcolors.WARNING}APIError. This is not your fault. Retrying...{bcolors.ENDC}")
await recall_func("`An APIError occurred. This is not your fault. Retrying...`")
@@ -121,7 +131,7 @@ class openai_caller:
if i == 10:
print(f"\n\n{bcolors.BOLD}{bcolors.FAIL}OpenAI API is not responding. Please try again later.{bcolors.ENDC}")
raise TimeoutError("OpenAI API is not responding. Please try again later.")
return response # type: ignore
return response
##testing
if __name__ == "__main__":