a big refactor

Kaveen Kumarasinghe 1 year ago
parent a6f6dcb74d
commit ba3ea3231b

@ -1,7 +1,7 @@
import discord import discord
from pycord.multicog import add_to_group from pycord.multicog import add_to_group
from models.env_service_model import EnvService from services.environment_service import EnvService
from models.check_model import Check from models.check_model import Check
from models.autocomplete_model import Settings_autocompleter, File_autocompleter from models.autocomplete_model import Settings_autocompleter, File_autocompleter

@ -0,0 +1,101 @@
import asyncio
import os
import tempfile
import traceback
from io import BytesIO
import aiohttp
import discord
from PIL import Image
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from sqlitedict import SqliteDict
from cogs.text_service_cog import GPT3ComCon
from services.environment_service import EnvService
from models.user_model import RedoUser
from services.image_service import ImageService
from services.text_service import TextService
users_to_interactions = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None
if USER_INPUT_API_KEYS:
USER_KEY_DB = SqliteDict("user_key_db.sqlite")
class DrawDallEService(discord.Cog, name="DrawDallEService"):
def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
):
super().__init__()
self.bot = bot
self.usage_service = usage_service
self.model = model
self.message_queue = message_queue
self.deletion_queue = deletion_queue
self.converser_cog = converser_cog
print("Draw service initialized")
self.redo_users = {}
async def draw_command(self, ctx: discord.ApplicationContext, prompt: str):
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer()
user = ctx.user
if user == self.bot.user:
return
try:
asyncio.ensure_future(
ImageService.encapsulated_send(
self,
user.id, prompt, ctx, custom_api_key=user_api_key
)
)
except Exception as e:
print(e)
traceback.print_exc()
await ctx.respond("Something went wrong. Please try again later.")
await ctx.send_followup(e)
async def local_size_command(self, ctx: discord.ApplicationContext):
await ctx.defer()
# Get the size of the dall-e images folder that we have on the current system.
image_path = self.model.IMAGE_SAVE_PATH
total_size = 0
for dirpath, dirnames, filenames in os.walk(image_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
# Format the size to be in MB and send.
total_size = total_size / 1000000
await ctx.respond(f"The size of the local images folder is {total_size} MB.")
async def clear_local_command(self, ctx):
await ctx.defer()
# Delete all the local images in the images folder.
image_path = self.model.IMAGE_SAVE_PATH
for dirpath, dirnames, filenames in os.walk(image_path):
for f in filenames:
try:
fp = os.path.join(dirpath, f)
os.remove(fp)
except Exception as e:
print(e)
await ctx.respond("Local images cleared.")

@ -4,10 +4,11 @@ import traceback
import discord import discord
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from cogs.gpt_3_commands_and_converser import GPT3ComCon from services.environment_service import EnvService
from models.env_service_model import EnvService
from models.user_model import RedoUser from models.user_model import RedoUser
from pycord.multicog import add_to_group from services.image_service import ImageService
from services.text_service import TextService
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
@ -55,7 +56,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
async def optimize_command(self, ctx: discord.ApplicationContext, prompt: str): async def optimize_command(self, ctx: discord.ApplicationContext, prompt: str):
user_api_key = None user_api_key = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx) user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx)
if not user_api_key: if not user_api_key:
return return
@ -77,7 +78,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
try: try:
response = await self.model.send_request( response = await self.model.send_request(
final_prompt, final_prompt,
tokens=70, tokens=60,
top_p_override=1.0, top_p_override=1.0,
temp_override=0.9, temp_override=0.9,
presence_penalty_override=0.5, presence_penalty_override=0.5,
@ -217,7 +218,8 @@ class DrawButton(discord.ui.Button["OptimizeView"]):
prompt = re.sub(r"Optimized Prompt: ?", "", prompt) prompt = re.sub(r"Optimized Prompt: ?", "", prompt)
# Call the image service cog to draw the image # Call the image service cog to draw the image
await self.image_service_cog.encapsulated_send( await ImageService.encapsulated_send(
self.image_service_cog,
user_id, user_id,
prompt, prompt,
interaction, interaction,
@ -255,7 +257,8 @@ class RedoButton(discord.ui.Button["OptimizeView"]):
msg = await interaction.response.send_message( msg = await interaction.response.send_message(
"Redoing your original request...", ephemeral=True, delete_after=20 "Redoing your original request...", ephemeral=True, delete_after=20
) )
await self.converser_cog.encapsulated_send( await TextService.encapsulated_send(
self.converser_cog,
id=user_id, id=user_id,
prompt=prompt, prompt=prompt,
ctx=ctx, ctx=ctx,

@ -1,6 +1,5 @@
import asyncio import asyncio
import datetime import datetime
import json
import re import re
import traceback import traceback
import sys import sys
@ -10,22 +9,17 @@ from pathlib import Path
import aiofiles import aiofiles
import json import json
import aiohttp
import discord import discord
from discord.ext import pages
from pycord.multicog import add_to_group from services.environment_service import EnvService
from services.message_queue_service import Message
from models.deletion_service_model import Deletion from services.moderations_service import Moderation
from models.env_service_model import EnvService from models.user_model import Thread, EmbeddedConversationItem
from models.message_model import Message
from models.moderations_service_model import Moderation
from models.openai_model import Model
from models.user_model import RedoUser, Thread, EmbeddedConversationItem
from models.check_model import Check
from models.autocomplete_model import Settings_autocompleter, File_autocompleter
from collections import defaultdict from collections import defaultdict
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from services.text_service import SetupModal, TextService
original_message = {} original_message = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
if sys.platform == "win32": if sys.platform == "win32":
@ -169,21 +163,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.message_queue = message_queue self.message_queue = message_queue
self.conversation_thread_owners = {} self.conversation_thread_owners = {}
@staticmethod
async def get_user_api_key(user_id, ctx):
user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id]
if user_api_key is None or user_api_key == "":
modal = SetupModal(title="API Key Setup")
if isinstance(ctx, discord.ApplicationContext):
await ctx.send_modal(modal)
await ctx.send_followup(
"You must set up your API key before using this command."
)
else:
await ctx.reply(
"You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key."
)
return user_api_key
async def load_file(self, file, ctx): async def load_file(self, file, ctx):
try: try:
@ -544,49 +523,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
Moderation(after, timestamp) Moderation(after, timestamp)
) )
if after.author.id in self.redo_users: await TextService.process_conversation_edit(self, after, original_message)
if after.id == original_message[after.author.id]:
response_message = self.redo_users[after.author.id].response
ctx = self.redo_users[after.author.id].ctx
await response_message.edit(content="Redoing prompt 🔄...")
edited_content = await self.mention_to_username(after, after.content)
if after.channel.id in self.conversation_threads:
# Remove the last two elements from the history array and add the new <username>: prompt
self.conversation_threads[
after.channel.id
].history = self.conversation_threads[after.channel.id].history[:-2]
pinecone_dont_reinsert = None
if not self.pinecone_service:
self.conversation_threads[after.channel.id].history.append(
EmbeddedConversationItem(
f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n",
0,
)
)
self.conversation_threads[after.channel.id].count += 1
overrides = self.conversation_threads[after.channel.id].get_overrides()
await self.encapsulated_send(
id=after.channel.id,
prompt=edited_content,
ctx=ctx,
response_message=response_message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=self.conversation_threads[after.channel.id].model,
edited_request=True,
)
if not self.pinecone_service:
self.redo_users[after.author.id].prompt = edited_content
async def check_and_launch_moderations(self, guild_id, alert_channel_override=None): async def check_and_launch_moderations(self, guild_id, alert_channel_override=None):
# Create the moderations service. # Create the moderations service.
print("Checking and attempting to launch moderations service...") print("Checking and attempting to launch moderations service...")
@ -630,114 +567,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
Moderation(message, timestamp) Moderation(message, timestamp)
) )
conversing = self.check_conversing(
message.author.id, message.channel.id, content
)
# If the user is conversing and they want to end it, end it immediately before we continue any further.
if conversing and message.content.lower() in self.END_PROMPTS:
await self.end_conversation(message)
return
if conversing:
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(
message.author.id, message
)
if not user_api_key:
return
prompt = await self.mention_to_username(message, content)
await self.check_conversation_limit(message)
# If the user is in a conversation thread
if message.channel.id in self.conversation_threads:
# Since this is async, we don't want to allow the user to send another prompt while a conversation
# prompt is processing, that'll mess up the conversation history!
if message.author.id in self.awaiting_responses:
message = await message.reply(
"You are already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await self.deletion_queue.put(deletion_message)
return
if message.channel.id in self.awaiting_thread_responses:
message = await message.reply(
"This thread is already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await self.deletion_queue.put(deletion_message)
return # Process the message if the user is in a conversation
if await TextService.process_conversation_message(self, message, USER_INPUT_API_KEYS, USER_KEY_DB):
self.awaiting_responses.append(message.author.id) original_message[message.author.id] = message.id
self.awaiting_thread_responses.append(message.channel.id)
original_message[message.author.id] = message.id
if not self.pinecone_service:
self.conversation_threads[message.channel.id].history.append(
EmbeddedConversationItem(
f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n",
0,
)
)
# increment the conversation counter for the user
self.conversation_threads[message.channel.id].count += 1
# Send the request to the model
# If conversing, the prompt to send is the history, otherwise, it's just the prompt
if (
self.pinecone_service
or message.channel.id not in self.conversation_threads
):
primary_prompt = prompt
else:
primary_prompt = "".join(
[
item.text
for item in self.conversation_threads[
message.channel.id
].history
]
)
# set conversation overrides
overrides = self.conversation_threads[message.channel.id].get_overrides()
await self.encapsulated_send(
message.channel.id,
primary_prompt,
message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=self.conversation_threads[message.channel.id].model,
custom_api_key=user_api_key,
)
def cleanse_response(self, response_text): def cleanse_response(self, response_text):
response_text = response_text.replace("GPTie:\n", "") response_text = response_text.replace("GPTie:\n", "")
@ -767,433 +600,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
pass pass
return message return message
# ctx can be of type AppContext(interaction) or Message
async def encapsulated_send(
self,
id,
prompt,
ctx,
response_message=None,
temp_override=None,
top_p_override=None,
frequency_penalty_override=None,
presence_penalty_override=None,
from_ask_command=False,
instruction=None,
from_edit_command=False,
codex=False,
model=None,
custom_api_key=None,
edited_request=False,
redo_request=False,
):
new_prompt = (
prompt + "\nGPTie: "
if not from_ask_command and not from_edit_command
else prompt
)
from_context = isinstance(ctx, discord.ApplicationContext)
if not instruction:
tokens = self.usage_service.count_tokens(new_prompt)
else:
tokens = self.usage_service.count_tokens(
new_prompt
) + self.usage_service.count_tokens(instruction)
try:
# Pinecone is enabled, we will create embeddings for this conversation.
if self.pinecone_service and ctx.channel.id in self.conversation_threads:
# Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists
# check if the text attribute for any object inside self.conversation_threads[converation_id].history
# contains ""GPTie: <|endofstatement|>"", if so, delete
for item in self.conversation_threads[ctx.channel.id].history:
if item.text.strip() == "GPTie:<|endofstatement|>":
self.conversation_threads[ctx.channel.id].history.remove(item)
# The conversation_id is the id of the thread
conversation_id = ctx.channel.id
# Create an embedding and timestamp for the prompt
new_prompt = prompt.encode("ascii", "ignore").decode()
prompt_less_author = f"{new_prompt} <|endofstatement|>\n"
user_displayname = ctx.author.display_name
new_prompt = (
f"\n'{user_displayname}': {new_prompt} <|endofstatement|>\n"
)
new_prompt = new_prompt.encode("ascii", "ignore").decode()
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
new_prompt_item = EmbeddedConversationItem(new_prompt, timestamp)
if not redo_request:
self.conversation_threads[conversation_id].history.append(
new_prompt_item
)
if edited_request:
new_prompt = "".join(
[
item.text
for item in self.conversation_threads[
ctx.channel.id
].history
]
)
self.redo_users[ctx.author.id].prompt = new_prompt
else:
# Create and upsert the embedding for the conversation id, prompt, timestamp
await self.pinecone_service.upsert_conversation_embedding(
self.model,
conversation_id,
new_prompt,
timestamp,
custom_api_key=custom_api_key,
)
embedding_prompt_less_author = await self.model.send_embedding_request(
prompt_less_author, custom_api_key=custom_api_key
) # Use the version of the prompt without the author's name for better clarity on retrieval.
# Now, build the new prompt by getting the X most similar with pinecone
similar_prompts = self.pinecone_service.get_n_similar(
conversation_id,
embedding_prompt_less_author,
n=self.model.num_conversation_lookback,
)
# When we are in embeddings mode, only the pre-text is contained in self.conversation_threads[message.channel.id].history, so we
# can use that as a base to build our new prompt
prompt_with_history = [
self.conversation_threads[ctx.channel.id].history[0]
]
# Append the similar prompts to the prompt with history
prompt_with_history += [
EmbeddedConversationItem(prompt, timestamp)
for prompt, timestamp in similar_prompts
]
# iterate UP TO the last X prompts in the history
for i in range(
1,
min(
len(self.conversation_threads[ctx.channel.id].history),
self.model.num_static_conversation_items,
),
):
prompt_with_history.append(
self.conversation_threads[ctx.channel.id].history[-i]
)
# remove duplicates from prompt_with_history and set the conversation history
prompt_with_history = list(dict.fromkeys(prompt_with_history))
self.conversation_threads[
ctx.channel.id
].history = prompt_with_history
# Sort the prompt_with_history by increasing timestamp if pinecone is enabled
if self.pinecone_service:
prompt_with_history.sort(key=lambda x: x.timestamp)
# Ensure that the last prompt in this list is the prompt we just sent (new_prompt_item)
if prompt_with_history[-1] != new_prompt_item:
try:
prompt_with_history.remove(new_prompt_item)
except ValueError:
pass
prompt_with_history.append(new_prompt_item)
prompt_with_history = "".join(
[item.text for item in prompt_with_history]
)
new_prompt = prompt_with_history + "\nGPTie: "
tokens = self.usage_service.count_tokens(new_prompt)
# No pinecone, we do conversation summarization for long term memory instead
elif (
id in self.conversation_threads
and tokens > self.model.summarize_threshold
and not from_ask_command
and not from_edit_command
and not self.pinecone_service # This should only happen if we are not doing summarizations.
):
# We don't need to worry about the differences between interactions and messages in this block,
# because if we are in this block, we can only be using a message object for ctx
if self.model.summarize_conversations:
await ctx.reply(
"I'm currently summarizing our current conversation so we can keep chatting, "
"give me one moment!"
)
await self.summarize_conversation(ctx, new_prompt)
# Check again if the prompt is about to go past the token limit
new_prompt = (
"".join(
[
item.text
for item in self.conversation_threads[id].history
]
)
+ "\nGPTie: "
)
tokens = self.usage_service.count_tokens(new_prompt)
if (
tokens > self.model.summarize_threshold - 150
): # 150 is a buffer for the second stage
await ctx.reply(
"I tried to summarize our current conversation so we could keep chatting, "
"but it still went over the token "
"limit. Please try again later."
)
await self.end_conversation(ctx)
return
else:
await ctx.reply("The conversation context limit has been reached.")
await self.end_conversation(ctx)
return
# Send the request to the model
if from_edit_command:
response = await self.model.send_edit_request(
input=new_prompt,
instruction=instruction,
temp_override=temp_override,
top_p_override=top_p_override,
codex=codex,
custom_api_key=custom_api_key,
)
else:
response = await self.model.send_request(
new_prompt,
tokens=tokens,
temp_override=temp_override,
top_p_override=top_p_override,
frequency_penalty_override=frequency_penalty_override,
presence_penalty_override=presence_penalty_override,
model=model,
custom_api_key=custom_api_key,
)
# Clean the request response
response_text = self.cleanse_response(str(response["choices"][0]["text"]))
if from_ask_command:
# Append the prompt to the beginning of the response, in italics, then a new line
response_text = response_text.strip()
response_text = f"***{prompt}***\n\n{response_text}"
elif from_edit_command:
if codex:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n```\n{response_text}\n```"
else:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n{response_text}\n"
# If gpt3 tries writing a user mention try to replace it with their name
response_text = await self.mention_to_username(ctx, response_text)
# If the user is conversing, add the GPT response to their conversation history.
if (
id in self.conversation_threads
and not from_ask_command
and not self.pinecone_service
):
if not redo_request:
self.conversation_threads[id].history.append(
EmbeddedConversationItem(
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n", 0
)
)
# Embeddings case!
elif (
id in self.conversation_threads
and not from_ask_command
and not from_edit_command
and self.pinecone_service
):
conversation_id = id
# Create an embedding and timestamp for the prompt
response_text = (
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n"
)
response_text = response_text.encode("ascii", "ignore").decode()
# Print the current timestamp
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
self.conversation_threads[conversation_id].history.append(
EmbeddedConversationItem(response_text, timestamp)
)
# Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await self.pinecone_service.upsert_conversation_embedding(
self.model,
conversation_id,
response_text,
timestamp,
custom_api_key=custom_api_key,
)
# Cleanse again
response_text = self.cleanse_response(response_text)
# escape any other mentions like @here or @everyone
response_text = discord.utils.escape_mentions(response_text)
# If we don't have a response message, we are not doing a redo, send as a new message(s)
if not response_message:
if len(response_text) > self.TEXT_CUTOFF:
if not from_context:
paginator = None
await self.paginate_and_send(response_text, ctx)
else:
embed_pages = await self.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, self, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view)
response_message = await paginator.respond(ctx.interaction)
else:
paginator = None
if not from_context:
response_message = await ctx.reply(
response_text,
view=ConversationView(
ctx,
self,
ctx.channel.id,
model,
custom_api_key=custom_api_key,
),
)
elif from_edit_command:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
self,
ctx.channel.id,
model,
from_edit_command=from_edit_command,
custom_api_key=custom_api_key
),
)
else:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
self,
ctx.channel.id,
model,
from_ask_command=from_ask_command,
custom_api_key=custom_api_key
),
)
if response_message:
# Get the actual message object of response_message in case it's an WebhookMessage
actual_response_message = (
response_message
if not from_context
else await ctx.fetch_message(response_message.id)
)
self.redo_users[ctx.author.id] = RedoUser(
prompt=new_prompt,
instruction=instruction,
ctx=ctx,
message=ctx,
response=actual_response_message,
codex=codex,
paginator=paginator
)
self.redo_users[ctx.author.id].add_interaction(
actual_response_message.id
)
# We are doing a redo, edit the message.
else:
paginator = self.redo_users.get(ctx.author.id).paginator
if isinstance(paginator, pages.Paginator):
embed_pages = await self.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, self, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
await paginator.update(pages=embed_pages, custom_view=view)
elif len(response_text) > self.TEXT_CUTOFF:
if not from_context:
await response_message.channel.send("Over 2000 characters", delete_after=5)
else:
await response_message.edit(content=response_text)
await self.send_debug_message(
self.generate_debug_message(prompt, response), self.debug_channel
)
if ctx.author.id in self.awaiting_responses:
self.awaiting_responses.remove(ctx.author.id)
if not from_ask_command and not from_edit_command:
if ctx.channel.id in self.awaiting_thread_responses:
self.awaiting_thread_responses.remove(ctx.channel.id)
# Error catching for AIOHTTP Errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
if from_context:
await ctx.send_followup(message)
else:
await ctx.reply(message)
self.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# Error catching for OpenAI model value errors
except ValueError as e:
if from_context:
await ctx.send_followup(e)
else:
await ctx.reply(e)
self.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# General catch case for everything
except Exception:
message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting."
await ctx.send_followup(message) if from_context else await ctx.reply(
message
)
self.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
traceback.print_exc()
try:
await self.end_conversation(ctx)
except:
pass
return
# COMMANDS # COMMANDS
async def help_command(self, ctx): async def help_command(self, ctx):
@ -1309,13 +715,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
user_api_key = None user_api_key = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(user.id, ctx) user_api_key = await TextService.get_user_api_key(user.id, ctx, USER_KEY_DB)
if not user_api_key: if not user_api_key:
return return
await ctx.defer() await ctx.defer()
await self.encapsulated_send( await TextService.encapsulated_send(
self,
user.id, user.id,
prompt, prompt,
ctx, ctx,
@ -1349,7 +756,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
await self.encapsulated_send( await TextService.encapsulated_send(
self,
user.id, user.id,
prompt=input, prompt=input,
ctx=ctx, ctx=ctx,
@ -1511,7 +919,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.conversation_threads[thread.id].count += 1 self.conversation_threads[thread.id].count += 1
await self.encapsulated_send( await TextService.encapsulated_send(
self,
thread.id, thread.id,
opener opener
if thread.id not in self.conversation_threads or self.pinecone_service if thread.id not in self.conversation_threads or self.pinecone_service
@ -1650,198 +1059,3 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
"alert_channel": self.get_moderated_alert_channel(guild_id), "alert_channel": self.get_moderated_alert_channel(guild_id),
} }
MOD_DB.commit() MOD_DB.commit()
# VIEWS AND MODALS
class ConversationView(discord.ui.View):
def __init__(
self,
ctx,
converser_cog,
id,
model,
from_ask_command=False,
from_edit_command=False,
custom_api_key=None,
):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.ctx = ctx
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
self.add_item(
RedoButton(
self.converser_cog,
model=model,
from_ask_command=from_ask_command,
from_edit_command=from_edit_command,
custom_api_key=self.custom_api_key,
)
)
if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog))
async def on_timeout(self):
# Remove the button from the view/message
self.clear_items()
# Send a message to the user saying the view has timed out
if self.message:
await self.message.edit(
view=None,
)
else:
await self.ctx.edit(
view=None,
)
class EndConvoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end")
self.converser_cog = converser_cog
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if (
user_id in self.converser_cog.conversation_thread_owners
and self.converser_cog.conversation_thread_owners[user_id]
== interaction.channel.id
):
try:
await self.converser_cog.end_conversation(
interaction, opener_user_id=interaction.user.id
)
except Exception as e:
print(e)
traceback.print_exc()
await interaction.response.send_message(
e, ephemeral=True, delete_after=30
)
pass
else:
await interaction.response.send_message(
"This is not your conversation to end!", ephemeral=True, delete_after=10
)
class RedoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo")
self.converser_cog = converser_cog
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
# Get the message and the prompt and call encapsulated_send
prompt = self.converser_cog.redo_users[user_id].prompt
instruction = self.converser_cog.redo_users[user_id].instruction
ctx = self.converser_cog.redo_users[user_id].ctx
response_message = self.converser_cog.redo_users[user_id].response
codex = self.converser_cog.redo_users[user_id].codex
msg = await interaction.response.send_message(
"Retrying your original request...", ephemeral=True, delete_after=15
)
await self.converser_cog.encapsulated_send(
id=user_id,
prompt=prompt,
instruction=instruction,
ctx=ctx,
model=self.model,
response_message=response_message,
codex=codex,
custom_api_key=self.custom_api_key,
redo_request=True,
from_ask_command=self.from_ask_command,
from_edit_command=self.from_edit_command,
)
else:
await interaction.response.send_message(
"You can only redo the most recent prompt that you sent yourself.",
ephemeral=True,
delete_after=10,
)
class SetupModal(discord.ui.Modal):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.add_item(
discord.ui.InputText(
label="OpenAI API Key",
placeholder="sk--......",
)
)
async def callback(self, interaction: discord.Interaction):
user = interaction.user
api_key = self.children[0].value
# Validate that api_key is indeed in this format
if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key):
await interaction.response.send_message(
"Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.",
ephemeral=True,
delete_after=100,
)
else:
# We can save the key for the user to the database.
# Make a test request using the api key to ensure that it is valid.
try:
await Model.send_test_request(api_key)
await interaction.response.send_message(
"Your API key was successfully validated.",
ephemeral=True,
delete_after=10,
)
except aiohttp.ClientResponseError as e:
await interaction.response.send_message(
f"The API returned an invalid response: **{e.status}: {e.message}**",
ephemeral=True,
delete_after=30,
)
return
except Exception as e:
await interaction.response.send_message(
f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding",
ephemeral=True,
delete_after=30,
)
return
# Save the key to the database
try:
USER_KEY_DB[user.id] = api_key
USER_KEY_DB.commit()
await interaction.followup.send(
"Your API key was successfully saved.",
ephemeral=True,
delete_after=10,
)
except Exception as e:
traceback.print_exc()
await interaction.followup.send(
"There was an error saving your API key.",
ephemeral=True,
delete_after=30,
)
return
pass

@ -8,22 +8,22 @@ import pinecone
from pycord.multicog import apply_multicog from pycord.multicog import apply_multicog
import os import os
from models.pinecone_service_model import PineconeService from services.pinecone_service import PineconeService
if sys.platform == "win32": if sys.platform == "win32":
separator = "\\" separator = "\\"
else: else:
separator = "/" separator = "/"
from cogs.draw_image_generation import DrawDallEService from cogs.image_service_cog import DrawDallEService
from cogs.gpt_3_commands_and_converser import GPT3ComCon from cogs.text_service_cog import GPT3ComCon
from cogs.image_prompt_optimizer import ImgPromptOptimizer from cogs.prompt_optimizer_cog import ImgPromptOptimizer
from cogs.commands import Commands from cogs.commands import Commands
from models.deletion_service_model import Deletion from services.deletion_service import Deletion
from models.message_model import Message from services.message_queue_service import Message
from models.openai_model import Model from models.openai_model import Model
from models.usage_service_model import UsageService from services.usage_service import UsageService
from models.env_service_model import EnvService from services.environment_service import EnvService
__version__ = "6.0" __version__ = "6.0"

@ -3,9 +3,9 @@ import os
import re import re
import discord import discord
from models.usage_service_model import UsageService from services.usage_service import UsageService
from models.openai_model import Model from models.openai_model import Model
from models.env_service_model import EnvService from services.environment_service import EnvService
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
model = Model(usage_service) model = Model(usage_service)

@ -1,6 +1,6 @@
import discord import discord
from models.env_service_model import EnvService from services.environment_service import EnvService
from typing import Callable from typing import Callable
ADMIN_ROLES = EnvService.get_admin_roles() ADMIN_ROLES = EnvService.get_admin_roles()

@ -1,452 +1,373 @@
import asyncio import asyncio
import os import tempfile
import tempfile import traceback
import traceback from io import BytesIO
from io import BytesIO
import aiohttp
import aiohttp import discord
import discord from PIL import Image
from PIL import Image
from pycord.multicog import add_to_group from models.user_model import RedoUser
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time class ImageService:
from sqlitedict import SqliteDict
def __init__(self):
from cogs.gpt_3_commands_and_converser import GPT3ComCon pass
from models.env_service_model import EnvService
from models.user_model import RedoUser @staticmethod
async def encapsulated_send(
redo_users = {} image_service_cog,
users_to_interactions = {} user_id,
ALLOWED_GUILDS = EnvService.get_allowed_guilds() prompt,
ctx,
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() response_message=None,
USER_KEY_DB = None vary=None,
if USER_INPUT_API_KEYS: draw_from_optimizer=None,
USER_KEY_DB = SqliteDict("user_key_db.sqlite") custom_api_key=None,
):
await asyncio.sleep(0)
class DrawDallEService(discord.Cog, name="DrawDallEService"): # send the prompt to the model
def __init__( from_context = isinstance(ctx, discord.ApplicationContext)
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
): try:
super().__init__() file, image_urls = await image_service_cog.model.send_image_request(
self.bot = bot ctx,
self.usage_service = usage_service prompt,
self.model = model vary=vary if not draw_from_optimizer else None,
self.message_queue = message_queue custom_api_key=custom_api_key,
self.deletion_queue = deletion_queue )
self.converser_cog = converser_cog
print("Draw service initialized") # Error catching for API errors
except aiohttp.ClientResponseError as e:
async def encapsulated_send( message = (
self, f"The API returned an invalid response: **{e.status}: {e.message}**"
user_id, )
prompt, await ctx.channel.send(message) if not from_context else await ctx.respond(
ctx, message
response_message=None, )
vary=None, return
draw_from_optimizer=None,
custom_api_key=None, except ValueError as e:
): message = f"Error: {e}. Please try again with a different prompt."
await asyncio.sleep(0) await ctx.channel.send(message) if not from_context else await ctx.respond(
# send the prompt to the model message
from_context = isinstance(ctx, discord.ApplicationContext) )
try: return
file, image_urls = await self.model.send_image_request(
ctx, # Start building an embed to send to the user with the results of the image generation
prompt, embed = discord.Embed(
vary=vary if not draw_from_optimizer else None, title="Image Generation Results"
custom_api_key=custom_api_key, if not vary
) else "Image Generation Results (Varying)"
if not draw_from_optimizer
# Error catching for API errors else "Image Generation Results (Drawing from Optimizer)",
except aiohttp.ClientResponseError as e: description=f"{prompt}",
message = ( color=0xC730C7,
f"The API returned an invalid response: **{e.status}: {e.message}**" )
)
await ctx.channel.send(message) if not from_context else await ctx.respond( # Add the image file to the embed
message embed.set_image(url=f"attachment://{file.filename}")
)
return if not response_message: # Original generation case
# Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, image_service_cog, image_service_cog.converser_cog)
except ValueError as e: result_message = (
message = f"Error: {e}. Please try again with a different prompt." await ctx.channel.send(
await ctx.channel.send(message) if not from_context else await ctx.respond( embed=embed,
message file=file,
) )
if not from_context
return else await ctx.respond(embed=embed, file=file)
)
# Start building an embed to send to the user with the results of the image generation
embed = discord.Embed( await result_message.edit(
title="Image Generation Results" view=SaveView(
if not vary ctx,
else "Image Generation Results (Varying)" image_urls,
if not draw_from_optimizer image_service_cog,
else "Image Generation Results (Drawing from Optimizer)", image_service_cog.converser_cog,
description=f"{prompt}", result_message,
color=0xC730C7, custom_api_key=custom_api_key,
) )
)
# Add the image file to the embed
embed.set_image(url=f"attachment://{file.filename}") image_service_cog.converser_cog.users_to_interactions[user_id] = []
image_service_cog.converser_cog.users_to_interactions[user_id].append(result_message.id)
if not response_message: # Original generation case
# Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog) # Get the actual result message object
result_message = ( if from_context:
await ctx.channel.send( result_message = await ctx.fetch_message(result_message.id)
embed=embed,
file=file, image_service_cog.redo_users[user_id] = RedoUser(
) prompt=prompt,
if not from_context message=ctx,
else await ctx.respond(embed=embed, file=file) ctx=ctx,
) response=response_message,
instruction=None,
await result_message.edit( codex=False,
view=SaveView( paginator=None
ctx, )
image_urls,
self, else:
self.converser_cog, if not vary: # Editing case
result_message, message = await response_message.edit(
custom_api_key=custom_api_key, embed=embed,
) file=file,
) )
await message.edit(
self.converser_cog.users_to_interactions[user_id] = [] view=SaveView(
self.converser_cog.users_to_interactions[user_id].append(result_message.id) ctx,
image_urls,
# Get the actual result message object image_service_cog,
if from_context: image_service_cog.converser_cog,
result_message = await ctx.fetch_message(result_message.id) message,
custom_api_key=custom_api_key,
redo_users[user_id] = RedoUser( )
prompt=prompt, )
message=ctx, else: # Varying case
ctx=ctx, if not draw_from_optimizer:
response=response_message, result_message = await response_message.edit_original_response(
instruction=None, content="Image variation completed!",
codex=False, embed=embed,
paginator=None file=file,
) )
await result_message.edit(
else: view=SaveView(
if not vary: # Editing case ctx,
message = await response_message.edit( image_urls,
embed=embed, image_service_cog,
file=file, image_service_cog.converser_cog,
) result_message,
await message.edit( True,
view=SaveView( custom_api_key=custom_api_key,
ctx, )
image_urls, )
self,
self.converser_cog, else:
message, result_message = await response_message.edit_original_response(
custom_api_key=custom_api_key, content="I've drawn the optimized prompt!",
) embed=embed,
) file=file,
else: # Varying case )
if not draw_from_optimizer: await result_message.edit(
result_message = await response_message.edit_original_response( view=SaveView(
content="Image variation completed!", ctx,
embed=embed, image_urls,
file=file, image_service_cog,
) image_service_cog.converser_cog,
await result_message.edit( result_message,
view=SaveView( custom_api_key=custom_api_key,
ctx, )
image_urls, )
self,
self.converser_cog, image_service_cog.redo_users[user_id] = RedoUser(
result_message, prompt=prompt,
True, message=ctx,
custom_api_key=custom_api_key, ctx=ctx,
) response=response_message,
) instruction=None,
codex=False,
else: paginator=None,
result_message = await response_message.edit_original_response( )
content="I've drawn the optimized prompt!",
embed=embed, image_service_cog.converser_cog.users_to_interactions[user_id].append(
file=file, response_message.id
) )
await result_message.edit( image_service_cog.converser_cog.users_to_interactions[user_id].append(
view=SaveView( result_message.id
ctx, )
image_urls,
self,
self.converser_cog, class SaveView(discord.ui.View):
result_message, def __init__(
custom_api_key=custom_api_key, self,
) ctx,
) image_urls,
cog,
redo_users[user_id] = RedoUser( converser_cog,
prompt=prompt, message,
message=ctx, no_retry=False,
ctx=ctx, only_save=None,
response=response_message, custom_api_key=None,
instruction=None, ):
codex=False, super().__init__(
paginator=None, timeout=3600 if not only_save else None
) ) # 1 hour timeout for Retry, Save
self.ctx = ctx
self.converser_cog.users_to_interactions[user_id].append( self.image_urls = image_urls
response_message.id self.cog = cog
) self.no_retry = no_retry
self.converser_cog.users_to_interactions[user_id].append( self.converser_cog = converser_cog
result_message.id self.message = message
) self.custom_api_key = custom_api_key
for x in range(1, len(image_urls) + 1):
async def draw_command(self, ctx: discord.ApplicationContext, prompt: str): self.add_item(SaveButton(x, image_urls[x - 1]))
user_api_key = None if not only_save:
if USER_INPUT_API_KEYS: if not no_retry:
user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx) self.add_item(
if not user_api_key: RedoButton(
return self.cog,
converser_cog=self.converser_cog,
await ctx.defer() custom_api_key=self.custom_api_key,
)
user = ctx.user )
for x in range(1, len(image_urls) + 1):
if user == self.bot.user: self.add_item(
return VaryButton(
x,
try: image_urls[x - 1],
asyncio.ensure_future( self.cog,
self.encapsulated_send( converser_cog=self.converser_cog,
user.id, prompt, ctx, custom_api_key=user_api_key custom_api_key=self.custom_api_key,
) )
) )
except Exception as e: # On the timeout event, override it and we want to clear the items.
print(e) async def on_timeout(self):
traceback.print_exc() # Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then
await ctx.respond("Something went wrong. Please try again later.") # update the message
await ctx.send_followup(e) self.clear_items()
async def local_size_command(self, ctx: discord.ApplicationContext): # Create a new view with the same params as this one, but pass only_save=True
await ctx.defer() new_view = SaveView(
# Get the size of the dall-e images folder that we have on the current system. self.ctx,
self.image_urls,
image_path = self.model.IMAGE_SAVE_PATH self.cog,
total_size = 0 self.converser_cog,
for dirpath, dirnames, filenames in os.walk(image_path): self.message,
for f in filenames: self.no_retry,
fp = os.path.join(dirpath, f) only_save=True,
total_size += os.path.getsize(fp) )
# Format the size to be in MB and send. # Set the view of the message to the new view
total_size = total_size / 1000000 await self.ctx.edit(view=new_view)
await ctx.respond(f"The size of the local images folder is {total_size} MB.")
async def clear_local_command(self, ctx): class VaryButton(discord.ui.Button):
await ctx.defer() def __init__(self, number, image_url, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number))
# Delete all the local images in the images folder. self.number = number
image_path = self.model.IMAGE_SAVE_PATH self.image_url = image_url
for dirpath, dirnames, filenames in os.walk(image_path): self.cog = cog
for f in filenames: self.converser_cog = converser_cog
try: self.custom_api_key = custom_api_key
fp = os.path.join(dirpath, f)
os.remove(fp) async def callback(self, interaction: discord.Interaction):
except Exception as e: user_id = interaction.user.id
print(e) interaction_id = interaction.message.id
await ctx.respond("Local images cleared.") if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id
class SaveView(discord.ui.View): if (
def __init__( interaction_id2
self, not in self.converser_cog.users_to_interactions[user_id]
ctx, ):
image_urls, await interaction.response.send_message(
cog, content="You can not vary images in someone else's chain!",
converser_cog, ephemeral=True,
message, )
no_retry=False, else:
only_save=None, await interaction.response.send_message(
custom_api_key=None, content="You can only vary for images that you generated yourself!",
): ephemeral=True,
super().__init__( )
timeout=3600 if not only_save else None return
) # 1 hour timeout for Retry, Save
self.ctx = ctx if user_id in self.cog.redo_users:
self.image_urls = image_urls response_message = await interaction.response.send_message(
self.cog = cog content="Varying image number " + str(self.number) + "..."
self.no_retry = no_retry )
self.converser_cog = converser_cog self.converser_cog.users_to_interactions[user_id].append(
self.message = message response_message.message.id
self.custom_api_key = custom_api_key )
for x in range(1, len(image_urls) + 1): self.converser_cog.users_to_interactions[user_id].append(
self.add_item(SaveButton(x, image_urls[x - 1])) response_message.id
if not only_save: )
if not no_retry: prompt = self.cog.redo_users[user_id].prompt
self.add_item(
RedoButton( asyncio.ensure_future(
self.cog, ImageService.encapsulated_send(
converser_cog=self.converser_cog, self.cog,
custom_api_key=self.custom_api_key, user_id,
) prompt,
) interaction.message,
for x in range(1, len(image_urls) + 1): response_message=response_message,
self.add_item( vary=self.image_url,
VaryButton( custom_api_key=self.custom_api_key,
x, )
image_urls[x - 1], )
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key, class SaveButton(discord.ui.Button["SaveView"]):
) def __init__(self, number: int, image_url: str):
) super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
self.number = number
# On the timeout event, override it and we want to clear the items. self.image_url = image_url
async def on_timeout(self):
# Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then async def callback(self, interaction: discord.Interaction):
# update the message # If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
self.clear_items() # file to the user as an attachment.
try:
# Create a new view with the same params as this one, but pass only_save=True if not self.image_url.startswith("http"):
new_view = SaveView( with open(self.image_url, "rb") as f:
self.ctx, image = Image.open(BytesIO(f.read()))
self.image_urls, temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
self.cog, image.save(temp_file.name)
self.converser_cog,
self.message, await interaction.response.send_message(
self.no_retry, content="Here is your image for download (open original and save)",
only_save=True, file=discord.File(temp_file.name),
) ephemeral=True,
)
# Set the view of the message to the new view else:
await self.ctx.edit(view=new_view) await interaction.response.send_message(
f"You can directly download this image from {self.image_url}",
ephemeral=True,
class VaryButton(discord.ui.Button): )
def __init__(self, number, image_url, cog, converser_cog, custom_api_key): except Exception as e:
super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number)) await interaction.response.send_message(f"Error: {e}", ephemeral=True)
self.number = number traceback.print_exc()
self.image_url = image_url
self.cog = cog
self.converser_cog = converser_cog class RedoButton(discord.ui.Button["SaveView"]):
self.custom_api_key = custom_api_key def __init__(self, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
async def callback(self, interaction: discord.Interaction): self.cog = cog
user_id = interaction.user.id self.converser_cog = converser_cog
interaction_id = interaction.message.id self.custom_api_key = custom_api_key
if interaction_id not in self.converser_cog.users_to_interactions[user_id]: async def callback(self, interaction: discord.Interaction):
if len(self.converser_cog.users_to_interactions[user_id]) >= 2: user_id = interaction.user.id
interaction_id2 = interaction.id interaction_id = interaction.message.id
if (
interaction_id2 if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
not in self.converser_cog.users_to_interactions[user_id] await interaction.response.send_message(
): content="You can only retry for prompts that you generated yourself!",
await interaction.response.send_message( ephemeral=True,
content="You can not vary images in someone else's chain!", )
ephemeral=True, return
)
else: # We have passed the intial check of if the interaction belongs to the user
await interaction.response.send_message( if user_id in self.cog.redo_users:
content="You can only vary for images that you generated yourself!", # Get the message and the prompt and call encapsulated_send
ephemeral=True, ctx = self.cog.redo_users[user_id].ctx
) prompt = self.cog.redo_users[user_id].prompt
return response_message = self.cog.redo_users[user_id].response
message = await interaction.response.send_message(
if user_id in redo_users: f"Regenerating the image for your original prompt, check the original message.",
response_message = await interaction.response.send_message( ephemeral=True,
content="Varying image number " + str(self.number) + "..." )
) self.converser_cog.users_to_interactions[user_id].append(message.id)
self.converser_cog.users_to_interactions[user_id].append(
response_message.message.id asyncio.ensure_future(
) ImageService.encapsulated_send(
self.converser_cog.users_to_interactions[user_id].append( self.cog,
response_message.id user_id,
) prompt,
prompt = redo_users[user_id].prompt ctx,
response_message,
asyncio.ensure_future( custom_api_key=self.custom_api_key,
self.cog.encapsulated_send( )
user_id, )
prompt,
interaction.message,
response_message=response_message,
vary=self.image_url,
custom_api_key=self.custom_api_key,
)
)
class SaveButton(discord.ui.Button["SaveView"]):
def __init__(self, number: int, image_url: str):
super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
self.number = number
self.image_url = image_url
async def callback(self, interaction: discord.Interaction):
# If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
# file to the user as an attachment.
try:
if not self.image_url.startswith("http"):
with open(self.image_url, "rb") as f:
image = Image.open(BytesIO(f.read()))
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(temp_file.name)
await interaction.response.send_message(
content="Here is your image for download (open original and save)",
file=discord.File(temp_file.name),
ephemeral=True,
)
else:
await interaction.response.send_message(
f"You can directly download this image from {self.image_url}",
ephemeral=True,
)
except Exception as e:
await interaction.response.send_message(f"Error: {e}", ephemeral=True)
traceback.print_exc()
class RedoButton(discord.ui.Button["SaveView"]):
def __init__(self, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only retry for prompts that you generated yourself!",
ephemeral=True,
)
return
# We have passed the intial check of if the interaction belongs to the user
if user_id in redo_users:
# Get the message and the prompt and call encapsulated_send
ctx = redo_users[user_id].ctx
prompt = redo_users[user_id].prompt
response_message = redo_users[user_id].response
message = await interaction.response.send_message(
f"Regenerating the image for your original prompt, check the original message.",
ephemeral=True,
)
self.converser_cog.users_to_interactions[user_id].append(message.id)
asyncio.ensure_future(
self.cog.encapsulated_send(
user_id,
prompt,
ctx,
response_message,
custom_api_key=self.custom_api_key,
)
)

@ -7,7 +7,7 @@ from pathlib import Path
import discord import discord
from models.openai_model import Model from models.openai_model import Model
from models.usage_service_model import UsageService from services.usage_service import UsageService
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
model = Model(usage_service) model = Model(usage_service)

@ -0,0 +1,820 @@
import datetime
import re
import traceback
import aiohttp
import discord
from discord.ext import pages
from services.deletion_service import Deletion
from models.openai_model import Model
from models.user_model import EmbeddedConversationItem, RedoUser
class TextService:
def __init__(self):
pass
@staticmethod
async def encapsulated_send(
converser_cog,
id,
prompt,
ctx,
response_message=None,
temp_override=None,
top_p_override=None,
frequency_penalty_override=None,
presence_penalty_override=None,
from_ask_command=False,
instruction=None,
from_edit_command=False,
codex=False,
model=None,
custom_api_key=None,
edited_request=False,
redo_request=False,
):
new_prompt = (
prompt + "\nGPTie: "
if not from_ask_command and not from_edit_command
else prompt
)
from_context = isinstance(ctx, discord.ApplicationContext)
if not instruction:
tokens = converser_cog.usage_service.count_tokens(new_prompt)
else:
tokens = converser_cog.usage_service.count_tokens(
new_prompt
) + converser_cog.usage_service.count_tokens(instruction)
try:
# Pinecone is enabled, we will create embeddings for this conversation.
if converser_cog.pinecone_service and ctx.channel.id in converser_cog.conversation_threads:
# Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists
# check if the text attribute for any object inside converser_cog.conversation_threads[converation_id].history
# contains ""GPTie: <|endofstatement|>"", if so, delete
for item in converser_cog.conversation_threads[ctx.channel.id].history:
if item.text.strip() == "GPTie:<|endofstatement|>":
converser_cog.conversation_threads[ctx.channel.id].history.remove(item)
# The conversation_id is the id of the thread
conversation_id = ctx.channel.id
# Create an embedding and timestamp for the prompt
new_prompt = prompt.encode("ascii", "ignore").decode()
prompt_less_author = f"{new_prompt} <|endofstatement|>\n"
user_displayname = ctx.author.display_name
new_prompt = (
f"\n'{user_displayname}': {new_prompt} <|endofstatement|>\n"
)
new_prompt = new_prompt.encode("ascii", "ignore").decode()
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
new_prompt_item = EmbeddedConversationItem(new_prompt, timestamp)
if not redo_request:
converser_cog.conversation_threads[conversation_id].history.append(
new_prompt_item
)
if edited_request:
new_prompt = "".join(
[
item.text
for item in converser_cog.conversation_threads[
ctx.channel.id
].history
]
)
converser_cog.redo_users[ctx.author.id].prompt = new_prompt
else:
# Create and upsert the embedding for the conversation id, prompt, timestamp
await converser_cog.pinecone_service.upsert_conversation_embedding(
converser_cog.model,
conversation_id,
new_prompt,
timestamp,
custom_api_key=custom_api_key,
)
embedding_prompt_less_author = await converser_cog.model.send_embedding_request(
prompt_less_author, custom_api_key=custom_api_key
) # Use the version of the prompt without the author's name for better clarity on retrieval.
# Now, build the new prompt by getting the X most similar with pinecone
similar_prompts = converser_cog.pinecone_service.get_n_similar(
conversation_id,
embedding_prompt_less_author,
n=converser_cog.model.num_conversation_lookback,
)
# When we are in embeddings mode, only the pre-text is contained in converser_cog.conversation_threads[message.channel.id].history, so we
# can use that as a base to build our new prompt
prompt_with_history = [
converser_cog.conversation_threads[ctx.channel.id].history[0]
]
# Append the similar prompts to the prompt with history
prompt_with_history += [
EmbeddedConversationItem(prompt, timestamp)
for prompt, timestamp in similar_prompts
]
# iterate UP TO the last X prompts in the history
for i in range(
1,
min(
len(converser_cog.conversation_threads[ctx.channel.id].history),
converser_cog.model.num_static_conversation_items,
),
):
prompt_with_history.append(
converser_cog.conversation_threads[ctx.channel.id].history[-i]
)
# remove duplicates from prompt_with_history and set the conversation history
prompt_with_history = list(dict.fromkeys(prompt_with_history))
converser_cog.conversation_threads[
ctx.channel.id
].history = prompt_with_history
# Sort the prompt_with_history by increasing timestamp if pinecone is enabled
if converser_cog.pinecone_service:
prompt_with_history.sort(key=lambda x: x.timestamp)
# Ensure that the last prompt in this list is the prompt we just sent (new_prompt_item)
if prompt_with_history[-1] != new_prompt_item:
try:
prompt_with_history.remove(new_prompt_item)
except ValueError:
pass
prompt_with_history.append(new_prompt_item)
prompt_with_history = "".join(
[item.text for item in prompt_with_history]
)
new_prompt = prompt_with_history + "\nGPTie: "
tokens = converser_cog.usage_service.count_tokens(new_prompt)
# No pinecone, we do conversation summarization for long term memory instead
elif (
id in converser_cog.conversation_threads
and tokens > converser_cog.model.summarize_threshold
and not from_ask_command
and not from_edit_command
and not converser_cog.pinecone_service # This should only happen if we are not doing summarizations.
):
# We don't need to worry about the differences between interactions and messages in this block,
# because if we are in this block, we can only be using a message object for ctx
if converser_cog.model.summarize_conversations:
await ctx.reply(
"I'm currently summarizing our current conversation so we can keep chatting, "
"give me one moment!"
)
await converser_cog.summarize_conversation(ctx, new_prompt)
# Check again if the prompt is about to go past the token limit
new_prompt = (
"".join(
[
item.text
for item in converser_cog.conversation_threads[id].history
]
)
+ "\nGPTie: "
)
tokens = converser_cog.usage_service.count_tokens(new_prompt)
if (
tokens > converser_cog.model.summarize_threshold - 150
): # 150 is a buffer for the second stage
await ctx.reply(
"I tried to summarize our current conversation so we could keep chatting, "
"but it still went over the token "
"limit. Please try again later."
)
await converser_cog.end_conversation(ctx)
return
else:
await ctx.reply("The conversation context limit has been reached.")
await converser_cog.end_conversation(ctx)
return
# Send the request to the model
if from_edit_command:
response = await converser_cog.model.send_edit_request(
input=new_prompt,
instruction=instruction,
temp_override=temp_override,
top_p_override=top_p_override,
codex=codex,
custom_api_key=custom_api_key,
)
else:
response = await converser_cog.model.send_request(
new_prompt,
tokens=tokens,
temp_override=temp_override,
top_p_override=top_p_override,
frequency_penalty_override=frequency_penalty_override,
presence_penalty_override=presence_penalty_override,
model=model,
custom_api_key=custom_api_key,
)
# Clean the request response
response_text = converser_cog.cleanse_response(str(response["choices"][0]["text"]))
if from_ask_command:
# Append the prompt to the beginning of the response, in italics, then a new line
response_text = response_text.strip()
response_text = f"***{prompt}***\n\n{response_text}"
elif from_edit_command:
if codex:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n```\n{response_text}\n```"
else:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n{response_text}\n"
# If gpt3 tries writing a user mention try to replace it with their name
response_text = await converser_cog.mention_to_username(ctx, response_text)
# If the user is conversing, add the GPT response to their conversation history.
if (
id in converser_cog.conversation_threads
and not from_ask_command
and not converser_cog.pinecone_service
):
if not redo_request:
converser_cog.conversation_threads[id].history.append(
EmbeddedConversationItem(
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n", 0
)
)
# Embeddings case!
elif (
id in converser_cog.conversation_threads
and not from_ask_command
and not from_edit_command
and converser_cog.pinecone_service
):
conversation_id = id
# Create an embedding and timestamp for the prompt
response_text = (
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n"
)
response_text = response_text.encode("ascii", "ignore").decode()
# Print the current timestamp
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
converser_cog.conversation_threads[conversation_id].history.append(
EmbeddedConversationItem(response_text, timestamp)
)
# Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await converser_cog.pinecone_service.upsert_conversation_embedding(
converser_cog.model,
conversation_id,
response_text,
timestamp,
custom_api_key=custom_api_key,
)
# Cleanse again
response_text = converser_cog.cleanse_response(response_text)
# escape any other mentions like @here or @everyone
response_text = discord.utils.escape_mentions(response_text)
# If we don't have a response message, we are not doing a redo, send as a new message(s)
if not response_message:
if len(response_text) > converser_cog.TEXT_CUTOFF:
if not from_context:
paginator = None
await converser_cog.paginate_and_send(response_text, ctx)
else:
embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view)
response_message = await paginator.respond(ctx.interaction)
else:
paginator = None
if not from_context:
response_message = await ctx.reply(
response_text,
view=ConversationView(
ctx,
converser_cog,
ctx.channel.id,
model,
custom_api_key=custom_api_key,
),
)
elif from_edit_command:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
converser_cog,
ctx.channel.id,
model,
from_edit_command=from_edit_command,
custom_api_key=custom_api_key
),
)
else:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
converser_cog,
ctx.channel.id,
model,
from_ask_command=from_ask_command,
custom_api_key=custom_api_key
),
)
if response_message:
# Get the actual message object of response_message in case it's an WebhookMessage
actual_response_message = (
response_message
if not from_context
else await ctx.fetch_message(response_message.id)
)
converser_cog.redo_users[ctx.author.id] = RedoUser(
prompt=new_prompt,
instruction=instruction,
ctx=ctx,
message=ctx,
response=actual_response_message,
codex=codex,
paginator=paginator
)
converser_cog.redo_users[ctx.author.id].add_interaction(
actual_response_message.id
)
# We are doing a redo, edit the message.
else:
paginator = converser_cog.redo_users.get(ctx.author.id).paginator
if isinstance(paginator, pages.Paginator):
embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
await paginator.update(pages=embed_pages, custom_view=view)
elif len(response_text) > converser_cog.TEXT_CUTOFF:
if not from_context:
await response_message.channel.send("Over 2000 characters", delete_after=5)
else:
await response_message.edit(content=response_text)
await converser_cog.send_debug_message(
converser_cog.generate_debug_message(prompt, response), converser_cog.debug_channel
)
if ctx.author.id in converser_cog.awaiting_responses:
converser_cog.awaiting_responses.remove(ctx.author.id)
if not from_ask_command and not from_edit_command:
if ctx.channel.id in converser_cog.awaiting_thread_responses:
converser_cog.awaiting_thread_responses.remove(ctx.channel.id)
# Error catching for AIOHTTP Errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
if from_context:
await ctx.send_followup(message)
else:
await ctx.reply(message)
converser_cog.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# Error catching for OpenAI model value errors
except ValueError as e:
if from_context:
await ctx.send_followup(e)
else:
await ctx.reply(e)
converser_cog.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# General catch case for everything
except Exception:
message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting."
await ctx.send_followup(message) if from_context else await ctx.reply(
message
)
converser_cog.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
traceback.print_exc()
try:
await converser_cog.end_conversation(ctx)
except:
pass
return
@staticmethod
async def process_conversation_message(converser_cog, message, USER_INPUT_API_KEYS, USER_KEY_DB):
content = message.content.strip()
conversing = converser_cog.check_conversing(
message.author.id, message.channel.id, content
)
# If the user is conversing and they want to end it, end it immediately before we continue any further.
if conversing and message.content.lower() in converser_cog.END_PROMPTS:
await converser_cog.end_conversation(message)
return
if conversing:
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(
message.author.id, message, USER_KEY_DB
)
if not user_api_key:
return
prompt = await converser_cog.mention_to_username(message, content)
await converser_cog.check_conversation_limit(message)
# If the user is in a conversation thread
if message.channel.id in converser_cog.conversation_threads:
# Since this is async, we don't want to allow the user to send another prompt while a conversation
# prompt is processing, that'll mess up the conversation history!
if message.author.id in converser_cog.awaiting_responses:
message = await message.reply(
"You are already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await converser_cog.deletion_queue.put(deletion_message)
return
if message.channel.id in converser_cog.awaiting_thread_responses:
message = await message.reply(
"This thread is already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await converser_cog.deletion_queue.put(deletion_message)
return
converser_cog.awaiting_responses.append(message.author.id)
converser_cog.awaiting_thread_responses.append(message.channel.id)
if not converser_cog.pinecone_service:
converser_cog.conversation_threads[message.channel.id].history.append(
EmbeddedConversationItem(
f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n",
0,
)
)
# increment the conversation counter for the user
converser_cog.conversation_threads[message.channel.id].count += 1
# Send the request to the model
# If conversing, the prompt to send is the history, otherwise, it's just the prompt
if (
converser_cog.pinecone_service
or message.channel.id not in converser_cog.conversation_threads
):
primary_prompt = prompt
else:
primary_prompt = "".join(
[
item.text
for item in converser_cog.conversation_threads[
message.channel.id
].history
]
)
# set conversation overrides
overrides = converser_cog.conversation_threads[message.channel.id].get_overrides()
await TextService.encapsulated_send(
converser_cog,
message.channel.id,
primary_prompt,
message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=converser_cog.conversation_threads[message.channel.id].model,
custom_api_key=user_api_key,
)
return True
@staticmethod
async def get_user_api_key(user_id, ctx, USER_KEY_DB):
user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id]
if user_api_key is None or user_api_key == "":
modal = SetupModal(title="API Key Setup",user_key_db=USER_KEY_DB)
if isinstance(ctx, discord.ApplicationContext):
await ctx.send_modal(modal)
await ctx.send_followup(
"You must set up your API key before using this command."
)
else:
await ctx.reply(
"You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key."
)
return user_api_key
@staticmethod
async def process_conversation_edit(converser_cog, after, original_message):
if after.author.id in converser_cog.redo_users:
if after.id == original_message[after.author.id]:
response_message = converser_cog.redo_users[after.author.id].response
ctx = converser_cog.redo_users[after.author.id].ctx
await response_message.edit(content="Redoing prompt 🔄...")
edited_content = await converser_cog.mention_to_username(after, after.content)
if after.channel.id in converser_cog.conversation_threads:
# Remove the last two elements from the history array and add the new <username>: prompt
converser_cog.conversation_threads[
after.channel.id
].history = converser_cog.conversation_threads[after.channel.id].history[:-2]
pinecone_dont_reinsert = None
if not converser_cog.pinecone_service:
converser_cog.conversation_threads[after.channel.id].history.append(
EmbeddedConversationItem(
f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n",
0,
)
)
converser_cog.conversation_threads[after.channel.id].count += 1
overrides = converser_cog.conversation_threads[after.channel.id].get_overrides()
await TextService.encapsulated_send(
converser_cog,
id=after.channel.id,
prompt=edited_content,
ctx=ctx,
response_message=response_message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=converser_cog.conversation_threads[after.channel.id].model,
edited_request=True,
)
if not converser_cog.pinecone_service:
converser_cog.redo_users[after.author.id].prompt = edited_content
"""
Conversation interaction buttons
"""
class ConversationView(discord.ui.View):
def __init__(
self,
ctx,
converser_cog,
id,
model,
from_ask_command=False,
from_edit_command=False,
custom_api_key=None,
):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.ctx = ctx
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
self.add_item(
RedoButton(
self.converser_cog,
model=model,
from_ask_command=from_ask_command,
from_edit_command=from_edit_command,
custom_api_key=self.custom_api_key,
)
)
if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog))
async def on_timeout(self):
# Remove the button from the view/message
self.clear_items()
# Send a message to the user saying the view has timed out
if self.message:
await self.message.edit(
view=None,
)
else:
await self.ctx.edit(
view=None,
)
class EndConvoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end")
self.converser_cog = converser_cog
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if (
user_id in self.converser_cog.conversation_thread_owners
and self.converser_cog.conversation_thread_owners[user_id]
== interaction.channel.id
):
try:
await self.converser_cog.end_conversation(
interaction, opener_user_id=interaction.user.id
)
except Exception as e:
print(e)
traceback.print_exc()
await interaction.response.send_message(
e, ephemeral=True, delete_after=30
)
pass
else:
await interaction.response.send_message(
"This is not your conversation to end!", ephemeral=True, delete_after=10
)
class RedoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo")
self.converser_cog = converser_cog
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
# Get the message and the prompt and call encapsulated_send
prompt = self.converser_cog.redo_users[user_id].prompt
instruction = self.converser_cog.redo_users[user_id].instruction
ctx = self.converser_cog.redo_users[user_id].ctx
response_message = self.converser_cog.redo_users[user_id].response
codex = self.converser_cog.redo_users[user_id].codex
msg = await interaction.response.send_message(
"Retrying your original request...", ephemeral=True, delete_after=15
)
await TextService.encapsulated_send(
self.converser_cog,
id=user_id,
prompt=prompt,
instruction=instruction,
ctx=ctx,
model=self.model,
response_message=response_message,
codex=codex,
custom_api_key=self.custom_api_key,
redo_request=True,
from_ask_command=self.from_ask_command,
from_edit_command=self.from_edit_command,
)
else:
await interaction.response.send_message(
"You can only redo the most recent prompt that you sent yourself.",
ephemeral=True,
delete_after=10,
)
"""
The setup modal when using user input API keys
"""
class SetupModal(discord.ui.Modal):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Get the argument named "user_key_db" and save it as USER_KEY_DB
self.USER_KEY_DB = kwargs.pop("user_key_db")
self.add_item(
discord.ui.InputText(
label="OpenAI API Key",
placeholder="sk--......",
)
)
async def callback(self, interaction: discord.Interaction):
user = interaction.user
api_key = self.children[0].value
# Validate that api_key is indeed in this format
if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key):
await interaction.response.send_message(
"Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.",
ephemeral=True,
delete_after=100,
)
else:
# We can save the key for the user to the database.
# Make a test request using the api key to ensure that it is valid.
try:
await Model.send_test_request(api_key)
await interaction.response.send_message(
"Your API key was successfully validated.",
ephemeral=True,
delete_after=10,
)
except aiohttp.ClientResponseError as e:
await interaction.response.send_message(
f"The API returned an invalid response: **{e.status}: {e.message}**",
ephemeral=True,
delete_after=30,
)
return
except Exception as e:
await interaction.response.send_message(
f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding",
ephemeral=True,
delete_after=30,
)
return
# Save the key to the database
try:
self.USER_KEY_DB[user.id] = api_key
self.USER_KEY_DB.commit()
await interaction.followup.send(
"Your API key was successfully saved.",
ephemeral=True,
delete_after=10,
)
except Exception as e:
traceback.print_exc()
await interaction.followup.send(
"There was an error saving your API key.",
ephemeral=True,
delete_after=30,
)
return
pass
Loading…
Cancel
Save