a big refactor

Kaveen Kumarasinghe 1 year ago
parent a6f6dcb74d
commit ba3ea3231b

@ -1,7 +1,7 @@
import discord
from pycord.multicog import add_to_group
from models.env_service_model import EnvService
from services.environment_service import EnvService
from models.check_model import Check
from models.autocomplete_model import Settings_autocompleter, File_autocompleter

@ -0,0 +1,101 @@
import asyncio
import os
import tempfile
import traceback
from io import BytesIO
import aiohttp
import discord
from PIL import Image
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from sqlitedict import SqliteDict
from cogs.text_service_cog import GPT3ComCon
from services.environment_service import EnvService
from models.user_model import RedoUser
from services.image_service import ImageService
from services.text_service import TextService
users_to_interactions = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None
if USER_INPUT_API_KEYS:
USER_KEY_DB = SqliteDict("user_key_db.sqlite")
class DrawDallEService(discord.Cog, name="DrawDallEService"):
def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
):
super().__init__()
self.bot = bot
self.usage_service = usage_service
self.model = model
self.message_queue = message_queue
self.deletion_queue = deletion_queue
self.converser_cog = converser_cog
print("Draw service initialized")
self.redo_users = {}
async def draw_command(self, ctx: discord.ApplicationContext, prompt: str):
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer()
user = ctx.user
if user == self.bot.user:
return
try:
asyncio.ensure_future(
ImageService.encapsulated_send(
self,
user.id, prompt, ctx, custom_api_key=user_api_key
)
)
except Exception as e:
print(e)
traceback.print_exc()
await ctx.respond("Something went wrong. Please try again later.")
await ctx.send_followup(e)
async def local_size_command(self, ctx: discord.ApplicationContext):
await ctx.defer()
# Get the size of the dall-e images folder that we have on the current system.
image_path = self.model.IMAGE_SAVE_PATH
total_size = 0
for dirpath, dirnames, filenames in os.walk(image_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
# Format the size to be in MB and send.
total_size = total_size / 1000000
await ctx.respond(f"The size of the local images folder is {total_size} MB.")
async def clear_local_command(self, ctx):
await ctx.defer()
# Delete all the local images in the images folder.
image_path = self.model.IMAGE_SAVE_PATH
for dirpath, dirnames, filenames in os.walk(image_path):
for f in filenames:
try:
fp = os.path.join(dirpath, f)
os.remove(fp)
except Exception as e:
print(e)
await ctx.respond("Local images cleared.")

@ -4,10 +4,11 @@ import traceback
import discord
from sqlitedict import SqliteDict
from cogs.gpt_3_commands_and_converser import GPT3ComCon
from models.env_service_model import EnvService
from services.environment_service import EnvService
from models.user_model import RedoUser
from pycord.multicog import add_to_group
from services.image_service import ImageService
from services.text_service import TextService
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
@ -55,7 +56,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
async def optimize_command(self, ctx: discord.ApplicationContext, prompt: str):
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx)
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx)
if not user_api_key:
return
@ -77,7 +78,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
try:
response = await self.model.send_request(
final_prompt,
tokens=70,
tokens=60,
top_p_override=1.0,
temp_override=0.9,
presence_penalty_override=0.5,
@ -217,7 +218,8 @@ class DrawButton(discord.ui.Button["OptimizeView"]):
prompt = re.sub(r"Optimized Prompt: ?", "", prompt)
# Call the image service cog to draw the image
await self.image_service_cog.encapsulated_send(
await ImageService.encapsulated_send(
self.image_service_cog,
user_id,
prompt,
interaction,
@ -255,7 +257,8 @@ class RedoButton(discord.ui.Button["OptimizeView"]):
msg = await interaction.response.send_message(
"Redoing your original request...", ephemeral=True, delete_after=20
)
await self.converser_cog.encapsulated_send(
await TextService.encapsulated_send(
self.converser_cog,
id=user_id,
prompt=prompt,
ctx=ctx,

@ -1,6 +1,5 @@
import asyncio
import datetime
import json
import re
import traceback
import sys
@ -10,22 +9,17 @@ from pathlib import Path
import aiofiles
import json
import aiohttp
import discord
from discord.ext import pages
from pycord.multicog import add_to_group
from models.deletion_service_model import Deletion
from models.env_service_model import EnvService
from models.message_model import Message
from models.moderations_service_model import Moderation
from models.openai_model import Model
from models.user_model import RedoUser, Thread, EmbeddedConversationItem
from models.check_model import Check
from models.autocomplete_model import Settings_autocompleter, File_autocompleter
from services.environment_service import EnvService
from services.message_queue_service import Message
from services.moderations_service import Moderation
from models.user_model import Thread, EmbeddedConversationItem
from collections import defaultdict
from sqlitedict import SqliteDict
from services.text_service import SetupModal, TextService
original_message = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
if sys.platform == "win32":
@ -169,21 +163,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.message_queue = message_queue
self.conversation_thread_owners = {}
@staticmethod
async def get_user_api_key(user_id, ctx):
user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id]
if user_api_key is None or user_api_key == "":
modal = SetupModal(title="API Key Setup")
if isinstance(ctx, discord.ApplicationContext):
await ctx.send_modal(modal)
await ctx.send_followup(
"You must set up your API key before using this command."
)
else:
await ctx.reply(
"You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key."
)
return user_api_key
async def load_file(self, file, ctx):
try:
@ -544,49 +523,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
Moderation(after, timestamp)
)
if after.author.id in self.redo_users:
if after.id == original_message[after.author.id]:
response_message = self.redo_users[after.author.id].response
ctx = self.redo_users[after.author.id].ctx
await response_message.edit(content="Redoing prompt 🔄...")
edited_content = await self.mention_to_username(after, after.content)
if after.channel.id in self.conversation_threads:
# Remove the last two elements from the history array and add the new <username>: prompt
self.conversation_threads[
after.channel.id
].history = self.conversation_threads[after.channel.id].history[:-2]
pinecone_dont_reinsert = None
if not self.pinecone_service:
self.conversation_threads[after.channel.id].history.append(
EmbeddedConversationItem(
f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n",
0,
)
)
self.conversation_threads[after.channel.id].count += 1
overrides = self.conversation_threads[after.channel.id].get_overrides()
await self.encapsulated_send(
id=after.channel.id,
prompt=edited_content,
ctx=ctx,
response_message=response_message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=self.conversation_threads[after.channel.id].model,
edited_request=True,
)
if not self.pinecone_service:
self.redo_users[after.author.id].prompt = edited_content
await TextService.process_conversation_edit(self, after, original_message)
async def check_and_launch_moderations(self, guild_id, alert_channel_override=None):
# Create the moderations service.
print("Checking and attempting to launch moderations service...")
@ -630,114 +567,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
Moderation(message, timestamp)
)
conversing = self.check_conversing(
message.author.id, message.channel.id, content
)
# If the user is conversing and they want to end it, end it immediately before we continue any further.
if conversing and message.content.lower() in self.END_PROMPTS:
await self.end_conversation(message)
return
if conversing:
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(
message.author.id, message
)
if not user_api_key:
return
prompt = await self.mention_to_username(message, content)
await self.check_conversation_limit(message)
# If the user is in a conversation thread
if message.channel.id in self.conversation_threads:
# Since this is async, we don't want to allow the user to send another prompt while a conversation
# prompt is processing, that'll mess up the conversation history!
if message.author.id in self.awaiting_responses:
message = await message.reply(
"You are already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await self.deletion_queue.put(deletion_message)
return
if message.channel.id in self.awaiting_thread_responses:
message = await message.reply(
"This thread is already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await self.deletion_queue.put(deletion_message)
return
self.awaiting_responses.append(message.author.id)
self.awaiting_thread_responses.append(message.channel.id)
original_message[message.author.id] = message.id
if not self.pinecone_service:
self.conversation_threads[message.channel.id].history.append(
EmbeddedConversationItem(
f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n",
0,
)
)
# increment the conversation counter for the user
self.conversation_threads[message.channel.id].count += 1
# Send the request to the model
# If conversing, the prompt to send is the history, otherwise, it's just the prompt
if (
self.pinecone_service
or message.channel.id not in self.conversation_threads
):
primary_prompt = prompt
else:
primary_prompt = "".join(
[
item.text
for item in self.conversation_threads[
message.channel.id
].history
]
)
# set conversation overrides
overrides = self.conversation_threads[message.channel.id].get_overrides()
await self.encapsulated_send(
message.channel.id,
primary_prompt,
message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=self.conversation_threads[message.channel.id].model,
custom_api_key=user_api_key,
)
# Process the message if the user is in a conversation
if await TextService.process_conversation_message(self, message, USER_INPUT_API_KEYS, USER_KEY_DB):
original_message[message.author.id] = message.id
def cleanse_response(self, response_text):
response_text = response_text.replace("GPTie:\n", "")
@ -767,433 +600,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
pass
return message
# ctx can be of type AppContext(interaction) or Message
async def encapsulated_send(
self,
id,
prompt,
ctx,
response_message=None,
temp_override=None,
top_p_override=None,
frequency_penalty_override=None,
presence_penalty_override=None,
from_ask_command=False,
instruction=None,
from_edit_command=False,
codex=False,
model=None,
custom_api_key=None,
edited_request=False,
redo_request=False,
):
new_prompt = (
prompt + "\nGPTie: "
if not from_ask_command and not from_edit_command
else prompt
)
from_context = isinstance(ctx, discord.ApplicationContext)
if not instruction:
tokens = self.usage_service.count_tokens(new_prompt)
else:
tokens = self.usage_service.count_tokens(
new_prompt
) + self.usage_service.count_tokens(instruction)
try:
# Pinecone is enabled, we will create embeddings for this conversation.
if self.pinecone_service and ctx.channel.id in self.conversation_threads:
# Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists
# check if the text attribute for any object inside self.conversation_threads[converation_id].history
# contains ""GPTie: <|endofstatement|>"", if so, delete
for item in self.conversation_threads[ctx.channel.id].history:
if item.text.strip() == "GPTie:<|endofstatement|>":
self.conversation_threads[ctx.channel.id].history.remove(item)
# The conversation_id is the id of the thread
conversation_id = ctx.channel.id
# Create an embedding and timestamp for the prompt
new_prompt = prompt.encode("ascii", "ignore").decode()
prompt_less_author = f"{new_prompt} <|endofstatement|>\n"
user_displayname = ctx.author.display_name
new_prompt = (
f"\n'{user_displayname}': {new_prompt} <|endofstatement|>\n"
)
new_prompt = new_prompt.encode("ascii", "ignore").decode()
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
new_prompt_item = EmbeddedConversationItem(new_prompt, timestamp)
if not redo_request:
self.conversation_threads[conversation_id].history.append(
new_prompt_item
)
if edited_request:
new_prompt = "".join(
[
item.text
for item in self.conversation_threads[
ctx.channel.id
].history
]
)
self.redo_users[ctx.author.id].prompt = new_prompt
else:
# Create and upsert the embedding for the conversation id, prompt, timestamp
await self.pinecone_service.upsert_conversation_embedding(
self.model,
conversation_id,
new_prompt,
timestamp,
custom_api_key=custom_api_key,
)
embedding_prompt_less_author = await self.model.send_embedding_request(
prompt_less_author, custom_api_key=custom_api_key
) # Use the version of the prompt without the author's name for better clarity on retrieval.
# Now, build the new prompt by getting the X most similar with pinecone
similar_prompts = self.pinecone_service.get_n_similar(
conversation_id,
embedding_prompt_less_author,
n=self.model.num_conversation_lookback,
)
# When we are in embeddings mode, only the pre-text is contained in self.conversation_threads[message.channel.id].history, so we
# can use that as a base to build our new prompt
prompt_with_history = [
self.conversation_threads[ctx.channel.id].history[0]
]
# Append the similar prompts to the prompt with history
prompt_with_history += [
EmbeddedConversationItem(prompt, timestamp)
for prompt, timestamp in similar_prompts
]
# iterate UP TO the last X prompts in the history
for i in range(
1,
min(
len(self.conversation_threads[ctx.channel.id].history),
self.model.num_static_conversation_items,
),
):
prompt_with_history.append(
self.conversation_threads[ctx.channel.id].history[-i]
)
# remove duplicates from prompt_with_history and set the conversation history
prompt_with_history = list(dict.fromkeys(prompt_with_history))
self.conversation_threads[
ctx.channel.id
].history = prompt_with_history
# Sort the prompt_with_history by increasing timestamp if pinecone is enabled
if self.pinecone_service:
prompt_with_history.sort(key=lambda x: x.timestamp)
# Ensure that the last prompt in this list is the prompt we just sent (new_prompt_item)
if prompt_with_history[-1] != new_prompt_item:
try:
prompt_with_history.remove(new_prompt_item)
except ValueError:
pass
prompt_with_history.append(new_prompt_item)
prompt_with_history = "".join(
[item.text for item in prompt_with_history]
)
new_prompt = prompt_with_history + "\nGPTie: "
tokens = self.usage_service.count_tokens(new_prompt)
# No pinecone, we do conversation summarization for long term memory instead
elif (
id in self.conversation_threads
and tokens > self.model.summarize_threshold
and not from_ask_command
and not from_edit_command
and not self.pinecone_service # This should only happen if we are not doing summarizations.
):
# We don't need to worry about the differences between interactions and messages in this block,
# because if we are in this block, we can only be using a message object for ctx
if self.model.summarize_conversations:
await ctx.reply(
"I'm currently summarizing our current conversation so we can keep chatting, "
"give me one moment!"
)
await self.summarize_conversation(ctx, new_prompt)
# Check again if the prompt is about to go past the token limit
new_prompt = (
"".join(
[
item.text
for item in self.conversation_threads[id].history
]
)
+ "\nGPTie: "
)
tokens = self.usage_service.count_tokens(new_prompt)
if (
tokens > self.model.summarize_threshold - 150
): # 150 is a buffer for the second stage
await ctx.reply(
"I tried to summarize our current conversation so we could keep chatting, "
"but it still went over the token "
"limit. Please try again later."
)
await self.end_conversation(ctx)
return
else:
await ctx.reply("The conversation context limit has been reached.")
await self.end_conversation(ctx)
return
# Send the request to the model
if from_edit_command:
response = await self.model.send_edit_request(
input=new_prompt,
instruction=instruction,
temp_override=temp_override,
top_p_override=top_p_override,
codex=codex,
custom_api_key=custom_api_key,
)
else:
response = await self.model.send_request(
new_prompt,
tokens=tokens,
temp_override=temp_override,
top_p_override=top_p_override,
frequency_penalty_override=frequency_penalty_override,
presence_penalty_override=presence_penalty_override,
model=model,
custom_api_key=custom_api_key,
)
# Clean the request response
response_text = self.cleanse_response(str(response["choices"][0]["text"]))
if from_ask_command:
# Append the prompt to the beginning of the response, in italics, then a new line
response_text = response_text.strip()
response_text = f"***{prompt}***\n\n{response_text}"
elif from_edit_command:
if codex:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n```\n{response_text}\n```"
else:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n{response_text}\n"
# If gpt3 tries writing a user mention try to replace it with their name
response_text = await self.mention_to_username(ctx, response_text)
# If the user is conversing, add the GPT response to their conversation history.
if (
id in self.conversation_threads
and not from_ask_command
and not self.pinecone_service
):
if not redo_request:
self.conversation_threads[id].history.append(
EmbeddedConversationItem(
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n", 0
)
)
# Embeddings case!
elif (
id in self.conversation_threads
and not from_ask_command
and not from_edit_command
and self.pinecone_service
):
conversation_id = id
# Create an embedding and timestamp for the prompt
response_text = (
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n"
)
response_text = response_text.encode("ascii", "ignore").decode()
# Print the current timestamp
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
self.conversation_threads[conversation_id].history.append(
EmbeddedConversationItem(response_text, timestamp)
)
# Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await self.pinecone_service.upsert_conversation_embedding(
self.model,
conversation_id,
response_text,
timestamp,
custom_api_key=custom_api_key,
)
# Cleanse again
response_text = self.cleanse_response(response_text)
# escape any other mentions like @here or @everyone
response_text = discord.utils.escape_mentions(response_text)
# If we don't have a response message, we are not doing a redo, send as a new message(s)
if not response_message:
if len(response_text) > self.TEXT_CUTOFF:
if not from_context:
paginator = None
await self.paginate_and_send(response_text, ctx)
else:
embed_pages = await self.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, self, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view)
response_message = await paginator.respond(ctx.interaction)
else:
paginator = None
if not from_context:
response_message = await ctx.reply(
response_text,
view=ConversationView(
ctx,
self,
ctx.channel.id,
model,
custom_api_key=custom_api_key,
),
)
elif from_edit_command:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
self,
ctx.channel.id,
model,
from_edit_command=from_edit_command,
custom_api_key=custom_api_key
),
)
else:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
self,
ctx.channel.id,
model,
from_ask_command=from_ask_command,
custom_api_key=custom_api_key
),
)
if response_message:
# Get the actual message object of response_message in case it's an WebhookMessage
actual_response_message = (
response_message
if not from_context
else await ctx.fetch_message(response_message.id)
)
self.redo_users[ctx.author.id] = RedoUser(
prompt=new_prompt,
instruction=instruction,
ctx=ctx,
message=ctx,
response=actual_response_message,
codex=codex,
paginator=paginator
)
self.redo_users[ctx.author.id].add_interaction(
actual_response_message.id
)
# We are doing a redo, edit the message.
else:
paginator = self.redo_users.get(ctx.author.id).paginator
if isinstance(paginator, pages.Paginator):
embed_pages = await self.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, self, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
await paginator.update(pages=embed_pages, custom_view=view)
elif len(response_text) > self.TEXT_CUTOFF:
if not from_context:
await response_message.channel.send("Over 2000 characters", delete_after=5)
else:
await response_message.edit(content=response_text)
await self.send_debug_message(
self.generate_debug_message(prompt, response), self.debug_channel
)
if ctx.author.id in self.awaiting_responses:
self.awaiting_responses.remove(ctx.author.id)
if not from_ask_command and not from_edit_command:
if ctx.channel.id in self.awaiting_thread_responses:
self.awaiting_thread_responses.remove(ctx.channel.id)
# Error catching for AIOHTTP Errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
if from_context:
await ctx.send_followup(message)
else:
await ctx.reply(message)
self.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# Error catching for OpenAI model value errors
except ValueError as e:
if from_context:
await ctx.send_followup(e)
else:
await ctx.reply(e)
self.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# General catch case for everything
except Exception:
message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting."
await ctx.send_followup(message) if from_context else await ctx.reply(
message
)
self.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
traceback.print_exc()
try:
await self.end_conversation(ctx)
except:
pass
return
# COMMANDS
async def help_command(self, ctx):
@ -1309,13 +715,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(user.id, ctx)
user_api_key = await TextService.get_user_api_key(user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer()
await self.encapsulated_send(
await TextService.encapsulated_send(
self,
user.id,
prompt,
ctx,
@ -1349,7 +756,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await ctx.defer()
await self.encapsulated_send(
await TextService.encapsulated_send(
self,
user.id,
prompt=input,
ctx=ctx,
@ -1511,7 +919,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.conversation_threads[thread.id].count += 1
await self.encapsulated_send(
await TextService.encapsulated_send(
self,
thread.id,
opener
if thread.id not in self.conversation_threads or self.pinecone_service
@ -1650,198 +1059,3 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
"alert_channel": self.get_moderated_alert_channel(guild_id),
}
MOD_DB.commit()
# VIEWS AND MODALS
class ConversationView(discord.ui.View):
def __init__(
self,
ctx,
converser_cog,
id,
model,
from_ask_command=False,
from_edit_command=False,
custom_api_key=None,
):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.ctx = ctx
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
self.add_item(
RedoButton(
self.converser_cog,
model=model,
from_ask_command=from_ask_command,
from_edit_command=from_edit_command,
custom_api_key=self.custom_api_key,
)
)
if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog))
async def on_timeout(self):
# Remove the button from the view/message
self.clear_items()
# Send a message to the user saying the view has timed out
if self.message:
await self.message.edit(
view=None,
)
else:
await self.ctx.edit(
view=None,
)
class EndConvoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end")
self.converser_cog = converser_cog
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if (
user_id in self.converser_cog.conversation_thread_owners
and self.converser_cog.conversation_thread_owners[user_id]
== interaction.channel.id
):
try:
await self.converser_cog.end_conversation(
interaction, opener_user_id=interaction.user.id
)
except Exception as e:
print(e)
traceback.print_exc()
await interaction.response.send_message(
e, ephemeral=True, delete_after=30
)
pass
else:
await interaction.response.send_message(
"This is not your conversation to end!", ephemeral=True, delete_after=10
)
class RedoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo")
self.converser_cog = converser_cog
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
# Get the message and the prompt and call encapsulated_send
prompt = self.converser_cog.redo_users[user_id].prompt
instruction = self.converser_cog.redo_users[user_id].instruction
ctx = self.converser_cog.redo_users[user_id].ctx
response_message = self.converser_cog.redo_users[user_id].response
codex = self.converser_cog.redo_users[user_id].codex
msg = await interaction.response.send_message(
"Retrying your original request...", ephemeral=True, delete_after=15
)
await self.converser_cog.encapsulated_send(
id=user_id,
prompt=prompt,
instruction=instruction,
ctx=ctx,
model=self.model,
response_message=response_message,
codex=codex,
custom_api_key=self.custom_api_key,
redo_request=True,
from_ask_command=self.from_ask_command,
from_edit_command=self.from_edit_command,
)
else:
await interaction.response.send_message(
"You can only redo the most recent prompt that you sent yourself.",
ephemeral=True,
delete_after=10,
)
class SetupModal(discord.ui.Modal):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.add_item(
discord.ui.InputText(
label="OpenAI API Key",
placeholder="sk--......",
)
)
async def callback(self, interaction: discord.Interaction):
user = interaction.user
api_key = self.children[0].value
# Validate that api_key is indeed in this format
if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key):
await interaction.response.send_message(
"Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.",
ephemeral=True,
delete_after=100,
)
else:
# We can save the key for the user to the database.
# Make a test request using the api key to ensure that it is valid.
try:
await Model.send_test_request(api_key)
await interaction.response.send_message(
"Your API key was successfully validated.",
ephemeral=True,
delete_after=10,
)
except aiohttp.ClientResponseError as e:
await interaction.response.send_message(
f"The API returned an invalid response: **{e.status}: {e.message}**",
ephemeral=True,
delete_after=30,
)
return
except Exception as e:
await interaction.response.send_message(
f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding",
ephemeral=True,
delete_after=30,
)
return
# Save the key to the database
try:
USER_KEY_DB[user.id] = api_key
USER_KEY_DB.commit()
await interaction.followup.send(
"Your API key was successfully saved.",
ephemeral=True,
delete_after=10,
)
except Exception as e:
traceback.print_exc()
await interaction.followup.send(
"There was an error saving your API key.",
ephemeral=True,
delete_after=30,
)
return
pass

@ -8,22 +8,22 @@ import pinecone
from pycord.multicog import apply_multicog
import os
from models.pinecone_service_model import PineconeService
from services.pinecone_service import PineconeService
if sys.platform == "win32":
separator = "\\"
else:
separator = "/"
from cogs.draw_image_generation import DrawDallEService
from cogs.gpt_3_commands_and_converser import GPT3ComCon
from cogs.image_prompt_optimizer import ImgPromptOptimizer
from cogs.image_service_cog import DrawDallEService
from cogs.text_service_cog import GPT3ComCon
from cogs.prompt_optimizer_cog import ImgPromptOptimizer
from cogs.commands import Commands
from models.deletion_service_model import Deletion
from models.message_model import Message
from services.deletion_service import Deletion
from services.message_queue_service import Message
from models.openai_model import Model
from models.usage_service_model import UsageService
from models.env_service_model import EnvService
from services.usage_service import UsageService
from services.environment_service import EnvService
__version__ = "6.0"

@ -3,9 +3,9 @@ import os
import re
import discord
from models.usage_service_model import UsageService
from services.usage_service import UsageService
from models.openai_model import Model
from models.env_service_model import EnvService
from services.environment_service import EnvService
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
model = Model(usage_service)

@ -1,6 +1,6 @@
import discord
from models.env_service_model import EnvService
from services.environment_service import EnvService
from typing import Callable
ADMIN_ROLES = EnvService.get_admin_roles()

@ -1,452 +1,373 @@
import asyncio
import os
import tempfile
import traceback
from io import BytesIO
import aiohttp
import discord
from PIL import Image
from pycord.multicog import add_to_group
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from sqlitedict import SqliteDict
from cogs.gpt_3_commands_and_converser import GPT3ComCon
from models.env_service_model import EnvService
from models.user_model import RedoUser
redo_users = {}
users_to_interactions = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None
if USER_INPUT_API_KEYS:
USER_KEY_DB = SqliteDict("user_key_db.sqlite")
class DrawDallEService(discord.Cog, name="DrawDallEService"):
def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
):
super().__init__()
self.bot = bot
self.usage_service = usage_service
self.model = model
self.message_queue = message_queue
self.deletion_queue = deletion_queue
self.converser_cog = converser_cog
print("Draw service initialized")
async def encapsulated_send(
self,
user_id,
prompt,
ctx,
response_message=None,
vary=None,
draw_from_optimizer=None,
custom_api_key=None,
):
await asyncio.sleep(0)
# send the prompt to the model
from_context = isinstance(ctx, discord.ApplicationContext)
try:
file, image_urls = await self.model.send_image_request(
ctx,
prompt,
vary=vary if not draw_from_optimizer else None,
custom_api_key=custom_api_key,
)
# Error catching for API errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
await ctx.channel.send(message) if not from_context else await ctx.respond(
message
)
return
except ValueError as e:
message = f"Error: {e}. Please try again with a different prompt."
await ctx.channel.send(message) if not from_context else await ctx.respond(
message
)
return
# Start building an embed to send to the user with the results of the image generation
embed = discord.Embed(
title="Image Generation Results"
if not vary
else "Image Generation Results (Varying)"
if not draw_from_optimizer
else "Image Generation Results (Drawing from Optimizer)",
description=f"{prompt}",
color=0xC730C7,
)
# Add the image file to the embed
embed.set_image(url=f"attachment://{file.filename}")
if not response_message: # Original generation case
# Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)
result_message = (
await ctx.channel.send(
embed=embed,
file=file,
)
if not from_context
else await ctx.respond(embed=embed, file=file)
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)
self.converser_cog.users_to_interactions[user_id] = []
self.converser_cog.users_to_interactions[user_id].append(result_message.id)
# Get the actual result message object
if from_context:
result_message = await ctx.fetch_message(result_message.id)
redo_users[user_id] = RedoUser(
prompt=prompt,
message=ctx,
ctx=ctx,
response=response_message,
instruction=None,
codex=False,
paginator=None
)
else:
if not vary: # Editing case
message = await response_message.edit(
embed=embed,
file=file,
)
await message.edit(
view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
message,
custom_api_key=custom_api_key,
)
)
else: # Varying case
if not draw_from_optimizer:
result_message = await response_message.edit_original_response(
content="Image variation completed!",
embed=embed,
file=file,
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
result_message,
True,
custom_api_key=custom_api_key,
)
)
else:
result_message = await response_message.edit_original_response(
content="I've drawn the optimized prompt!",
embed=embed,
file=file,
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)
redo_users[user_id] = RedoUser(
prompt=prompt,
message=ctx,
ctx=ctx,
response=response_message,
instruction=None,
codex=False,
paginator=None,
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
self.converser_cog.users_to_interactions[user_id].append(
result_message.id
)
async def draw_command(self, ctx: discord.ApplicationContext, prompt: str):
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx)
if not user_api_key:
return
await ctx.defer()
user = ctx.user
if user == self.bot.user:
return
try:
asyncio.ensure_future(
self.encapsulated_send(
user.id, prompt, ctx, custom_api_key=user_api_key
)
)
except Exception as e:
print(e)
traceback.print_exc()
await ctx.respond("Something went wrong. Please try again later.")
await ctx.send_followup(e)
async def local_size_command(self, ctx: discord.ApplicationContext):
await ctx.defer()
# Get the size of the dall-e images folder that we have on the current system.
image_path = self.model.IMAGE_SAVE_PATH
total_size = 0
for dirpath, dirnames, filenames in os.walk(image_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
# Format the size to be in MB and send.
total_size = total_size / 1000000
await ctx.respond(f"The size of the local images folder is {total_size} MB.")
async def clear_local_command(self, ctx):
await ctx.defer()
# Delete all the local images in the images folder.
image_path = self.model.IMAGE_SAVE_PATH
for dirpath, dirnames, filenames in os.walk(image_path):
for f in filenames:
try:
fp = os.path.join(dirpath, f)
os.remove(fp)
except Exception as e:
print(e)
await ctx.respond("Local images cleared.")
class SaveView(discord.ui.View):
def __init__(
self,
ctx,
image_urls,
cog,
converser_cog,
message,
no_retry=False,
only_save=None,
custom_api_key=None,
):
super().__init__(
timeout=3600 if not only_save else None
) # 1 hour timeout for Retry, Save
self.ctx = ctx
self.image_urls = image_urls
self.cog = cog
self.no_retry = no_retry
self.converser_cog = converser_cog
self.message = message
self.custom_api_key = custom_api_key
for x in range(1, len(image_urls) + 1):
self.add_item(SaveButton(x, image_urls[x - 1]))
if not only_save:
if not no_retry:
self.add_item(
RedoButton(
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
for x in range(1, len(image_urls) + 1):
self.add_item(
VaryButton(
x,
image_urls[x - 1],
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
# On the timeout event, override it and we want to clear the items.
async def on_timeout(self):
# Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then
# update the message
self.clear_items()
# Create a new view with the same params as this one, but pass only_save=True
new_view = SaveView(
self.ctx,
self.image_urls,
self.cog,
self.converser_cog,
self.message,
self.no_retry,
only_save=True,
)
# Set the view of the message to the new view
await self.ctx.edit(view=new_view)
class VaryButton(discord.ui.Button):
def __init__(self, number, image_url, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number))
self.number = number
self.image_url = image_url
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id
if (
interaction_id2
not in self.converser_cog.users_to_interactions[user_id]
):
await interaction.response.send_message(
content="You can not vary images in someone else's chain!",
ephemeral=True,
)
else:
await interaction.response.send_message(
content="You can only vary for images that you generated yourself!",
ephemeral=True,
)
return
if user_id in redo_users:
response_message = await interaction.response.send_message(
content="Varying image number " + str(self.number) + "..."
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.message.id
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
prompt = redo_users[user_id].prompt
asyncio.ensure_future(
self.cog.encapsulated_send(
user_id,
prompt,
interaction.message,
response_message=response_message,
vary=self.image_url,
custom_api_key=self.custom_api_key,
)
)
class SaveButton(discord.ui.Button["SaveView"]):
def __init__(self, number: int, image_url: str):
super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
self.number = number
self.image_url = image_url
async def callback(self, interaction: discord.Interaction):
# If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
# file to the user as an attachment.
try:
if not self.image_url.startswith("http"):
with open(self.image_url, "rb") as f:
image = Image.open(BytesIO(f.read()))
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(temp_file.name)
await interaction.response.send_message(
content="Here is your image for download (open original and save)",
file=discord.File(temp_file.name),
ephemeral=True,
)
else:
await interaction.response.send_message(
f"You can directly download this image from {self.image_url}",
ephemeral=True,
)
except Exception as e:
await interaction.response.send_message(f"Error: {e}", ephemeral=True)
traceback.print_exc()
class RedoButton(discord.ui.Button["SaveView"]):
def __init__(self, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only retry for prompts that you generated yourself!",
ephemeral=True,
)
return
# We have passed the intial check of if the interaction belongs to the user
if user_id in redo_users:
# Get the message and the prompt and call encapsulated_send
ctx = redo_users[user_id].ctx
prompt = redo_users[user_id].prompt
response_message = redo_users[user_id].response
message = await interaction.response.send_message(
f"Regenerating the image for your original prompt, check the original message.",
ephemeral=True,
)
self.converser_cog.users_to_interactions[user_id].append(message.id)
asyncio.ensure_future(
self.cog.encapsulated_send(
user_id,
prompt,
ctx,
response_message,
custom_api_key=self.custom_api_key,
)
)
import asyncio
import tempfile
import traceback
from io import BytesIO
import aiohttp
import discord
from PIL import Image
from models.user_model import RedoUser
class ImageService:
def __init__(self):
pass
@staticmethod
async def encapsulated_send(
image_service_cog,
user_id,
prompt,
ctx,
response_message=None,
vary=None,
draw_from_optimizer=None,
custom_api_key=None,
):
await asyncio.sleep(0)
# send the prompt to the model
from_context = isinstance(ctx, discord.ApplicationContext)
try:
file, image_urls = await image_service_cog.model.send_image_request(
ctx,
prompt,
vary=vary if not draw_from_optimizer else None,
custom_api_key=custom_api_key,
)
# Error catching for API errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
await ctx.channel.send(message) if not from_context else await ctx.respond(
message
)
return
except ValueError as e:
message = f"Error: {e}. Please try again with a different prompt."
await ctx.channel.send(message) if not from_context else await ctx.respond(
message
)
return
# Start building an embed to send to the user with the results of the image generation
embed = discord.Embed(
title="Image Generation Results"
if not vary
else "Image Generation Results (Varying)"
if not draw_from_optimizer
else "Image Generation Results (Drawing from Optimizer)",
description=f"{prompt}",
color=0xC730C7,
)
# Add the image file to the embed
embed.set_image(url=f"attachment://{file.filename}")
if not response_message: # Original generation case
# Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, image_service_cog, image_service_cog.converser_cog)
result_message = (
await ctx.channel.send(
embed=embed,
file=file,
)
if not from_context
else await ctx.respond(embed=embed, file=file)
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)
image_service_cog.converser_cog.users_to_interactions[user_id] = []
image_service_cog.converser_cog.users_to_interactions[user_id].append(result_message.id)
# Get the actual result message object
if from_context:
result_message = await ctx.fetch_message(result_message.id)
image_service_cog.redo_users[user_id] = RedoUser(
prompt=prompt,
message=ctx,
ctx=ctx,
response=response_message,
instruction=None,
codex=False,
paginator=None
)
else:
if not vary: # Editing case
message = await response_message.edit(
embed=embed,
file=file,
)
await message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
message,
custom_api_key=custom_api_key,
)
)
else: # Varying case
if not draw_from_optimizer:
result_message = await response_message.edit_original_response(
content="Image variation completed!",
embed=embed,
file=file,
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
result_message,
True,
custom_api_key=custom_api_key,
)
)
else:
result_message = await response_message.edit_original_response(
content="I've drawn the optimized prompt!",
embed=embed,
file=file,
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)
image_service_cog.redo_users[user_id] = RedoUser(
prompt=prompt,
message=ctx,
ctx=ctx,
response=response_message,
instruction=None,
codex=False,
paginator=None,
)
image_service_cog.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
image_service_cog.converser_cog.users_to_interactions[user_id].append(
result_message.id
)
class SaveView(discord.ui.View):
def __init__(
self,
ctx,
image_urls,
cog,
converser_cog,
message,
no_retry=False,
only_save=None,
custom_api_key=None,
):
super().__init__(
timeout=3600 if not only_save else None
) # 1 hour timeout for Retry, Save
self.ctx = ctx
self.image_urls = image_urls
self.cog = cog
self.no_retry = no_retry
self.converser_cog = converser_cog
self.message = message
self.custom_api_key = custom_api_key
for x in range(1, len(image_urls) + 1):
self.add_item(SaveButton(x, image_urls[x - 1]))
if not only_save:
if not no_retry:
self.add_item(
RedoButton(
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
for x in range(1, len(image_urls) + 1):
self.add_item(
VaryButton(
x,
image_urls[x - 1],
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
# On the timeout event, override it and we want to clear the items.
async def on_timeout(self):
# Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then
# update the message
self.clear_items()
# Create a new view with the same params as this one, but pass only_save=True
new_view = SaveView(
self.ctx,
self.image_urls,
self.cog,
self.converser_cog,
self.message,
self.no_retry,
only_save=True,
)
# Set the view of the message to the new view
await self.ctx.edit(view=new_view)
class VaryButton(discord.ui.Button):
def __init__(self, number, image_url, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number))
self.number = number
self.image_url = image_url
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id
if (
interaction_id2
not in self.converser_cog.users_to_interactions[user_id]
):
await interaction.response.send_message(
content="You can not vary images in someone else's chain!",
ephemeral=True,
)
else:
await interaction.response.send_message(
content="You can only vary for images that you generated yourself!",
ephemeral=True,
)
return
if user_id in self.cog.redo_users:
response_message = await interaction.response.send_message(
content="Varying image number " + str(self.number) + "..."
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.message.id
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
prompt = self.cog.redo_users[user_id].prompt
asyncio.ensure_future(
ImageService.encapsulated_send(
self.cog,
user_id,
prompt,
interaction.message,
response_message=response_message,
vary=self.image_url,
custom_api_key=self.custom_api_key,
)
)
class SaveButton(discord.ui.Button["SaveView"]):
def __init__(self, number: int, image_url: str):
super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
self.number = number
self.image_url = image_url
async def callback(self, interaction: discord.Interaction):
# If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
# file to the user as an attachment.
try:
if not self.image_url.startswith("http"):
with open(self.image_url, "rb") as f:
image = Image.open(BytesIO(f.read()))
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(temp_file.name)
await interaction.response.send_message(
content="Here is your image for download (open original and save)",
file=discord.File(temp_file.name),
ephemeral=True,
)
else:
await interaction.response.send_message(
f"You can directly download this image from {self.image_url}",
ephemeral=True,
)
except Exception as e:
await interaction.response.send_message(f"Error: {e}", ephemeral=True)
traceback.print_exc()
class RedoButton(discord.ui.Button["SaveView"]):
def __init__(self, cog, converser_cog, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only retry for prompts that you generated yourself!",
ephemeral=True,
)
return
# We have passed the intial check of if the interaction belongs to the user
if user_id in self.cog.redo_users:
# Get the message and the prompt and call encapsulated_send
ctx = self.cog.redo_users[user_id].ctx
prompt = self.cog.redo_users[user_id].prompt
response_message = self.cog.redo_users[user_id].response
message = await interaction.response.send_message(
f"Regenerating the image for your original prompt, check the original message.",
ephemeral=True,
)
self.converser_cog.users_to_interactions[user_id].append(message.id)
asyncio.ensure_future(
ImageService.encapsulated_send(
self.cog,
user_id,
prompt,
ctx,
response_message,
custom_api_key=self.custom_api_key,
)
)

@ -7,7 +7,7 @@ from pathlib import Path
import discord
from models.openai_model import Model
from models.usage_service_model import UsageService
from services.usage_service import UsageService
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
model = Model(usage_service)

@ -0,0 +1,820 @@
import datetime
import re
import traceback
import aiohttp
import discord
from discord.ext import pages
from services.deletion_service import Deletion
from models.openai_model import Model
from models.user_model import EmbeddedConversationItem, RedoUser
class TextService:
def __init__(self):
pass
@staticmethod
async def encapsulated_send(
converser_cog,
id,
prompt,
ctx,
response_message=None,
temp_override=None,
top_p_override=None,
frequency_penalty_override=None,
presence_penalty_override=None,
from_ask_command=False,
instruction=None,
from_edit_command=False,
codex=False,
model=None,
custom_api_key=None,
edited_request=False,
redo_request=False,
):
new_prompt = (
prompt + "\nGPTie: "
if not from_ask_command and not from_edit_command
else prompt
)
from_context = isinstance(ctx, discord.ApplicationContext)
if not instruction:
tokens = converser_cog.usage_service.count_tokens(new_prompt)
else:
tokens = converser_cog.usage_service.count_tokens(
new_prompt
) + converser_cog.usage_service.count_tokens(instruction)
try:
# Pinecone is enabled, we will create embeddings for this conversation.
if converser_cog.pinecone_service and ctx.channel.id in converser_cog.conversation_threads:
# Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists
# check if the text attribute for any object inside converser_cog.conversation_threads[converation_id].history
# contains ""GPTie: <|endofstatement|>"", if so, delete
for item in converser_cog.conversation_threads[ctx.channel.id].history:
if item.text.strip() == "GPTie:<|endofstatement|>":
converser_cog.conversation_threads[ctx.channel.id].history.remove(item)
# The conversation_id is the id of the thread
conversation_id = ctx.channel.id
# Create an embedding and timestamp for the prompt
new_prompt = prompt.encode("ascii", "ignore").decode()
prompt_less_author = f"{new_prompt} <|endofstatement|>\n"
user_displayname = ctx.author.display_name
new_prompt = (
f"\n'{user_displayname}': {new_prompt} <|endofstatement|>\n"
)
new_prompt = new_prompt.encode("ascii", "ignore").decode()
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
new_prompt_item = EmbeddedConversationItem(new_prompt, timestamp)
if not redo_request:
converser_cog.conversation_threads[conversation_id].history.append(
new_prompt_item
)
if edited_request:
new_prompt = "".join(
[
item.text
for item in converser_cog.conversation_threads[
ctx.channel.id
].history
]
)
converser_cog.redo_users[ctx.author.id].prompt = new_prompt
else:
# Create and upsert the embedding for the conversation id, prompt, timestamp
await converser_cog.pinecone_service.upsert_conversation_embedding(
converser_cog.model,
conversation_id,
new_prompt,
timestamp,
custom_api_key=custom_api_key,
)
embedding_prompt_less_author = await converser_cog.model.send_embedding_request(
prompt_less_author, custom_api_key=custom_api_key
) # Use the version of the prompt without the author's name for better clarity on retrieval.
# Now, build the new prompt by getting the X most similar with pinecone
similar_prompts = converser_cog.pinecone_service.get_n_similar(
conversation_id,
embedding_prompt_less_author,
n=converser_cog.model.num_conversation_lookback,
)
# When we are in embeddings mode, only the pre-text is contained in converser_cog.conversation_threads[message.channel.id].history, so we
# can use that as a base to build our new prompt
prompt_with_history = [
converser_cog.conversation_threads[ctx.channel.id].history[0]
]
# Append the similar prompts to the prompt with history
prompt_with_history += [
EmbeddedConversationItem(prompt, timestamp)
for prompt, timestamp in similar_prompts
]
# iterate UP TO the last X prompts in the history
for i in range(
1,
min(
len(converser_cog.conversation_threads[ctx.channel.id].history),
converser_cog.model.num_static_conversation_items,
),
):
prompt_with_history.append(
converser_cog.conversation_threads[ctx.channel.id].history[-i]
)
# remove duplicates from prompt_with_history and set the conversation history
prompt_with_history = list(dict.fromkeys(prompt_with_history))
converser_cog.conversation_threads[
ctx.channel.id
].history = prompt_with_history
# Sort the prompt_with_history by increasing timestamp if pinecone is enabled
if converser_cog.pinecone_service:
prompt_with_history.sort(key=lambda x: x.timestamp)
# Ensure that the last prompt in this list is the prompt we just sent (new_prompt_item)
if prompt_with_history[-1] != new_prompt_item:
try:
prompt_with_history.remove(new_prompt_item)
except ValueError:
pass
prompt_with_history.append(new_prompt_item)
prompt_with_history = "".join(
[item.text for item in prompt_with_history]
)
new_prompt = prompt_with_history + "\nGPTie: "
tokens = converser_cog.usage_service.count_tokens(new_prompt)
# No pinecone, we do conversation summarization for long term memory instead
elif (
id in converser_cog.conversation_threads
and tokens > converser_cog.model.summarize_threshold
and not from_ask_command
and not from_edit_command
and not converser_cog.pinecone_service # This should only happen if we are not doing summarizations.
):
# We don't need to worry about the differences between interactions and messages in this block,
# because if we are in this block, we can only be using a message object for ctx
if converser_cog.model.summarize_conversations:
await ctx.reply(
"I'm currently summarizing our current conversation so we can keep chatting, "
"give me one moment!"
)
await converser_cog.summarize_conversation(ctx, new_prompt)
# Check again if the prompt is about to go past the token limit
new_prompt = (
"".join(
[
item.text
for item in converser_cog.conversation_threads[id].history
]
)
+ "\nGPTie: "
)
tokens = converser_cog.usage_service.count_tokens(new_prompt)
if (
tokens > converser_cog.model.summarize_threshold - 150
): # 150 is a buffer for the second stage
await ctx.reply(
"I tried to summarize our current conversation so we could keep chatting, "
"but it still went over the token "
"limit. Please try again later."
)
await converser_cog.end_conversation(ctx)
return
else:
await ctx.reply("The conversation context limit has been reached.")
await converser_cog.end_conversation(ctx)
return
# Send the request to the model
if from_edit_command:
response = await converser_cog.model.send_edit_request(
input=new_prompt,
instruction=instruction,
temp_override=temp_override,
top_p_override=top_p_override,
codex=codex,
custom_api_key=custom_api_key,
)
else:
response = await converser_cog.model.send_request(
new_prompt,
tokens=tokens,
temp_override=temp_override,
top_p_override=top_p_override,
frequency_penalty_override=frequency_penalty_override,
presence_penalty_override=presence_penalty_override,
model=model,
custom_api_key=custom_api_key,
)
# Clean the request response
response_text = converser_cog.cleanse_response(str(response["choices"][0]["text"]))
if from_ask_command:
# Append the prompt to the beginning of the response, in italics, then a new line
response_text = response_text.strip()
response_text = f"***{prompt}***\n\n{response_text}"
elif from_edit_command:
if codex:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n```\n{response_text}\n```"
else:
response_text = response_text.strip()
response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n{response_text}\n"
# If gpt3 tries writing a user mention try to replace it with their name
response_text = await converser_cog.mention_to_username(ctx, response_text)
# If the user is conversing, add the GPT response to their conversation history.
if (
id in converser_cog.conversation_threads
and not from_ask_command
and not converser_cog.pinecone_service
):
if not redo_request:
converser_cog.conversation_threads[id].history.append(
EmbeddedConversationItem(
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n", 0
)
)
# Embeddings case!
elif (
id in converser_cog.conversation_threads
and not from_ask_command
and not from_edit_command
and converser_cog.pinecone_service
):
conversation_id = id
# Create an embedding and timestamp for the prompt
response_text = (
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n"
)
response_text = response_text.encode("ascii", "ignore").decode()
# Print the current timestamp
timestamp = int(
str(datetime.datetime.now().timestamp()).replace(".", "")
)
converser_cog.conversation_threads[conversation_id].history.append(
EmbeddedConversationItem(response_text, timestamp)
)
# Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await converser_cog.pinecone_service.upsert_conversation_embedding(
converser_cog.model,
conversation_id,
response_text,
timestamp,
custom_api_key=custom_api_key,
)
# Cleanse again
response_text = converser_cog.cleanse_response(response_text)
# escape any other mentions like @here or @everyone
response_text = discord.utils.escape_mentions(response_text)
# If we don't have a response message, we are not doing a redo, send as a new message(s)
if not response_message:
if len(response_text) > converser_cog.TEXT_CUTOFF:
if not from_context:
paginator = None
await converser_cog.paginate_and_send(response_text, ctx)
else:
embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view)
response_message = await paginator.respond(ctx.interaction)
else:
paginator = None
if not from_context:
response_message = await ctx.reply(
response_text,
view=ConversationView(
ctx,
converser_cog,
ctx.channel.id,
model,
custom_api_key=custom_api_key,
),
)
elif from_edit_command:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
converser_cog,
ctx.channel.id,
model,
from_edit_command=from_edit_command,
custom_api_key=custom_api_key
),
)
else:
response_message = await ctx.respond(
response_text,
view=ConversationView(
ctx,
converser_cog,
ctx.channel.id,
model,
from_ask_command=from_ask_command,
custom_api_key=custom_api_key
),
)
if response_message:
# Get the actual message object of response_message in case it's an WebhookMessage
actual_response_message = (
response_message
if not from_context
else await ctx.fetch_message(response_message.id)
)
converser_cog.redo_users[ctx.author.id] = RedoUser(
prompt=new_prompt,
instruction=instruction,
ctx=ctx,
message=ctx,
response=actual_response_message,
codex=codex,
paginator=paginator
)
converser_cog.redo_users[ctx.author.id].add_interaction(
actual_response_message.id
)
# We are doing a redo, edit the message.
else:
paginator = converser_cog.redo_users.get(ctx.author.id).paginator
if isinstance(paginator, pages.Paginator):
embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction)
view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key)
await paginator.update(pages=embed_pages, custom_view=view)
elif len(response_text) > converser_cog.TEXT_CUTOFF:
if not from_context:
await response_message.channel.send("Over 2000 characters", delete_after=5)
else:
await response_message.edit(content=response_text)
await converser_cog.send_debug_message(
converser_cog.generate_debug_message(prompt, response), converser_cog.debug_channel
)
if ctx.author.id in converser_cog.awaiting_responses:
converser_cog.awaiting_responses.remove(ctx.author.id)
if not from_ask_command and not from_edit_command:
if ctx.channel.id in converser_cog.awaiting_thread_responses:
converser_cog.awaiting_thread_responses.remove(ctx.channel.id)
# Error catching for AIOHTTP Errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
if from_context:
await ctx.send_followup(message)
else:
await ctx.reply(message)
converser_cog.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# Error catching for OpenAI model value errors
except ValueError as e:
if from_context:
await ctx.send_followup(e)
else:
await ctx.reply(e)
converser_cog.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
# General catch case for everything
except Exception:
message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting."
await ctx.send_followup(message) if from_context else await ctx.reply(
message
)
converser_cog.remove_awaiting(
ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command
)
traceback.print_exc()
try:
await converser_cog.end_conversation(ctx)
except:
pass
return
@staticmethod
async def process_conversation_message(converser_cog, message, USER_INPUT_API_KEYS, USER_KEY_DB):
content = message.content.strip()
conversing = converser_cog.check_conversing(
message.author.id, message.channel.id, content
)
# If the user is conversing and they want to end it, end it immediately before we continue any further.
if conversing and message.content.lower() in converser_cog.END_PROMPTS:
await converser_cog.end_conversation(message)
return
if conversing:
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(
message.author.id, message, USER_KEY_DB
)
if not user_api_key:
return
prompt = await converser_cog.mention_to_username(message, content)
await converser_cog.check_conversation_limit(message)
# If the user is in a conversation thread
if message.channel.id in converser_cog.conversation_threads:
# Since this is async, we don't want to allow the user to send another prompt while a conversation
# prompt is processing, that'll mess up the conversation history!
if message.author.id in converser_cog.awaiting_responses:
message = await message.reply(
"You are already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await converser_cog.deletion_queue.put(deletion_message)
return
if message.channel.id in converser_cog.awaiting_thread_responses:
message = await message.reply(
"This thread is already waiting for a response from GPT3. Please wait for it to respond before sending another message."
)
# get the current date, add 10 seconds to it, and then turn it into a timestamp.
# we need to use our deletion service because this isn't an interaction, it's a regular message.
deletion_time = datetime.datetime.now() + datetime.timedelta(
seconds=10
)
deletion_time = deletion_time.timestamp()
deletion_message = Deletion(message, deletion_time)
await converser_cog.deletion_queue.put(deletion_message)
return
converser_cog.awaiting_responses.append(message.author.id)
converser_cog.awaiting_thread_responses.append(message.channel.id)
if not converser_cog.pinecone_service:
converser_cog.conversation_threads[message.channel.id].history.append(
EmbeddedConversationItem(
f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n",
0,
)
)
# increment the conversation counter for the user
converser_cog.conversation_threads[message.channel.id].count += 1
# Send the request to the model
# If conversing, the prompt to send is the history, otherwise, it's just the prompt
if (
converser_cog.pinecone_service
or message.channel.id not in converser_cog.conversation_threads
):
primary_prompt = prompt
else:
primary_prompt = "".join(
[
item.text
for item in converser_cog.conversation_threads[
message.channel.id
].history
]
)
# set conversation overrides
overrides = converser_cog.conversation_threads[message.channel.id].get_overrides()
await TextService.encapsulated_send(
converser_cog,
message.channel.id,
primary_prompt,
message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=converser_cog.conversation_threads[message.channel.id].model,
custom_api_key=user_api_key,
)
return True
@staticmethod
async def get_user_api_key(user_id, ctx, USER_KEY_DB):
user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id]
if user_api_key is None or user_api_key == "":
modal = SetupModal(title="API Key Setup",user_key_db=USER_KEY_DB)
if isinstance(ctx, discord.ApplicationContext):
await ctx.send_modal(modal)
await ctx.send_followup(
"You must set up your API key before using this command."
)
else:
await ctx.reply(
"You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key."
)
return user_api_key
@staticmethod
async def process_conversation_edit(converser_cog, after, original_message):
if after.author.id in converser_cog.redo_users:
if after.id == original_message[after.author.id]:
response_message = converser_cog.redo_users[after.author.id].response
ctx = converser_cog.redo_users[after.author.id].ctx
await response_message.edit(content="Redoing prompt 🔄...")
edited_content = await converser_cog.mention_to_username(after, after.content)
if after.channel.id in converser_cog.conversation_threads:
# Remove the last two elements from the history array and add the new <username>: prompt
converser_cog.conversation_threads[
after.channel.id
].history = converser_cog.conversation_threads[after.channel.id].history[:-2]
pinecone_dont_reinsert = None
if not converser_cog.pinecone_service:
converser_cog.conversation_threads[after.channel.id].history.append(
EmbeddedConversationItem(
f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n",
0,
)
)
converser_cog.conversation_threads[after.channel.id].count += 1
overrides = converser_cog.conversation_threads[after.channel.id].get_overrides()
await TextService.encapsulated_send(
converser_cog,
id=after.channel.id,
prompt=edited_content,
ctx=ctx,
response_message=response_message,
temp_override=overrides["temperature"],
top_p_override=overrides["top_p"],
frequency_penalty_override=overrides["frequency_penalty"],
presence_penalty_override=overrides["presence_penalty"],
model=converser_cog.conversation_threads[after.channel.id].model,
edited_request=True,
)
if not converser_cog.pinecone_service:
converser_cog.redo_users[after.author.id].prompt = edited_content
"""
Conversation interaction buttons
"""
class ConversationView(discord.ui.View):
def __init__(
self,
ctx,
converser_cog,
id,
model,
from_ask_command=False,
from_edit_command=False,
custom_api_key=None,
):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.ctx = ctx
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
self.add_item(
RedoButton(
self.converser_cog,
model=model,
from_ask_command=from_ask_command,
from_edit_command=from_edit_command,
custom_api_key=self.custom_api_key,
)
)
if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog))
async def on_timeout(self):
# Remove the button from the view/message
self.clear_items()
# Send a message to the user saying the view has timed out
if self.message:
await self.message.edit(
view=None,
)
else:
await self.ctx.edit(
view=None,
)
class EndConvoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end")
self.converser_cog = converser_cog
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if (
user_id in self.converser_cog.conversation_thread_owners
and self.converser_cog.conversation_thread_owners[user_id]
== interaction.channel.id
):
try:
await self.converser_cog.end_conversation(
interaction, opener_user_id=interaction.user.id
)
except Exception as e:
print(e)
traceback.print_exc()
await interaction.response.send_message(
e, ephemeral=True, delete_after=30
)
pass
else:
await interaction.response.send_message(
"This is not your conversation to end!", ephemeral=True, delete_after=10
)
class RedoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key):
super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo")
self.converser_cog = converser_cog
self.model = model
self.from_ask_command = from_ask_command
self.from_edit_command = from_edit_command
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
# Get the message and the prompt and call encapsulated_send
prompt = self.converser_cog.redo_users[user_id].prompt
instruction = self.converser_cog.redo_users[user_id].instruction
ctx = self.converser_cog.redo_users[user_id].ctx
response_message = self.converser_cog.redo_users[user_id].response
codex = self.converser_cog.redo_users[user_id].codex
msg = await interaction.response.send_message(
"Retrying your original request...", ephemeral=True, delete_after=15
)
await TextService.encapsulated_send(
self.converser_cog,
id=user_id,
prompt=prompt,
instruction=instruction,
ctx=ctx,
model=self.model,
response_message=response_message,
codex=codex,
custom_api_key=self.custom_api_key,
redo_request=True,
from_ask_command=self.from_ask_command,
from_edit_command=self.from_edit_command,
)
else:
await interaction.response.send_message(
"You can only redo the most recent prompt that you sent yourself.",
ephemeral=True,
delete_after=10,
)
"""
The setup modal when using user input API keys
"""
class SetupModal(discord.ui.Modal):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Get the argument named "user_key_db" and save it as USER_KEY_DB
self.USER_KEY_DB = kwargs.pop("user_key_db")
self.add_item(
discord.ui.InputText(
label="OpenAI API Key",
placeholder="sk--......",
)
)
async def callback(self, interaction: discord.Interaction):
user = interaction.user
api_key = self.children[0].value
# Validate that api_key is indeed in this format
if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key):
await interaction.response.send_message(
"Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.",
ephemeral=True,
delete_after=100,
)
else:
# We can save the key for the user to the database.
# Make a test request using the api key to ensure that it is valid.
try:
await Model.send_test_request(api_key)
await interaction.response.send_message(
"Your API key was successfully validated.",
ephemeral=True,
delete_after=10,
)
except aiohttp.ClientResponseError as e:
await interaction.response.send_message(
f"The API returned an invalid response: **{e.status}: {e.message}**",
ephemeral=True,
delete_after=30,
)
return
except Exception as e:
await interaction.response.send_message(
f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding",
ephemeral=True,
delete_after=30,
)
return
# Save the key to the database
try:
self.USER_KEY_DB[user.id] = api_key
self.USER_KEY_DB.commit()
await interaction.followup.send(
"Your API key was successfully saved.",
ephemeral=True,
delete_after=10,
)
except Exception as e:
traceback.print_exc()
await interaction.followup.send(
"There was an error saving your API key.",
ephemeral=True,
delete_after=30,
)
return
pass
Loading…
Cancel
Save