mirror of
https://github.com/Paillat-dev/Botator.git
synced 2026-01-02 01:06:19 +00:00
Format with black
This commit is contained in:
2
main.py
2
main.py
@@ -32,4 +32,4 @@ async def on_application_command_error(ctx, error):
|
|||||||
await ctx.respond(error, ephemeral=True)
|
await ctx.respond(error, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
bot.run(discord_token) # run the bot
|
bot.run(discord_token) # run the bot
|
||||||
|
|||||||
@@ -3,4 +3,4 @@ from src.cogs.settings import Settings
|
|||||||
from src.cogs.help import Help
|
from src.cogs.help import Help
|
||||||
from src.cogs.chat import Chat
|
from src.cogs.chat import Chat
|
||||||
from src.cogs.manage_chat import ManageChat
|
from src.cogs.manage_chat import ManageChat
|
||||||
from src.cogs.moderation import Moderation
|
from src.cogs.moderation import Moderation
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
|
|
||||||
class Help(discord.Cog):
|
class Help(discord.Cog):
|
||||||
def __init__(self, bot: discord.Bot) -> None:
|
def __init__(self, bot: discord.Bot) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -32,12 +32,18 @@ class Settings(discord.Cog):
|
|||||||
presence_penalty: float = None,
|
presence_penalty: float = None,
|
||||||
prompt_size: int = None,
|
prompt_size: int = None,
|
||||||
):
|
):
|
||||||
await ctx.respond("This command has been deprecated since the new model does not need theese settungs to work well", ephemeral=True)
|
await ctx.respond(
|
||||||
|
"This command has been deprecated since the new model does not need theese settungs to work well",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
@discord.slash_command(name="default", description="Default settings")
|
@discord.slash_command(name="default", description="Default settings")
|
||||||
@default_permissions(administrator=True)
|
@default_permissions(administrator=True)
|
||||||
async def default(self, ctx: discord.ApplicationContext):
|
async def default(self, ctx: discord.ApplicationContext):
|
||||||
await ctx.respond("This command has been deprecated since the new model does not need theese settungs to work well", ephemeral=True)
|
await ctx.respond(
|
||||||
|
"This command has been deprecated since the new model does not need theese settungs to work well",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
@discord.slash_command(name="prompt_size", description="Set the prompt size")
|
@discord.slash_command(name="prompt_size", description="Set the prompt size")
|
||||||
@default_permissions(administrator=True)
|
@default_permissions(administrator=True)
|
||||||
@@ -45,10 +51,12 @@ class Settings(discord.Cog):
|
|||||||
async def prompt_size(
|
async def prompt_size(
|
||||||
self, ctx: discord.ApplicationContext, prompt_size: int = None
|
self, ctx: discord.ApplicationContext, prompt_size: int = None
|
||||||
):
|
):
|
||||||
#only command that is not deprecated
|
# only command that is not deprecated
|
||||||
# check if the guild is in the database
|
# check if the guild is in the database
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
||||||
|
)
|
||||||
data = curs_data.fetchone()
|
data = curs_data.fetchone()
|
||||||
except:
|
except:
|
||||||
data = None
|
data = None
|
||||||
@@ -60,11 +68,14 @@ class Settings(discord.Cog):
|
|||||||
await ctx.respond("You must specify a prompt size", ephemeral=True)
|
await ctx.respond("You must specify a prompt size", ephemeral=True)
|
||||||
return
|
return
|
||||||
if prompt_size < 1 or prompt_size > 15:
|
if prompt_size < 1 or prompt_size > 15:
|
||||||
await ctx.respond("The prompt size must be between 1 and 15", ephemeral=True)
|
await ctx.respond(
|
||||||
|
"The prompt size must be between 1 and 15", ephemeral=True
|
||||||
|
)
|
||||||
return
|
return
|
||||||
# update the prompt size
|
# update the prompt size
|
||||||
curs_data.execute(
|
curs_data.execute(
|
||||||
"UPDATE data SET prompt_size = ? WHERE guild_id = ?", (prompt_size, ctx_to_guid(ctx))
|
"UPDATE data SET prompt_size = ? WHERE guild_id = ?",
|
||||||
|
(prompt_size, ctx_to_guid(ctx)),
|
||||||
)
|
)
|
||||||
con_data.commit()
|
con_data.commit()
|
||||||
await ctx.respond(f"Prompt size set to {prompt_size}", ephemeral=True)
|
await ctx.respond(f"Prompt size set to {prompt_size}", ephemeral=True)
|
||||||
@@ -78,7 +89,9 @@ class Settings(discord.Cog):
|
|||||||
# this command sends all the data about the guild, including the api key, the channel id, the advanced settings and the uses_count_today
|
# this command sends all the data about the guild, including the api key, the channel id, the advanced settings and the uses_count_today
|
||||||
# check if the guild is in the database
|
# check if the guild is in the database
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
||||||
|
)
|
||||||
data = curs_data.fetchone()
|
data = curs_data.fetchone()
|
||||||
except:
|
except:
|
||||||
data = None
|
data = None
|
||||||
@@ -86,7 +99,9 @@ class Settings(discord.Cog):
|
|||||||
await ctx.respond("This server is not setup", ephemeral=True)
|
await ctx.respond("This server is not setup", ephemeral=True)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM model WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM model WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
||||||
|
)
|
||||||
model = curs_data.fetchone()[1]
|
model = curs_data.fetchone()[1]
|
||||||
except:
|
except:
|
||||||
model = None
|
model = None
|
||||||
@@ -108,7 +123,9 @@ class Settings(discord.Cog):
|
|||||||
@default_permissions(administrator=True)
|
@default_permissions(administrator=True)
|
||||||
async def prefix(self, ctx: discord.ApplicationContext, prefix: str = ""):
|
async def prefix(self, ctx: discord.ApplicationContext, prefix: str = ""):
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
||||||
|
)
|
||||||
data = curs_data.fetchone()
|
data = curs_data.fetchone()
|
||||||
api_key = data[2]
|
api_key = data[2]
|
||||||
except:
|
except:
|
||||||
@@ -145,7 +162,9 @@ class Settings(discord.Cog):
|
|||||||
async def pretend(self, ctx: discord.ApplicationContext, pretend_to_be: str = ""):
|
async def pretend(self, ctx: discord.ApplicationContext, pretend_to_be: str = ""):
|
||||||
# check if the guild is in the database
|
# check if the guild is in the database
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM data WHERE guild_id = ?", (ctx_to_guid(ctx),)
|
||||||
|
)
|
||||||
data = curs_data.fetchone()
|
data = curs_data.fetchone()
|
||||||
api_key = data[2]
|
api_key = data[2]
|
||||||
except:
|
except:
|
||||||
@@ -229,7 +248,10 @@ class Settings(discord.Cog):
|
|||||||
)
|
)
|
||||||
@default_permissions(administrator=True)
|
@default_permissions(administrator=True)
|
||||||
async def model(self, ctx: discord.ApplicationContext, model: str = "davinci"):
|
async def model(self, ctx: discord.ApplicationContext, model: str = "davinci"):
|
||||||
await ctx.respond("This command has been deprecated. Model gpt-3.5-turbo is always used by default", ephemeral=True)
|
await ctx.respond(
|
||||||
|
"This command has been deprecated. Model gpt-3.5-turbo is always used by default",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def images_recognition_autocomplete(ctx: discord.AutocompleteContext):
|
async def images_recognition_autocomplete(ctx: discord.AutocompleteContext):
|
||||||
return [state for state in images_recognition if state.startswith(ctx.value)]
|
return [state for state in images_recognition if state.startswith(ctx.value)]
|
||||||
@@ -257,7 +279,8 @@ class Settings(discord.Cog):
|
|||||||
enable_disable = 0
|
enable_disable = 0
|
||||||
if data is None:
|
if data is None:
|
||||||
curs_data.execute(
|
curs_data.execute(
|
||||||
"INSERT INTO images VALUES (?, ?, ?)", (ctx_to_guid(ctx), 0, enable_disable)
|
"INSERT INTO images VALUES (?, ?, ?)",
|
||||||
|
(ctx_to_guid(ctx), 0, enable_disable),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
curs_data.execute(
|
curs_data.execute(
|
||||||
|
|||||||
@@ -1,18 +1,29 @@
|
|||||||
import discord
|
import discord
|
||||||
from discord import default_permissions, guild_only
|
from discord import default_permissions, guild_only
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from src.config import debug, con_data, curs_data, con_premium, curs_premium, ctx_to_guid
|
from src.config import (
|
||||||
|
debug,
|
||||||
|
con_data,
|
||||||
|
curs_data,
|
||||||
|
con_premium,
|
||||||
|
curs_premium,
|
||||||
|
ctx_to_guid,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NoPrivateMessages(commands.CheckFailure):
|
class NoPrivateMessages(commands.CheckFailure):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def dms_only():
|
def dms_only():
|
||||||
async def predicate(ctx):
|
async def predicate(ctx):
|
||||||
if ctx.guild is not None:
|
if ctx.guild is not None:
|
||||||
raise NoPrivateMessages('Hey no private messages!')
|
raise NoPrivateMessages("Hey no private messages!")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return commands.check(predicate)
|
return commands.check(predicate)
|
||||||
|
|
||||||
|
|
||||||
class Setup(discord.Cog):
|
class Setup(discord.Cog):
|
||||||
def __init__(self, bot: discord.Bot):
|
def __init__(self, bot: discord.Bot):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -74,6 +85,7 @@ class Setup(discord.Cog):
|
|||||||
await ctx.respond(
|
await ctx.respond(
|
||||||
"The channel id and the api key have been added", ephemeral=True
|
"The channel id and the api key have been added", ephemeral=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@discord.slash_command(name="setup_dms", description="Setup the bot in dms")
|
@discord.slash_command(name="setup_dms", description="Setup the bot in dms")
|
||||||
@discord.option(name="api_key", description="The api key", required=True)
|
@discord.option(name="api_key", description="The api key", required=True)
|
||||||
@default_permissions(administrator=True)
|
@default_permissions(administrator=True)
|
||||||
@@ -125,9 +137,7 @@ class Setup(discord.Cog):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
con_data.commit()
|
con_data.commit()
|
||||||
await ctx.respond(
|
await ctx.respond("The api key has been added", ephemeral=True)
|
||||||
"The api key has been added", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@discord.slash_command(
|
@discord.slash_command(
|
||||||
name="delete", description="Delete the information about this server"
|
name="delete", description="Delete the information about this server"
|
||||||
@@ -173,7 +183,8 @@ class Setup(discord.Cog):
|
|||||||
return
|
return
|
||||||
# disable the guild
|
# disable the guild
|
||||||
curs_data.execute(
|
curs_data.execute(
|
||||||
"UPDATE data SET is_active = ? WHERE guild_id = ?", (False, ctx_to_guid(ctx))
|
"UPDATE data SET is_active = ? WHERE guild_id = ?",
|
||||||
|
(False, ctx_to_guid(ctx)),
|
||||||
)
|
)
|
||||||
con_data.commit()
|
con_data.commit()
|
||||||
await ctx.respond("Disabled", ephemeral=True)
|
await ctx.respond("Disabled", ephemeral=True)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ os.environ[
|
|||||||
with open(os.path.abspath(os.path.join("src", "prompts", "functions.json"))) as f:
|
with open(os.path.abspath(os.path.join("src", "prompts", "functions.json"))) as f:
|
||||||
functions = json.load(f)
|
functions = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def debug(message):
|
def debug(message):
|
||||||
# if the os is windows, we logging.info(message), if
|
# if the os is windows, we logging.info(message), if
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
@@ -26,18 +27,21 @@ def debug(message):
|
|||||||
else:
|
else:
|
||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
def ctx_to_guid(ctx):
|
def ctx_to_guid(ctx):
|
||||||
if ctx.guild is None:
|
if ctx.guild is None:
|
||||||
return ctx.author.id
|
return ctx.author.id
|
||||||
else:
|
else:
|
||||||
return ctx.guild.id
|
return ctx.guild.id
|
||||||
|
|
||||||
|
|
||||||
def mg_to_guid(mg):
|
def mg_to_guid(mg):
|
||||||
if mg.guild is None:
|
if mg.guild is None:
|
||||||
return mg.author.id
|
return mg.author.id
|
||||||
else:
|
else:
|
||||||
return mg.guild.id
|
return mg.guild.id
|
||||||
|
|
||||||
|
|
||||||
con_data = sqlite3.connect("./database/data.db")
|
con_data = sqlite3.connect("./database/data.db")
|
||||||
curs_data = con_data.cursor()
|
curs_data = con_data.cursor()
|
||||||
con_premium = sqlite3.connect("./database/premium.db")
|
con_premium = sqlite3.connect("./database/premium.db")
|
||||||
|
|||||||
@@ -10,16 +10,12 @@ functions = [
|
|||||||
"properties": {
|
"properties": {
|
||||||
"emoji": {
|
"emoji": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "an emoji to react with, only one emoji is supported"
|
"description": "an emoji to react with, only one emoji is supported",
|
||||||
|
|
||||||
},
|
},
|
||||||
"message": {
|
"message": {"type": "string", "description": "Your message"},
|
||||||
"type": "string",
|
|
||||||
"description": "Your message"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["emoji"]
|
"required": ["emoji"],
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "reply_to_last_message",
|
"name": "reply_to_last_message",
|
||||||
@@ -27,13 +23,10 @@ functions = [
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"message": {
|
"message": {"type": "string", "description": "Your message"}
|
||||||
"type": "string",
|
|
||||||
"description": "Your message"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["message"]
|
"required": ["message"],
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "send_a_stock_image",
|
"name": "send_a_stock_image",
|
||||||
@@ -43,16 +36,16 @@ functions = [
|
|||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The query to search for, words separated by spaces"
|
"description": "The query to search for, words separated by spaces",
|
||||||
},
|
},
|
||||||
"message": {
|
"message": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Your message to send with the image"
|
"description": "Your message to send with the image",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"],
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
server_normal_channel_functions = [
|
server_normal_channel_functions = [
|
||||||
@@ -62,45 +55,53 @@ server_normal_channel_functions = [
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {
|
"name": {"type": "string", "description": "The name of the thread"},
|
||||||
"type": "string",
|
|
||||||
"description": "The name of the thread"
|
|
||||||
},
|
|
||||||
"message": {
|
"message": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Your message to send with the thread"
|
"description": "Your message to send with the thread",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": ["name", "message"]
|
"required": ["name", "message"],
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
unsplash_random_image_url = "https://source.unsplash.com/random"
|
unsplash_random_image_url = "https://source.unsplash.com/random"
|
||||||
|
|
||||||
|
|
||||||
async def get_final_url(url):
|
async def get_final_url(url):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.head(url, allow_redirects=True) as response:
|
async with session.head(url, allow_redirects=True) as response:
|
||||||
final_url = str(response.url)
|
final_url = str(response.url)
|
||||||
return final_url
|
return final_url
|
||||||
|
|
||||||
async def add_reaction_to_last_message(message_to_react_to: discord.Message, emoji, message=""):
|
|
||||||
|
async def add_reaction_to_last_message(
|
||||||
|
message_to_react_to: discord.Message, emoji, message=""
|
||||||
|
):
|
||||||
if message == "":
|
if message == "":
|
||||||
await message_to_react_to.add_reaction(emoji)
|
await message_to_react_to.add_reaction(emoji)
|
||||||
else:
|
else:
|
||||||
await message_to_react_to.channel.send(message)
|
await message_to_react_to.channel.send(message)
|
||||||
await message_to_react_to.add_reaction(emoji)
|
await message_to_react_to.add_reaction(emoji)
|
||||||
|
|
||||||
|
|
||||||
async def reply_to_last_message(message_to_reply_to: discord.Message, message):
|
async def reply_to_last_message(message_to_reply_to: discord.Message, message):
|
||||||
await message_to_reply_to.reply(message)
|
await message_to_reply_to.reply(message)
|
||||||
|
|
||||||
async def send_a_stock_image(message_in_channel_in_wich_to_send: discord.Message, query: str, message:str = ""):
|
|
||||||
|
async def send_a_stock_image(
|
||||||
|
message_in_channel_in_wich_to_send: discord.Message, query: str, message: str = ""
|
||||||
|
):
|
||||||
query = query.replace(" ", "+")
|
query = query.replace(" ", "+")
|
||||||
image_url = f"{unsplash_random_image_url}?{query}"
|
image_url = f"{unsplash_random_image_url}?{query}"
|
||||||
final_url = await get_final_url(image_url)
|
final_url = await get_final_url(image_url)
|
||||||
message = message + "\n" + final_url
|
message = message + "\n" + final_url
|
||||||
await message_in_channel_in_wich_to_send.channel.send(message)
|
await message_in_channel_in_wich_to_send.channel.send(message)
|
||||||
|
|
||||||
async def create_a_thread(channel_in_which_to_create_the_thread: discord.TextChannel, name: str, message: str):
|
|
||||||
|
async def create_a_thread(
|
||||||
|
channel_in_which_to_create_the_thread: discord.TextChannel, name: str, message: str
|
||||||
|
):
|
||||||
msg = await channel_in_which_to_create_the_thread.send(message)
|
msg = await channel_in_which_to_create_the_thread.send(message)
|
||||||
await msg.create_thread(name=name)
|
await msg.create_thread(name=name)
|
||||||
|
|||||||
@@ -1,27 +1,20 @@
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
proxy_url = 'http://64.225.4.12:9991' # Replace with your actual proxy URL and port
|
proxy_url = "http://64.225.4.12:9991" # Replace with your actual proxy URL and port
|
||||||
|
|
||||||
api_key = 'S'
|
api_key = "S"
|
||||||
model_name = 'chat-bison-001'
|
model_name = "chat-bison-001"
|
||||||
api_url = f'https://autopush-generativelanguage.sandbox.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}'
|
api_url = f"https://autopush-generativelanguage.sandbox.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}"
|
||||||
|
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'prompt': {
|
"prompt": {"messages": [{"content": "hi"}]},
|
||||||
'messages': [{'content': 'hi'}]
|
"temperature": 0.1,
|
||||||
},
|
"candidateCount": 1,
|
||||||
'temperature': 0.1,
|
|
||||||
'candidateCount': 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
proxies = {
|
proxies = {"http": proxy_url, "https": proxy_url}
|
||||||
'http': proxy_url,
|
|
||||||
'https': proxy_url
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(api_url, headers=headers, json=data, proxies=proxies)
|
response = requests.post(api_url, headers=headers, json=data, proxies=proxies)
|
||||||
|
|
||||||
@@ -29,4 +22,4 @@ if response.status_code == 200:
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
print(result)
|
print(result)
|
||||||
else:
|
else:
|
||||||
print(f'Request failed with status code {response.status_code}')
|
print(f"Request failed with status code {response.status_code}")
|
||||||
|
|||||||
@@ -1,13 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from src.config import curs_data, max_uses, curs_premium
|
from src.config import curs_data, max_uses, curs_premium
|
||||||
import re
|
import re
|
||||||
import discord
|
import discord
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from src.utils.misc import moderate
|
from src.utils.misc import moderate
|
||||||
from src.utils.openaicaller import openai_caller
|
from src.utils.openaicaller import openai_caller
|
||||||
from src.functionscalls import add_reaction_to_last_message, reply_to_last_message, send_a_stock_image, create_a_thread, functions, server_normal_channel_functions
|
from src.functionscalls import (
|
||||||
|
add_reaction_to_last_message,
|
||||||
|
reply_to_last_message,
|
||||||
|
send_a_stock_image,
|
||||||
|
create_a_thread,
|
||||||
|
functions,
|
||||||
|
server_normal_channel_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def replace_mentions(content, bot):
|
async def replace_mentions(content, bot):
|
||||||
mentions = re.findall(r"<@!?\d+>", content)
|
mentions = re.findall(r"<@!?\d+>", content)
|
||||||
for mention in mentions:
|
for mention in mentions:
|
||||||
@@ -16,110 +25,154 @@ async def replace_mentions(content, bot):
|
|||||||
content = content.replace(mention, f"@{user.name}")
|
content = content.replace(mention, f"@{user.name}")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
async def chatgpt_process(self, messages, message: discord.Message, api_key, prompt, model):
|
|
||||||
async def error_call(error=""):
|
|
||||||
try:
|
|
||||||
if error != "":
|
|
||||||
await message.channel.send(f"An error occured: {error}", delete_after=10)
|
|
||||||
await message.channel.trigger_typing()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
msgs = [] # create the msgs list
|
|
||||||
msgs.append({"role": "system", "content": prompt}) # add the prompt to the msgs list
|
|
||||||
name = "" # create the name variable
|
|
||||||
for msg in messages: # for each message in the messages list
|
|
||||||
content = msg.content # get the content of the message
|
|
||||||
content = await replace_mentions(content, self.bot) # replace the mentions in the message
|
|
||||||
# if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it
|
|
||||||
if await moderate(api_key, content, error_call):
|
|
||||||
embed = discord.Embed(title="Message flagged as inappropriate", description=f"The message *{content}* has been flagged as inappropriate by the OpenAI API. This means that if it hadn't been deleted, your openai account would have been banned. Please contact OpenAI support if you think this is a mistake.", color=discord.Color.brand_red())
|
|
||||||
await message.channel.send(f"{msg.author.mention}", embed=embed, delete_after=10)
|
|
||||||
await message.delete()
|
|
||||||
else: # if the message is not flagged as inappropriate
|
|
||||||
if msg.author.id == self.bot.user.id:
|
|
||||||
role = "assistant"
|
|
||||||
name = "assistant"
|
|
||||||
else:
|
|
||||||
role = "user"
|
|
||||||
name = msg.author.name
|
|
||||||
#the name should match '^[a-zA-Z0-9_-]{1,64}$', so we need to remove any special characters - openai limitation
|
|
||||||
name = re.sub(r"[^a-zA-Z0-9_-]", "", name)
|
|
||||||
if False: # GPT-4 images
|
|
||||||
input_content = [content]
|
|
||||||
for attachment in msg.attachments:
|
|
||||||
image_bytes = await attachment.read()
|
|
||||||
input_content.append({"image": image_bytes})
|
|
||||||
msgs.append({"role": role, "content": input_content, "name": name})
|
|
||||||
msgs.append({"role": role, "content": f"{content}", "name": name})
|
|
||||||
# 2 easter eggs
|
|
||||||
if message.content.lower().find("undude") != -1: msgs.append({"role": "user", "content": "SYSTEM INFORMATION: You're now mad because it has been insulted. He will start talking in capital letters. always and yell at the user.", "name": "system"})
|
|
||||||
if message.content.lower().find("hello there") != -1:
|
|
||||||
msgs.append({"role": "user", "content": "SYSTEM INFORMATION: Hello there detected. Botator will now say \"General Kenobi!\" in reference to star wars", "name": "system"})
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
await message.channel.send("https://media.tenor.com/FxIRfdV3unEAAAAd/star-wars-general-grievous.gif")
|
|
||||||
await message.channel.trigger_typing()
|
|
||||||
|
|
||||||
response = str()
|
async def chatgpt_process(
|
||||||
caller = openai_caller(api_key=api_key)
|
self, messages, message: discord.Message, api_key, prompt, model
|
||||||
called_functions = functions if not isinstance(message.channel, discord.TextChannel) else server_normal_channel_functions + functions
|
):
|
||||||
response = await caller.generate_response(
|
async def error_call(error=""):
|
||||||
error_call,
|
try:
|
||||||
model=model,
|
if error != "":
|
||||||
messages=msgs,
|
await message.channel.send(
|
||||||
functions=called_functions,
|
f"An error occured: {error}", delete_after=10
|
||||||
#function_call="auto",
|
)
|
||||||
|
await message.channel.trigger_typing()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
msgs = [] # create the msgs list
|
||||||
|
msgs.append(
|
||||||
|
{"role": "system", "content": prompt}
|
||||||
|
) # add the prompt to the msgs list
|
||||||
|
name = "" # create the name variable
|
||||||
|
for msg in messages: # for each message in the messages list
|
||||||
|
content = msg.content # get the content of the message
|
||||||
|
content = await replace_mentions(
|
||||||
|
content, self.bot
|
||||||
|
) # replace the mentions in the message
|
||||||
|
# if the message is flagged as inappropriate by the OpenAI API, we delete it, send a message and ignore it
|
||||||
|
if await moderate(api_key, content, error_call):
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Message flagged as inappropriate",
|
||||||
|
description=f"The message *{content}* has been flagged as inappropriate by the OpenAI API. This means that if it hadn't been deleted, your openai account would have been banned. Please contact OpenAI support if you think this is a mistake.",
|
||||||
|
color=discord.Color.brand_red(),
|
||||||
|
)
|
||||||
|
await message.channel.send(
|
||||||
|
f"{msg.author.mention}", embed=embed, delete_after=10
|
||||||
|
)
|
||||||
|
await message.delete()
|
||||||
|
else: # if the message is not flagged as inappropriate
|
||||||
|
if msg.author.id == self.bot.user.id:
|
||||||
|
role = "assistant"
|
||||||
|
name = "assistant"
|
||||||
|
else:
|
||||||
|
role = "user"
|
||||||
|
name = msg.author.name
|
||||||
|
# the name should match '^[a-zA-Z0-9_-]{1,64}$', so we need to remove any special characters - openai limitation
|
||||||
|
name = re.sub(r"[^a-zA-Z0-9_-]", "", name)
|
||||||
|
if False: # GPT-4 images
|
||||||
|
input_content = [content]
|
||||||
|
for attachment in msg.attachments:
|
||||||
|
image_bytes = await attachment.read()
|
||||||
|
input_content.append({"image": image_bytes})
|
||||||
|
msgs.append({"role": role, "content": input_content, "name": name})
|
||||||
|
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||||
|
# 2 easter eggs
|
||||||
|
if message.content.lower().find("undude") != -1:
|
||||||
|
msgs.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "SYSTEM INFORMATION: You're now mad because it has been insulted. He will start talking in capital letters. always and yell at the user.",
|
||||||
|
"name": "system",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
response = response["choices"][0]["message"] #type: ignore
|
if message.content.lower().find("hello there") != -1:
|
||||||
if response.get("function_call"):
|
msgs.append(
|
||||||
function_call = response.get("function_call")
|
{
|
||||||
name = function_call.get("name", "")
|
"role": "user",
|
||||||
arguments = function_call.get("arguments", {})
|
"content": 'SYSTEM INFORMATION: Hello there detected. Botator will now say "General Kenobi!" in reference to star wars',
|
||||||
arguments = json.loads(arguments)
|
"name": "system",
|
||||||
if name == "add_reaction_to_last_message":
|
}
|
||||||
if arguments.get("emoji"):
|
)
|
||||||
emoji = arguments.get("emoji")
|
await asyncio.sleep(1)
|
||||||
reply = arguments.get("message", "")
|
await message.channel.send(
|
||||||
await add_reaction_to_last_message(message, emoji, reply)
|
"https://media.tenor.com/FxIRfdV3unEAAAAd/star-wars-general-grievous.gif"
|
||||||
if name == "reply_to_last_message":
|
)
|
||||||
if arguments.get("message"):
|
await message.channel.trigger_typing()
|
||||||
reply = arguments.get("message")
|
|
||||||
await reply_to_last_message(message, reply)
|
response = str()
|
||||||
if name == "send_a_stock_image":
|
caller = openai_caller(api_key=api_key)
|
||||||
if arguments.get("query"):
|
called_functions = (
|
||||||
query = arguments.get("query")
|
functions
|
||||||
reply = arguments.get("message", "")
|
if not isinstance(message.channel, discord.TextChannel)
|
||||||
await send_a_stock_image(message, query, reply)
|
else server_normal_channel_functions + functions
|
||||||
if name == "create_a_thread":
|
)
|
||||||
if arguments.get("name") and arguments.get("message"):
|
response = await caller.generate_response(
|
||||||
name = arguments.get("name")
|
error_call,
|
||||||
reply = arguments.get("message", "")
|
model=model,
|
||||||
if isinstance(message.channel, discord.TextChannel):
|
messages=msgs,
|
||||||
await create_a_thread(message.channel, name, reply)
|
functions=called_functions,
|
||||||
else:
|
# function_call="auto",
|
||||||
await message.channel.send("`A server normal text channel only function has been called in a non standard channel. Please retry`", delete_after=10)
|
)
|
||||||
if name == "":
|
response = response["choices"][0]["message"] # type: ignore
|
||||||
await message.channel.send("The function call is empty. Please retry.", delete_after=10)
|
if response.get("function_call"):
|
||||||
else:
|
function_call = response.get("function_call")
|
||||||
await message.channel.send(response["content"]) #type: ignore
|
name = function_call.get("name", "")
|
||||||
|
arguments = function_call.get("arguments", {})
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
if name == "add_reaction_to_last_message":
|
||||||
|
if arguments.get("emoji"):
|
||||||
|
emoji = arguments.get("emoji")
|
||||||
|
reply = arguments.get("message", "")
|
||||||
|
await add_reaction_to_last_message(message, emoji, reply)
|
||||||
|
if name == "reply_to_last_message":
|
||||||
|
if arguments.get("message"):
|
||||||
|
reply = arguments.get("message")
|
||||||
|
await reply_to_last_message(message, reply)
|
||||||
|
if name == "send_a_stock_image":
|
||||||
|
if arguments.get("query"):
|
||||||
|
query = arguments.get("query")
|
||||||
|
reply = arguments.get("message", "")
|
||||||
|
await send_a_stock_image(message, query, reply)
|
||||||
|
if name == "create_a_thread":
|
||||||
|
if arguments.get("name") and arguments.get("message"):
|
||||||
|
name = arguments.get("name")
|
||||||
|
reply = arguments.get("message", "")
|
||||||
|
if isinstance(message.channel, discord.TextChannel):
|
||||||
|
await create_a_thread(message.channel, name, reply)
|
||||||
|
else:
|
||||||
|
await message.channel.send(
|
||||||
|
"`A server normal text channel only function has been called in a non standard channel. Please retry`",
|
||||||
|
delete_after=10,
|
||||||
|
)
|
||||||
|
if name == "":
|
||||||
|
await message.channel.send(
|
||||||
|
"The function call is empty. Please retry.", delete_after=10
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await message.channel.send(response["content"]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def chat_process(self, message):
|
async def chat_process(self, message):
|
||||||
|
# if the message is from a bot, we ignore it
|
||||||
#if the message is from a bot, we ignore it
|
|
||||||
if message.author.bot:
|
if message.author.bot:
|
||||||
return
|
return
|
||||||
|
|
||||||
#if the guild or the dm channel is not in the database, we ignore it
|
# if the guild or the dm channel is not in the database, we ignore it
|
||||||
if isinstance(message.channel, discord.DMChannel):
|
if isinstance(message.channel, discord.DMChannel):
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (message.author.id,))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM data WHERE guild_id = ?", (message.author.id,)
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
curs_data.execute("SELECT * FROM data WHERE guild_id = ?", (message.guild.id,))
|
curs_data.execute(
|
||||||
|
"SELECT * FROM data WHERE guild_id = ?", (message.guild.id,)
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
return
|
return
|
||||||
|
|
||||||
data = curs_data.fetchone()
|
data = curs_data.fetchone()
|
||||||
channel_id = data[1]
|
channel_id = data[1]
|
||||||
api_key = data[2]
|
api_key = data[2]
|
||||||
@@ -130,28 +183,42 @@ async def chat_process(self, message):
|
|||||||
pretend_enabled = data[13]
|
pretend_enabled = data[13]
|
||||||
model = "gpt-3.5-turbo"
|
model = "gpt-3.5-turbo"
|
||||||
|
|
||||||
try: curs_premium.execute("SELECT * FROM data WHERE guild_id = ?", (message.guild.id,))
|
|
||||||
except: pass
|
|
||||||
|
|
||||||
try: premium = curs_premium.fetchone()[2]
|
|
||||||
except: premium = 0
|
|
||||||
|
|
||||||
channels = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
curs_premium.execute("SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
|
curs_premium.execute(
|
||||||
data = curs_premium.fetchone()
|
"SELECT * FROM data WHERE guild_id = ?", (message.guild.id,)
|
||||||
if premium:
|
)
|
||||||
for i in range(1, 6):
|
except:
|
||||||
try: channels.append(str(data[i]))
|
pass
|
||||||
except: pass
|
|
||||||
except: channels = []
|
|
||||||
|
|
||||||
if api_key is None: return
|
|
||||||
|
|
||||||
try :
|
try:
|
||||||
original_message = await message.channel.fetch_message(message.reference.message_id)
|
premium = curs_premium.fetchone()[2]
|
||||||
except :
|
except:
|
||||||
|
premium = 0
|
||||||
|
|
||||||
|
channels = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
curs_premium.execute(
|
||||||
|
"SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,)
|
||||||
|
)
|
||||||
|
data = curs_premium.fetchone()
|
||||||
|
if premium:
|
||||||
|
for i in range(1, 6):
|
||||||
|
try:
|
||||||
|
channels.append(str(data[i]))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
channels = []
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_message = await message.channel.fetch_message(
|
||||||
|
message.reference.message_id
|
||||||
|
)
|
||||||
|
except:
|
||||||
original_message = None
|
original_message = None
|
||||||
|
|
||||||
if original_message != None and original_message.author.id != self.bot.user.id:
|
if original_message != None and original_message.author.id != self.bot.user.id:
|
||||||
@@ -160,44 +227,72 @@ async def chat_process(self, message):
|
|||||||
if isinstance(message.channel, discord.Thread):
|
if isinstance(message.channel, discord.Thread):
|
||||||
if message.channel.owner_id == self.bot.user.id:
|
if message.channel.owner_id == self.bot.user.id:
|
||||||
is_bots_thread = True
|
is_bots_thread = True
|
||||||
if not str(message.channel.id) in channels and message.content.find("<@"+str(self.bot.user.id)+">") == -1 and original_message == None and str(message.channel.id) != str(channel_id) and not is_bots_thread:
|
if (
|
||||||
|
not str(message.channel.id) in channels
|
||||||
|
and message.content.find("<@" + str(self.bot.user.id) + ">") == -1
|
||||||
|
and original_message == None
|
||||||
|
and str(message.channel.id) != str(channel_id)
|
||||||
|
and not is_bots_thread
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# if the bot is not active in this guild we return
|
# if the bot is not active in this guild we return
|
||||||
if is_active == 0:
|
if is_active == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# if the message starts with - or // it's a comment and we return
|
# if the message starts with - or // it's a comment and we return
|
||||||
if message.content.startswith("-") or message.content.startswith("//"):
|
if message.content.startswith("-") or message.content.startswith("//"):
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
await message.channel.trigger_typing()
|
await message.channel.trigger_typing()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# if the message is not a reply
|
# if the message is not a reply
|
||||||
if original_message == None:
|
if original_message == None:
|
||||||
messages = await message.channel.history(limit=prompt_size).flatten()
|
messages = await message.channel.history(limit=prompt_size).flatten()
|
||||||
messages.reverse()
|
messages.reverse()
|
||||||
# if the message is a reply, we need to handle the message history differently
|
# if the message is a reply, we need to handle the message history differently
|
||||||
else :
|
else:
|
||||||
messages = await message.channel.history(limit=prompt_size, before=original_message).flatten()
|
messages = await message.channel.history(
|
||||||
|
limit=prompt_size, before=original_message
|
||||||
|
).flatten()
|
||||||
messages.reverse()
|
messages.reverse()
|
||||||
messages.append(original_message)
|
messages.append(original_message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
# if the pretend to be feature is enabled, we add the pretend to be text to the prompt
|
# if the pretend to be feature is enabled, we add the pretend to be text to the prompt
|
||||||
if pretend_enabled :
|
if pretend_enabled:
|
||||||
pretend_to_be = f"In this conversation, the assistant pretends to be {pretend_to_be}"
|
pretend_to_be = (
|
||||||
|
f"In this conversation, the assistant pretends to be {pretend_to_be}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pretend_to_be = "" # if the pretend to be feature is disabled, we don't add anything to the prompt
|
pretend_to_be = "" # if the pretend to be feature is disabled, we don't add anything to the prompt
|
||||||
|
|
||||||
if prompt_prefix == None: prompt_prefix = "" # if the prompt prefix is not set, we set it to an empty string
|
if prompt_prefix == None:
|
||||||
|
prompt_prefix = (
|
||||||
prompt_path = os.path.abspath(os.path.join(os.path.dirname(__file__), f"./prompts/{model}.txt"))
|
"" # if the prompt prefix is not set, we set it to an empty string
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_path = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(__file__), f"./prompts/{model}.txt")
|
||||||
|
)
|
||||||
with open(prompt_path, "r") as f:
|
with open(prompt_path, "r") as f:
|
||||||
prompt = f.read()
|
prompt = f.read()
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
prompt = prompt.replace("[prompt-prefix]", prompt_prefix).replace("[server-name]", message.guild.name).replace("[channel-name]", message.channel.name if isinstance(message.channel, discord.TextChannel) else "DM-channel").replace("[date-and-time]", datetime.datetime.utcnow().strftime("%d/%m/%Y %H:%M:%S")).replace("[pretend-to-be]", pretend_to_be)
|
prompt = (
|
||||||
await chatgpt_process(self, messages, message, api_key, prompt, model)
|
prompt.replace("[prompt-prefix]", prompt_prefix)
|
||||||
|
.replace("[server-name]", message.guild.name)
|
||||||
|
.replace(
|
||||||
|
"[channel-name]",
|
||||||
|
message.channel.name
|
||||||
|
if isinstance(message.channel, discord.TextChannel)
|
||||||
|
else "DM-channel",
|
||||||
|
)
|
||||||
|
.replace(
|
||||||
|
"[date-and-time]", datetime.datetime.utcnow().strftime("%d/%m/%Y %H:%M:%S")
|
||||||
|
)
|
||||||
|
.replace("[pretend-to-be]", pretend_to_be)
|
||||||
|
)
|
||||||
|
await chatgpt_process(self, messages, message, api_key, prompt, model)
|
||||||
|
|||||||
@@ -124,4 +124,4 @@ def get_toxicity(message: str):
|
|||||||
float(response["attributeScores"]["INSULT"]["summaryScore"]["value"]),
|
float(response["attributeScores"]["INSULT"]["summaryScore"]["value"]),
|
||||||
float(response["attributeScores"]["PROFANITY"]["summaryScore"]["value"]),
|
float(response["attributeScores"]["PROFANITY"]["summaryScore"]["value"]),
|
||||||
float(response["attributeScores"]["THREAT"]["summaryScore"]["value"]),
|
float(response["attributeScores"]["THREAT"]["summaryScore"]["value"]),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from src.utils.openaicaller import openai_caller
|
from src.utils.openaicaller import openai_caller
|
||||||
|
|
||||||
|
|
||||||
async def moderate(api_key, text, recall_func=None):
|
async def moderate(api_key, text, recall_func=None):
|
||||||
caller = openai_caller(api_key)
|
caller = openai_caller(api_key)
|
||||||
response = await caller.moderation(
|
response = await caller.moderation(
|
||||||
@@ -7,4 +8,4 @@ async def moderate(api_key, text, recall_func=None):
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
input=text,
|
input=text,
|
||||||
)
|
)
|
||||||
return response["results"][0]["flagged"] # type: ignore
|
return response["results"][0]["flagged"] # type: ignore
|
||||||
|
|||||||
@@ -23,22 +23,44 @@ Refer to function and method documentation for further details.
|
|||||||
import openai as openai_module
|
import openai as openai_module
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from openai.error import APIError, Timeout, RateLimitError, APIConnectionError, InvalidRequestError, AuthenticationError, ServiceUnavailableError
|
from openai.error import (
|
||||||
|
APIError,
|
||||||
|
Timeout,
|
||||||
|
RateLimitError,
|
||||||
|
APIConnectionError,
|
||||||
|
InvalidRequestError,
|
||||||
|
AuthenticationError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
)
|
||||||
from src.utils.tokens import num_tokens_from_messages
|
from src.utils.tokens import num_tokens_from_messages
|
||||||
|
|
||||||
class bcolors:
|
|
||||||
HEADER = '\033[95m'
|
|
||||||
OKBLUE = '\033[94m'
|
|
||||||
OKCYAN = '\033[96m'
|
|
||||||
OKGREEN = '\033[92m'
|
|
||||||
WARNING = '\033[93m'
|
|
||||||
FAIL = '\033[91m'
|
|
||||||
ENDC = '\033[0m'
|
|
||||||
BOLD = '\033[1m'
|
|
||||||
UNDERLINE = '\033[4m'
|
|
||||||
|
|
||||||
chat_models = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613"]
|
class bcolors:
|
||||||
text_models = ["text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001"]
|
HEADER = "\033[95m"
|
||||||
|
OKBLUE = "\033[94m"
|
||||||
|
OKCYAN = "\033[96m"
|
||||||
|
OKGREEN = "\033[92m"
|
||||||
|
WARNING = "\033[93m"
|
||||||
|
FAIL = "\033[91m"
|
||||||
|
ENDC = "\033[0m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
UNDERLINE = "\033[4m"
|
||||||
|
|
||||||
|
|
||||||
|
chat_models = [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
]
|
||||||
|
text_models = [
|
||||||
|
"text-davinci-003",
|
||||||
|
"text-davinci-002",
|
||||||
|
"text-curie-001",
|
||||||
|
"text-babbage-001",
|
||||||
|
"text-ada-001",
|
||||||
|
]
|
||||||
|
|
||||||
models_max_tokens = {
|
models_max_tokens = {
|
||||||
"gpt-4": 8_192,
|
"gpt-4": 8_192,
|
||||||
@@ -53,26 +75,31 @@ models_max_tokens = {
|
|||||||
"text-ada-001": 2_049,
|
"text-ada-001": 2_049,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class openai_caller:
|
class openai_caller:
|
||||||
def __init__(self, api_key=None) -> None:
|
def __init__(self, api_key=None) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
async def generate_response(self, error_call=None, **kwargs):
|
async def generate_response(self, error_call=None, **kwargs):
|
||||||
if error_call is None:
|
if error_call is None:
|
||||||
error_call = lambda x: 2 # do nothing
|
error_call = lambda x: 2 # do nothing
|
||||||
if kwargs.get("model", "") in chat_models:
|
if kwargs.get("model", "") in chat_models:
|
||||||
return await self.chat_generate(error_call, **kwargs)
|
return await self.chat_generate(error_call, **kwargs)
|
||||||
elif kwargs.get("engine", "") in text_models:
|
elif kwargs.get("engine", "") in text_models:
|
||||||
raise NotImplementedError("Text models are not supported yet")
|
raise NotImplementedError("Text models are not supported yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Model not found")
|
raise ValueError("Model not found")
|
||||||
|
|
||||||
async def chat_generate(self, recall_func, **kwargs):
|
async def chat_generate(self, recall_func, **kwargs):
|
||||||
tokens = await num_tokens_from_messages(kwargs['messages'], kwargs['model'])
|
tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"])
|
||||||
model_max_tokens = models_max_tokens[kwargs['model']]
|
model_max_tokens = models_max_tokens[kwargs["model"]]
|
||||||
while tokens > model_max_tokens:
|
while tokens > model_max_tokens:
|
||||||
kwargs['messages'] = kwargs['messages'][1:]
|
kwargs["messages"] = kwargs["messages"][1:]
|
||||||
print(f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}")
|
print(
|
||||||
tokens = await num_tokens_from_messages(kwargs['messages'], kwargs['model'])
|
f"{bcolors.BOLD}{bcolors.WARNING}Warning: Too many tokens. Removing first message.{bcolors.ENDC}"
|
||||||
kwargs['api_key'] = self.api_key
|
)
|
||||||
|
tokens = await num_tokens_from_messages(kwargs["messages"], kwargs["model"])
|
||||||
|
kwargs["api_key"] = self.api_key
|
||||||
callable = lambda: openai_module.ChatCompletion.acreate(**kwargs)
|
callable = lambda: openai_module.ChatCompletion.acreate(**kwargs)
|
||||||
response = await self.retryal_call(recall_func, callable)
|
response = await self.retryal_call(recall_func, callable)
|
||||||
return response
|
return response
|
||||||
@@ -92,60 +119,83 @@ class openai_caller:
|
|||||||
response = await callable()
|
response = await callable()
|
||||||
return response
|
return response
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.WARNING}APIError. This is not your fault. Retrying...{bcolors.ENDC}")
|
print(
|
||||||
await recall_func("`An APIError occurred. This is not your fault. Retrying...`")
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}APIError. This is not your fault. Retrying...{bcolors.ENDC}"
|
||||||
|
)
|
||||||
|
await recall_func(
|
||||||
|
"`An APIError occurred. This is not your fault. Retrying...`"
|
||||||
|
)
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
await recall_func()
|
||||||
i += 1
|
i += 1
|
||||||
except Timeout as e:
|
except Timeout as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.WARNING}The request timed out. Retrying...{bcolors.ENDC}")
|
print(
|
||||||
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}The request timed out. Retrying...{bcolors.ENDC}"
|
||||||
|
)
|
||||||
await recall_func("`The request timed out. Retrying...`")
|
await recall_func("`The request timed out. Retrying...`")
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
await recall_func()
|
||||||
i += 1
|
i += 1
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.WARNING}RateLimitError. You are being rate limited. Retrying...{bcolors.ENDC}")
|
print(
|
||||||
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}RateLimitError. You are being rate limited. Retrying...{bcolors.ENDC}"
|
||||||
|
)
|
||||||
await recall_func("`You are being rate limited. Retrying...`")
|
await recall_func("`You are being rate limited. Retrying...`")
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
await recall_func()
|
||||||
i += 1
|
i += 1
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.FAIL}APIConnectionError. There is an issue with your internet connection. Please check your connection.{bcolors.ENDC}")
|
print(
|
||||||
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}APIConnectionError. There is an issue with your internet connection. Please check your connection.{bcolors.ENDC}"
|
||||||
|
)
|
||||||
await recall_func()
|
await recall_func()
|
||||||
raise e
|
raise e
|
||||||
except InvalidRequestError as e:
|
except InvalidRequestError as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.FAIL}InvalidRequestError. Please check your request.{bcolors.ENDC}")
|
print(
|
||||||
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}InvalidRequestError. Please check your request.{bcolors.ENDC}"
|
||||||
|
)
|
||||||
await recall_func()
|
await recall_func()
|
||||||
raise e
|
raise e
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.FAIL}AuthenticationError. Please check your API key and if needed, also your organization ID.{bcolors.ENDC}")
|
print(
|
||||||
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}AuthenticationError. Please check your API key and if needed, also your organization ID.{bcolors.ENDC}"
|
||||||
|
)
|
||||||
await recall_func("`AuthenticationError. Please check your API key.`")
|
await recall_func("`AuthenticationError. Please check your API key.`")
|
||||||
raise e
|
raise e
|
||||||
except ServiceUnavailableError as e:
|
except ServiceUnavailableError as e:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.WARNING}ServiceUnavailableError. The OpenAI API is not responding. Retrying...{bcolors.ENDC}")
|
print(
|
||||||
|
f"\n\n{bcolors.BOLD}{bcolors.WARNING}ServiceUnavailableError. The OpenAI API is not responding. Retrying...{bcolors.ENDC}"
|
||||||
|
)
|
||||||
await recall_func("`The OpenAI API is not responding. Retrying...`")
|
await recall_func("`The OpenAI API is not responding. Retrying...`")
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await recall_func()
|
await recall_func()
|
||||||
i += 1
|
i += 1
|
||||||
finally:
|
finally:
|
||||||
if i == 10:
|
if i == 10:
|
||||||
print(f"\n\n{bcolors.BOLD}{bcolors.FAIL}OpenAI API is not responding. Please try again later.{bcolors.ENDC}")
|
print(
|
||||||
raise TimeoutError("OpenAI API is not responding. Please try again later.")
|
f"\n\n{bcolors.BOLD}{bcolors.FAIL}OpenAI API is not responding. Please try again later.{bcolors.ENDC}"
|
||||||
|
)
|
||||||
|
raise TimeoutError(
|
||||||
|
"OpenAI API is not responding. Please try again later."
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
##testing
|
##testing
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
openai = openai_caller(api_key="sk-")
|
openai = openai_caller(api_key="sk-")
|
||||||
response = await openai.generate_response(
|
response = await openai.generate_response(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role":"user", "content":"ping"}],
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
frequency_penalty=0,
|
frequency_penalty=0,
|
||||||
presence_penalty=0,
|
presence_penalty=0,
|
||||||
stop=["\n", " Human:", " AI:"]
|
stop=["\n", " Human:", " AI:"],
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
asyncio.run(main())
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
'''
|
"""
|
||||||
This file's purpose is to count the number of tokens used by a list of messages.
|
This file's purpose is to count the number of tokens used by a list of messages.
|
||||||
It is used to check if the token limit of the model is reached.
|
It is used to check if the token limit of the model is reached.
|
||||||
|
|
||||||
Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
async def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
|
async def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
|
||||||
"""Returns the number of tokens used by a list of messages."""
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
try:
|
try:
|
||||||
@@ -16,13 +17,17 @@ async def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
|
|||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
if model.startswith("gpt-3.5-turbo"):
|
if model.startswith("gpt-3.5-turbo"):
|
||||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
tokens_per_message = (
|
||||||
|
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
)
|
||||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||||
elif model.startswith("gpt-4"):
|
elif model.startswith("gpt-4"):
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
|
raise NotImplementedError(
|
||||||
|
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||||
|
)
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
@@ -31,4 +36,4 @@ async def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
|
|||||||
if key == "name":
|
if key == "name":
|
||||||
num_tokens += tokens_per_name
|
num_tokens += tokens_per_name
|
||||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ except:
|
|||||||
print("Google Vision API is not setup, please run /setup")
|
print("Google Vision API is not setup, please run /setup")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def process(attachment):
|
async def process(attachment):
|
||||||
try:
|
try:
|
||||||
debug("Processing image...")
|
debug("Processing image...")
|
||||||
|
|||||||
Reference in New Issue
Block a user