mirror of
https://github.com/Paillat-dev/Botator.git
synced 2026-01-02 01:06:19 +00:00
Added image recognition
This commit is contained in:
@@ -3,7 +3,7 @@ from config import debug, conn, c, moderate
|
|||||||
from discord import default_permissions
|
from discord import default_permissions
|
||||||
import openai
|
import openai
|
||||||
models = ["davinci", "chatGPT"]
|
models = ["davinci", "chatGPT"]
|
||||||
|
images_recognition = ["enable", "disable"]
|
||||||
class Settings (discord.Cog) :
|
class Settings (discord.Cog) :
|
||||||
def __init__(self, bot: discord.Bot) -> None:
|
def __init__(self, bot: discord.Bot) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -30,7 +30,7 @@ class Settings (discord.Cog) :
|
|||||||
await ctx.respond("You must enter at least one argument", ephemeral=True)
|
await ctx.respond("You must enter at least one argument", ephemeral=True)
|
||||||
return
|
return
|
||||||
#check if the user has entered valid arguments
|
#check if the user has entered valid arguments
|
||||||
if max_tokens is not None and (max_tokens < 1 or max_tokens > 2048):
|
if max_tokens is not None and (max_tokens < 1 or max_tokens > 4000):
|
||||||
await ctx.respond("Invalid max tokens", ephemeral=True)
|
await ctx.respond("Invalid max tokens", ephemeral=True)
|
||||||
return
|
return
|
||||||
if temperature is not None and (temperature < 0.0 or temperature > 1.0):
|
if temperature is not None and (temperature < 0.0 or temperature > 1.0):
|
||||||
@@ -226,4 +226,22 @@ class Settings (discord.Cog) :
|
|||||||
if data is None: c.execute("INSERT INTO model VALUES (?, ?)", (ctx.guild.id, model))
|
if data is None: c.execute("INSERT INTO model VALUES (?, ?)", (ctx.guild.id, model))
|
||||||
else: c.execute("UPDATE model SET model_name = ? WHERE guild_id = ?", (model, ctx.guild.id))
|
else: c.execute("UPDATE model SET model_name = ? WHERE guild_id = ?", (model, ctx.guild.id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await ctx.respond("Model changed !", ephemeral=True)
|
await ctx.respond("Model changed !", ephemeral=True)
|
||||||
|
|
||||||
|
async def images_recognition_autocomplete(ctx: discord.AutocompleteContext):
|
||||||
|
return [model for model in images_recognition if model.startswith(ctx.value)]
|
||||||
|
@discord.slash_command(name="images", description="Enable or disable images recognition")
|
||||||
|
@discord.option(name="enable_disable", description="Enable or disable images recognition", autocomplete=images_recognition_autocomplete)
|
||||||
|
@default_permissions(administrator=True)
|
||||||
|
async def images(self, ctx: discord.ApplicationContext, enable_disable: str):
|
||||||
|
try:
|
||||||
|
c.execute("SELECT * FROM images WHERE guild_id = ?", (ctx.guild.id,))
|
||||||
|
data = c.fetchone()
|
||||||
|
except:
|
||||||
|
data = None
|
||||||
|
if enable_disable == "enable": enable_disable = 1
|
||||||
|
elif enable_disable == "disable": enable_disable = 0
|
||||||
|
if data is None: c.execute("INSERT INTO images VALUES (?, ?, ?)", (ctx.guild.id, 0, enable_disable))
|
||||||
|
else: c.execute("UPDATE images SET is_enabled = ? WHERE guild_id = ?", (enable_disable, ctx.guild.id))
|
||||||
|
conn.commit()
|
||||||
|
await ctx.respond("Images recognition has been " + ("enabled" if enable_disable == 1 else "disabled"), ephemeral=True)
|
||||||
@@ -10,6 +10,8 @@ webhook_url = os.getenv("WEBHOOK_URL")
|
|||||||
max_uses: int = 400
|
max_uses: int = 400
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./../database/google-vision/botator.json"
|
||||||
|
|
||||||
def debug(message):
|
def debug(message):
|
||||||
logging.info(message)
|
logging.info(message)
|
||||||
conn = sqlite3.connect('../database/data.db')
|
conn = sqlite3.connect('../database/data.db')
|
||||||
@@ -49,5 +51,6 @@ else:
|
|||||||
print("Table already has the correct number of columns")
|
print("Table already has the correct number of columns")
|
||||||
pass
|
pass
|
||||||
c.execute('''CREATE TABLE IF NOT EXISTS model (guild_id text, model_name text)''')
|
c.execute('''CREATE TABLE IF NOT EXISTS model (guild_id text, model_name text)''')
|
||||||
|
c.execute('''CREATE TABLE IF NOT EXISTS images (guild_id text, usage_count integer, is_enabled boolean)''')
|
||||||
cp.execute('''CREATE TABLE IF NOT EXISTS data (user_id text, guild_id text, premium boolean)''')
|
cp.execute('''CREATE TABLE IF NOT EXISTS data (user_id text, guild_id text, premium boolean)''')
|
||||||
cp.execute('''CREATE TABLE IF NOT EXISTS channels (guild_id text, channel0 text, channel1 text, channel2 text, channel3 text, channel4 text)''')
|
cp.execute('''CREATE TABLE IF NOT EXISTS channels (guild_id text, channel0 text, channel1 text, channel2 text, channel3 text, channel4 text)''')
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from config import c, max_uses, cp, conn, debug, moderate
|
from config import c, max_uses, cp, conn, debug, moderate
|
||||||
|
import vision_processing
|
||||||
import re
|
import re
|
||||||
import discord
|
import discord
|
||||||
import datetime
|
import datetime
|
||||||
import openai
|
import openai
|
||||||
import emoji # pip install emoji
|
import emoji # pip install emoji
|
||||||
|
import os
|
||||||
|
|
||||||
async def replace_mentions(content, bot):
|
async def replace_mentions(content, bot):
|
||||||
mentions = re.findall(r"<@!?\d+>", content)
|
mentions = re.findall(r"<@!?\d+>", content)
|
||||||
@@ -60,6 +62,7 @@ async def chat_process(self, message):
|
|||||||
tts = data[11]
|
tts = data[11]
|
||||||
pretend_to_be = data[12]
|
pretend_to_be = data[12]
|
||||||
pretend_enabled = data[13]
|
pretend_enabled = data[13]
|
||||||
|
images_limit_reached = False
|
||||||
try: cp.execute("SELECT * FROM data WHERE guild_id = ?", (message.guild.id,))
|
try: cp.execute("SELECT * FROM data WHERE guild_id = ?", (message.guild.id,))
|
||||||
except: pass
|
except: pass
|
||||||
try:
|
try:
|
||||||
@@ -70,8 +73,16 @@ async def chat_process(self, message):
|
|||||||
try: premium = cp.fetchone()[2] # get the premium status of the guild
|
try: premium = cp.fetchone()[2] # get the premium status of the guild
|
||||||
except: premium = 0 # if the guild is not in the database, it's not premium
|
except: premium = 0 # if the guild is not in the database, it's not premium
|
||||||
|
|
||||||
|
try:
|
||||||
|
c.execute("SELECT * FROM images WHERE guild_id = ?", (message.guild.id,)) # get the images setting in the database
|
||||||
|
data = c.fetchone()
|
||||||
|
except:
|
||||||
|
data = None
|
||||||
|
if data is None: data = [message.guild.id, 0, 0]
|
||||||
|
images_usage = data[1]
|
||||||
|
images_enabled = data[2]
|
||||||
channels = []
|
channels = []
|
||||||
|
if message.guild.id == 1050769643180146749: images_usage = 0 # if the guild is the support server, we set the images usage to 0, so the bot can be used as much as possible
|
||||||
try:
|
try:
|
||||||
cp.execute("SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
|
cp.execute("SELECT * FROM channels WHERE guild_id = ?", (message.guild.id,))
|
||||||
data = cp.fetchone()
|
data = cp.fetchone()
|
||||||
@@ -160,12 +171,35 @@ async def chat_process(self, message):
|
|||||||
name = msg.author.name
|
name = msg.author.name
|
||||||
#the name should match '^[a-zA-Z0-9_-]{1,64}$', so we need to remove any special characters
|
#the name should match '^[a-zA-Z0-9_-]{1,64}$', so we need to remove any special characters
|
||||||
name = re.sub(r"[^a-zA-Z0-9_-]", "", name)
|
name = re.sub(r"[^a-zA-Z0-9_-]", "", name)
|
||||||
if False: # This is a placeholder for a new feature that will be added soon
|
if False: # GPT-4 images
|
||||||
input_content = [content]
|
input_content = [content]
|
||||||
for attachment in msg.attachments:
|
for attachment in msg.attachments:
|
||||||
image_bytes = await attachment.read()
|
image_bytes = await attachment.read()
|
||||||
input_content.append({"image": image_bytes})
|
input_content.append({"image": image_bytes})
|
||||||
msgs.append({"role": role, "content": input_content, "name": name})
|
msgs.append({"role": role, "content": input_content, "name": name})
|
||||||
|
#if there is an attachment, we add it to the message
|
||||||
|
if len(msg.attachments) > 0 and role == "user" and images_enabled == 1:
|
||||||
|
for attachment in msg.attachments:
|
||||||
|
if images_usage >= 6 and premium == 0: images_limit_reached = True
|
||||||
|
elif images_usage >= 30 and premium == 1: images_limit_reached = True
|
||||||
|
if attachment.url.endswith((".png", ".jpg", ".jpeg", ".gif")) and images_limit_reached == False and os.path.exists(f"./../database/google-vision/results/{attachment.id}.txt") == False:
|
||||||
|
images_usage += 1
|
||||||
|
analysis = await vision_processing.process(attachment)
|
||||||
|
if analysis != None:
|
||||||
|
content = f"{content} \n\n {analysis}"
|
||||||
|
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||||
|
#if the attachment is still an image, we can check if there's a file called ./../database/google-vision/results/{attachment.id}.txt, if there is, we add the content of the file to the message
|
||||||
|
elif attachment.url.endswith((".png", ".jpg", ".jpeg", ".gif")) and os.path.exists(f"./../database/google-vision/results/{attachment.id}.txt") == True:
|
||||||
|
try:
|
||||||
|
with open(f"./../database/google-vision/results/{attachment.id}.txt", "r") as f:
|
||||||
|
content = f"{content} \n\n {f.read()}"
|
||||||
|
f.close()
|
||||||
|
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||||
|
except:
|
||||||
|
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||||
|
else:
|
||||||
|
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||||
|
c.execute("UPDATE images SET usage_count = ? WHERE guild_id = ?", (images_usage, message.guild.id))
|
||||||
else:
|
else:
|
||||||
msgs.append({"role": role, "content": f"{content}", "name": name})
|
msgs.append({"role": role, "content": f"{content}", "name": name})
|
||||||
# 2 easter eggs
|
# 2 easter eggs
|
||||||
@@ -188,11 +222,13 @@ async def chat_process(self, message):
|
|||||||
frequency_penalty=0,
|
frequency_penalty=0,
|
||||||
presence_penalty=0,
|
presence_penalty=0,
|
||||||
messages=msgs,
|
messages=msgs,
|
||||||
|
max_tokens=512, # max tokens is 4000, that's a lot of text! (the max tokens is 2048 for the davinci model)
|
||||||
)
|
)
|
||||||
should_break = True
|
should_break = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
should_break = False
|
should_break = False
|
||||||
await message.channel.send(f"```diff\n-Error: OpenAI API ERROR.\n\n{e}```", delete_after=5)
|
await message.channel.send(f"```diff\n-Error: OpenAI API ERROR.\n\n{e}```", delete_after=5)
|
||||||
|
raise e
|
||||||
break
|
break
|
||||||
#if the ai said "as an ai language model..." we continue the loop" (this is a bug in the chatgpt model)
|
#if the ai said "as an ai language model..." we continue the loop" (this is a bug in the chatgpt model)
|
||||||
if response.choices[0].message.content.lower().find("as an ai language model") != -1:
|
if response.choices[0].message.content.lower().find("as an ai language model") != -1:
|
||||||
@@ -203,7 +239,8 @@ async def chat_process(self, message):
|
|||||||
if should_break: break
|
if should_break: break
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
response = response.choices[0].message.content
|
response = response.choices[0].message.content
|
||||||
|
if images_limit_reached == True:
|
||||||
|
await message.channel.send(f"```diff\n-Warning: You have reached the image limit for this server. You can upgrade to premium to get more images recognized. More info in our server: https://discord.gg/sxjHtmqrbf```", delete_after=10)
|
||||||
#-----------------------------------------Davinci------------------------------------------------------------------------------------------
|
#-----------------------------------------Davinci------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -250,7 +287,13 @@ async def chat_process(self, message):
|
|||||||
else: tts = False
|
else: tts = False
|
||||||
emojis, string = await extract_emoji(response)
|
emojis, string = await extract_emoji(response)
|
||||||
debug(f"Emojis: {emojis}")
|
debug(f"Emojis: {emojis}")
|
||||||
await message.channel.send(string, tts=tts)
|
if len(string) < 1996:
|
||||||
|
await message.channel.send(string, tts=tts)
|
||||||
|
else:
|
||||||
|
while len(string) > 1996:
|
||||||
|
send_string = string[:1996]
|
||||||
|
string = string[1996:]
|
||||||
|
await message.channel.send(send_string, tts=tts)
|
||||||
for emoji in emojis:
|
for emoji in emojis:
|
||||||
#if the emoji is longer than 1 character, it's a custom emoji
|
#if the emoji is longer than 1 character, it's a custom emoji
|
||||||
try:
|
try:
|
||||||
@@ -265,4 +308,4 @@ async def chat_process(self, message):
|
|||||||
await message.add_reaction(emoji)
|
await message.add_reaction(emoji)
|
||||||
except : pass
|
except : pass
|
||||||
else:
|
else:
|
||||||
await message.channel.send("The AI is not sure what to say (the response was empty)")
|
await message.channel.send("The AI is not sure what to say (the response was empty)")
|
||||||
@@ -19,4 +19,8 @@ It works also with custom emojis, for example: "+<:emoji_name:emoji_id> Hello!".
|
|||||||
|
|
||||||
You often use reactions to express your feelings. With the syntax explained above.
|
You often use reactions to express your feelings. With the syntax explained above.
|
||||||
|
|
||||||
|
When an image is sent by the user, some keywords are extracted from the image, and sent in an image block, for example: <image Labels: label1, label2 \n Text: Some text in the image \n Objects: object1, object2 !image>.
|
||||||
|
|
||||||
|
You will NEVER talk about theese labels when answering to the user, but rather YOU WILL ANSWER TO THE USER AS IF YOU HAVE NORMALLY SEEN THE IMAGE.
|
||||||
|
|
||||||
For every response, give the answers from Botator's perspective. This includes physical appearance, and occupation.
|
For every response, give the answers from Botator's perspective. This includes physical appearance, and occupation.
|
||||||
|
|||||||
58
code/test-google-vision.py
Normal file
58
code/test-google-vision.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
# Imports the Google Cloud client library
|
||||||
|
from google.cloud import vision
|
||||||
|
|
||||||
|
#we set the env variable GOOGLE_APPLICATION_CREDENTIALS to the path of the json file
|
||||||
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./../database/google-vision/botator-vision-8cd1030a7541.json"
|
||||||
|
# Instantiates a client
|
||||||
|
client = vision.ImageAnnotatorClient()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# The name of the image file to annotate
|
||||||
|
file_name = os.path.abspath('./../database/google-vision/label.jpg')
|
||||||
|
print(file_name)
|
||||||
|
# Loads the image into memory
|
||||||
|
with io.open(file_name, 'rb') as image_file:
|
||||||
|
content = image_file.read()
|
||||||
|
|
||||||
|
image = vision.Image(content=content)
|
||||||
|
|
||||||
|
# Performs label detection on the image file
|
||||||
|
#response = client.label_detection(image=image)
|
||||||
|
#labels = response.label_annotations
|
||||||
|
|
||||||
|
#print('Labels:')
|
||||||
|
#for label in labels:
|
||||||
|
# print(label.description)
|
||||||
|
|
||||||
|
async def get_labels(image):
|
||||||
|
response = client.label_detection(image=image)
|
||||||
|
labels = response.label_annotations
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def get_text(image):
|
||||||
|
response = client.text_detection(image=image)
|
||||||
|
texts = response.text_annotations
|
||||||
|
return texts
|
||||||
|
|
||||||
|
#now we print the labels
|
||||||
|
async def main():
|
||||||
|
labels = await get_labels(image)
|
||||||
|
print('Labels:')
|
||||||
|
for label in labels:
|
||||||
|
print(label.description)
|
||||||
|
texts = await get_text(image)
|
||||||
|
print('Texts:')
|
||||||
|
for text in texts:
|
||||||
|
print(text.description)
|
||||||
|
|
||||||
|
#now we run the main function
|
||||||
|
if __name__ == '__main__':
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(main())
|
||||||
46
code/vision_processing.py
Normal file
46
code/vision_processing.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from config import debug
|
||||||
|
# Imports the Google Cloud client library
|
||||||
|
from google.cloud import vision
|
||||||
|
|
||||||
|
# Instantiates a client
|
||||||
|
client = vision.ImageAnnotatorClient()
|
||||||
|
|
||||||
|
async def process(attachment):
|
||||||
|
debug("Processing image...")
|
||||||
|
image = vision.Image()
|
||||||
|
image.source.image_uri = attachment.url
|
||||||
|
labels = client.label_detection(image=image)
|
||||||
|
texts = client.text_detection(image=image)
|
||||||
|
objects = client.object_localization(image=image)
|
||||||
|
labels = labels.label_annotations
|
||||||
|
texts = texts.text_annotations
|
||||||
|
objects = objects.localized_object_annotations
|
||||||
|
#we take the first 4 labels and the first 4 objects
|
||||||
|
labels = labels[:2]
|
||||||
|
objects = objects[:7]
|
||||||
|
final = "<image\n"
|
||||||
|
if len(labels) > 0: final += "Labels:\n"
|
||||||
|
for label in labels:
|
||||||
|
final += label.description + ", "
|
||||||
|
final = final[:-2] + "\n"
|
||||||
|
if len(texts) > 0: final += "Text:\n"
|
||||||
|
try: final += texts[0].description + "\n" #we take the first text, wich is the whole text in reality
|
||||||
|
except: pass
|
||||||
|
if len(objects) > 0: final += "Objects:\n"
|
||||||
|
for obj in objects:
|
||||||
|
final += obj.name + ", "
|
||||||
|
final = final[:-2] + "\n"
|
||||||
|
final += "!image>"
|
||||||
|
# we store the result in a file called attachment.key.txt in the folder ./../database/google-vision/results
|
||||||
|
# we create the folder if it doesn't exist
|
||||||
|
if not os.path.exists("./../database/google-vision/results"):
|
||||||
|
os.mkdir("./../database/google-vision/results")
|
||||||
|
# we create the file
|
||||||
|
with open(f"./../database/google-vision/results/{attachment.id}.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(final)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
return final
|
||||||
@@ -4,4 +4,5 @@ openai
|
|||||||
apsw
|
apsw
|
||||||
google-api-python-client
|
google-api-python-client
|
||||||
python-dotenv
|
python-dotenv
|
||||||
emoji
|
emoji
|
||||||
|
google-cloud-vision
|
||||||
Reference in New Issue
Block a user