From ba3ea3231b6e7fb3e8ec922ba815e4367143e70e Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Sat, 14 Jan 2023 06:55:47 -0500 Subject: [PATCH] a big refactor --- cogs/commands.py | 2 +- cogs/image_service_cog.py | 101 +++ ...t_optimizer.py => prompt_optimizer_cog.py} | 17 +- ...s_and_converser.py => text_service_cog.py} | 822 +---------------- gpt3discord.py | 16 +- models/autocomplete_model.py | 4 +- models/check_model.py | 2 +- services/__init__.py | 0 .../deletion_service.py | 0 .../environment_service.py | 0 .../image_service.py | 825 ++++++++---------- .../message_queue_service.py | 0 .../moderations_service.py | 2 +- .../pinecone_service.py | 0 services/text_service.py | 820 +++++++++++++++++ .../usage_service.py | 0 16 files changed, 1335 insertions(+), 1276 deletions(-) create mode 100644 cogs/image_service_cog.py rename cogs/{image_prompt_optimizer.py => prompt_optimizer_cog.py} (93%) rename cogs/{gpt_3_commands_and_converser.py => text_service_cog.py} (52%) create mode 100644 services/__init__.py rename models/deletion_service_model.py => services/deletion_service.py (100%) rename models/env_service_model.py => services/environment_service.py (100%) rename cogs/draw_image_generation.py => services/image_service.py (71%) rename models/message_model.py => services/message_queue_service.py (100%) rename models/moderations_service_model.py => services/moderations_service.py (97%) rename models/pinecone_service_model.py => services/pinecone_service.py (100%) create mode 100644 services/text_service.py rename models/usage_service_model.py => services/usage_service.py (100%) diff --git a/cogs/commands.py b/cogs/commands.py index c164eb8..861c11a 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -1,7 +1,7 @@ import discord from pycord.multicog import add_to_group -from models.env_service_model import EnvService +from services.environment_service import EnvService from models.check_model import Check from models.autocomplete_model import Settings_autocompleter, File_autocompleter diff --git a/cogs/image_service_cog.py b/cogs/image_service_cog.py new file mode 100644 index 0000000..d14ca9f --- /dev/null +++ b/cogs/image_service_cog.py @@ -0,0 +1,101 @@ +import asyncio +import os +import tempfile +import traceback +from io import BytesIO + +import aiohttp +import discord +from PIL import Image + +# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time +from sqlitedict import SqliteDict + +from cogs.text_service_cog import GPT3ComCon +from services.environment_service import EnvService +from models.user_model import RedoUser +from services.image_service import ImageService +from services.text_service import TextService + +users_to_interactions = {} +ALLOWED_GUILDS = EnvService.get_allowed_guilds() + +USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() +USER_KEY_DB = None +if USER_INPUT_API_KEYS: + USER_KEY_DB = SqliteDict("user_key_db.sqlite") + + +class DrawDallEService(discord.Cog, name="DrawDallEService"): + def __init__( + self, bot, usage_service, model, message_queue, deletion_queue, converser_cog + ): + super().__init__() + self.bot = bot + self.usage_service = usage_service + self.model = model + self.message_queue = message_queue + self.deletion_queue = deletion_queue + self.converser_cog = converser_cog + print("Draw service initialized") + self.redo_users = {} + + + async def draw_command(self, ctx: discord.ApplicationContext, prompt: str): + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) + if not user_api_key: + return + + await ctx.defer() + + user = ctx.user + + if user == self.bot.user: + return + + try: + asyncio.ensure_future( + ImageService.encapsulated_send( + self, + user.id, prompt, ctx, custom_api_key=user_api_key + ) + ) + + except Exception as e: + print(e) + traceback.print_exc() + await ctx.respond("Something went wrong. Please try again later.") + await ctx.send_followup(e) + + async def local_size_command(self, ctx: discord.ApplicationContext): + await ctx.defer() + # Get the size of the dall-e images folder that we have on the current system. + + image_path = self.model.IMAGE_SAVE_PATH + total_size = 0 + for dirpath, dirnames, filenames in os.walk(image_path): + for f in filenames: + fp = os.path.join(dirpath, f) + total_size += os.path.getsize(fp) + + # Format the size to be in MB and send. + total_size = total_size / 1000000 + await ctx.respond(f"The size of the local images folder is {total_size} MB.") + + async def clear_local_command(self, ctx): + await ctx.defer() + + # Delete all the local images in the images folder. + image_path = self.model.IMAGE_SAVE_PATH + for dirpath, dirnames, filenames in os.walk(image_path): + for f in filenames: + try: + fp = os.path.join(dirpath, f) + os.remove(fp) + except Exception as e: + print(e) + + await ctx.respond("Local images cleared.") + diff --git a/cogs/image_prompt_optimizer.py b/cogs/prompt_optimizer_cog.py similarity index 93% rename from cogs/image_prompt_optimizer.py rename to cogs/prompt_optimizer_cog.py index e36ef80..687a039 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/prompt_optimizer_cog.py @@ -4,10 +4,11 @@ import traceback import discord from sqlitedict import SqliteDict -from cogs.gpt_3_commands_and_converser import GPT3ComCon -from models.env_service_model import EnvService +from services.environment_service import EnvService from models.user_model import RedoUser -from pycord.multicog import add_to_group +from services.image_service import ImageService + +from services.text_service import TextService ALLOWED_GUILDS = EnvService.get_allowed_guilds() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() @@ -55,7 +56,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): async def optimize_command(self, ctx: discord.ApplicationContext, prompt: str): user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx) + user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx) if not user_api_key: return @@ -77,7 +78,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): try: response = await self.model.send_request( final_prompt, - tokens=70, + tokens=60, top_p_override=1.0, temp_override=0.9, presence_penalty_override=0.5, @@ -217,7 +218,8 @@ class DrawButton(discord.ui.Button["OptimizeView"]): prompt = re.sub(r"Optimized Prompt: ?", "", prompt) # Call the image service cog to draw the image - await self.image_service_cog.encapsulated_send( + await ImageService.encapsulated_send( + self.image_service_cog, user_id, prompt, interaction, @@ -255,7 +257,8 @@ class RedoButton(discord.ui.Button["OptimizeView"]): msg = await interaction.response.send_message( "Redoing your original request...", ephemeral=True, delete_after=20 ) - await self.converser_cog.encapsulated_send( + await TextService.encapsulated_send( + self.converser_cog, id=user_id, prompt=prompt, ctx=ctx, diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/text_service_cog.py similarity index 52% rename from cogs/gpt_3_commands_and_converser.py rename to cogs/text_service_cog.py index 8b962ee..bcba0be 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/text_service_cog.py @@ -1,6 +1,5 @@ import asyncio import datetime -import json import re import traceback import sys @@ -10,22 +9,17 @@ from pathlib import Path import aiofiles import json -import aiohttp import discord -from discord.ext import pages -from pycord.multicog import add_to_group - -from models.deletion_service_model import Deletion -from models.env_service_model import EnvService -from models.message_model import Message -from models.moderations_service_model import Moderation -from models.openai_model import Model -from models.user_model import RedoUser, Thread, EmbeddedConversationItem -from models.check_model import Check -from models.autocomplete_model import Settings_autocompleter, File_autocompleter + +from services.environment_service import EnvService +from services.message_queue_service import Message +from services.moderations_service import Moderation +from models.user_model import Thread, EmbeddedConversationItem from collections import defaultdict from sqlitedict import SqliteDict +from services.text_service import SetupModal, TextService + original_message = {} ALLOWED_GUILDS = EnvService.get_allowed_guilds() if sys.platform == "win32": @@ -169,21 +163,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.message_queue = message_queue self.conversation_thread_owners = {} - @staticmethod - async def get_user_api_key(user_id, ctx): - user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id] - if user_api_key is None or user_api_key == "": - modal = SetupModal(title="API Key Setup") - if isinstance(ctx, discord.ApplicationContext): - await ctx.send_modal(modal) - await ctx.send_followup( - "You must set up your API key before using this command." - ) - else: - await ctx.reply( - "You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key." - ) - return user_api_key async def load_file(self, file, ctx): try: @@ -544,49 +523,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): Moderation(after, timestamp) ) - if after.author.id in self.redo_users: - if after.id == original_message[after.author.id]: - response_message = self.redo_users[after.author.id].response - ctx = self.redo_users[after.author.id].ctx - await response_message.edit(content="Redoing prompt 🔄...") - - edited_content = await self.mention_to_username(after, after.content) - - if after.channel.id in self.conversation_threads: - # Remove the last two elements from the history array and add the new : prompt - self.conversation_threads[ - after.channel.id - ].history = self.conversation_threads[after.channel.id].history[:-2] - - pinecone_dont_reinsert = None - if not self.pinecone_service: - self.conversation_threads[after.channel.id].history.append( - EmbeddedConversationItem( - f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n", - 0, - ) - ) - - self.conversation_threads[after.channel.id].count += 1 - - overrides = self.conversation_threads[after.channel.id].get_overrides() - - await self.encapsulated_send( - id=after.channel.id, - prompt=edited_content, - ctx=ctx, - response_message=response_message, - temp_override=overrides["temperature"], - top_p_override=overrides["top_p"], - frequency_penalty_override=overrides["frequency_penalty"], - presence_penalty_override=overrides["presence_penalty"], - model=self.conversation_threads[after.channel.id].model, - edited_request=True, - ) - - if not self.pinecone_service: - self.redo_users[after.author.id].prompt = edited_content - + await TextService.process_conversation_edit(self, after, original_message) async def check_and_launch_moderations(self, guild_id, alert_channel_override=None): # Create the moderations service. print("Checking and attempting to launch moderations service...") @@ -630,114 +567,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): Moderation(message, timestamp) ) - conversing = self.check_conversing( - message.author.id, message.channel.id, content - ) - - # If the user is conversing and they want to end it, end it immediately before we continue any further. - if conversing and message.content.lower() in self.END_PROMPTS: - await self.end_conversation(message) - return - - if conversing: - user_api_key = None - if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key( - message.author.id, message - ) - if not user_api_key: - return - - prompt = await self.mention_to_username(message, content) - - await self.check_conversation_limit(message) - - # If the user is in a conversation thread - if message.channel.id in self.conversation_threads: - - # Since this is async, we don't want to allow the user to send another prompt while a conversation - # prompt is processing, that'll mess up the conversation history! - if message.author.id in self.awaiting_responses: - message = await message.reply( - "You are already waiting for a response from GPT3. Please wait for it to respond before sending another message." - ) - - # get the current date, add 10 seconds to it, and then turn it into a timestamp. - # we need to use our deletion service because this isn't an interaction, it's a regular message. - deletion_time = datetime.datetime.now() + datetime.timedelta( - seconds=10 - ) - deletion_time = deletion_time.timestamp() - - deletion_message = Deletion(message, deletion_time) - await self.deletion_queue.put(deletion_message) - - return - - if message.channel.id in self.awaiting_thread_responses: - message = await message.reply( - "This thread is already waiting for a response from GPT3. Please wait for it to respond before sending another message." - ) - - # get the current date, add 10 seconds to it, and then turn it into a timestamp. - # we need to use our deletion service because this isn't an interaction, it's a regular message. - deletion_time = datetime.datetime.now() + datetime.timedelta( - seconds=10 - ) - deletion_time = deletion_time.timestamp() - - deletion_message = Deletion(message, deletion_time) - await self.deletion_queue.put(deletion_message) - return - - self.awaiting_responses.append(message.author.id) - self.awaiting_thread_responses.append(message.channel.id) - - original_message[message.author.id] = message.id - - if not self.pinecone_service: - self.conversation_threads[message.channel.id].history.append( - EmbeddedConversationItem( - f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n", - 0, - ) - ) - - # increment the conversation counter for the user - self.conversation_threads[message.channel.id].count += 1 - - # Send the request to the model - # If conversing, the prompt to send is the history, otherwise, it's just the prompt - if ( - self.pinecone_service - or message.channel.id not in self.conversation_threads - ): - primary_prompt = prompt - else: - primary_prompt = "".join( - [ - item.text - for item in self.conversation_threads[ - message.channel.id - ].history - ] - ) - - # set conversation overrides - overrides = self.conversation_threads[message.channel.id].get_overrides() - - await self.encapsulated_send( - message.channel.id, - primary_prompt, - message, - temp_override=overrides["temperature"], - top_p_override=overrides["top_p"], - frequency_penalty_override=overrides["frequency_penalty"], - presence_penalty_override=overrides["presence_penalty"], - model=self.conversation_threads[message.channel.id].model, - custom_api_key=user_api_key, - ) + # Process the message if the user is in a conversation + if await TextService.process_conversation_message(self, message, USER_INPUT_API_KEYS, USER_KEY_DB): + original_message[message.author.id] = message.id def cleanse_response(self, response_text): response_text = response_text.replace("GPTie:\n", "") @@ -767,433 +600,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): pass return message - # ctx can be of type AppContext(interaction) or Message - async def encapsulated_send( - self, - id, - prompt, - ctx, - response_message=None, - temp_override=None, - top_p_override=None, - frequency_penalty_override=None, - presence_penalty_override=None, - from_ask_command=False, - instruction=None, - from_edit_command=False, - codex=False, - model=None, - custom_api_key=None, - edited_request=False, - redo_request=False, - ): - new_prompt = ( - prompt + "\nGPTie: " - if not from_ask_command and not from_edit_command - else prompt - ) - - from_context = isinstance(ctx, discord.ApplicationContext) - - if not instruction: - tokens = self.usage_service.count_tokens(new_prompt) - else: - tokens = self.usage_service.count_tokens( - new_prompt - ) + self.usage_service.count_tokens(instruction) - - try: - - # Pinecone is enabled, we will create embeddings for this conversation. - if self.pinecone_service and ctx.channel.id in self.conversation_threads: - # Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists - # check if the text attribute for any object inside self.conversation_threads[converation_id].history - # contains ""GPTie: <|endofstatement|>"", if so, delete - for item in self.conversation_threads[ctx.channel.id].history: - if item.text.strip() == "GPTie:<|endofstatement|>": - self.conversation_threads[ctx.channel.id].history.remove(item) - - # The conversation_id is the id of the thread - conversation_id = ctx.channel.id - - # Create an embedding and timestamp for the prompt - new_prompt = prompt.encode("ascii", "ignore").decode() - prompt_less_author = f"{new_prompt} <|endofstatement|>\n" - - user_displayname = ctx.author.display_name - - new_prompt = ( - f"\n'{user_displayname}': {new_prompt} <|endofstatement|>\n" - ) - new_prompt = new_prompt.encode("ascii", "ignore").decode() - - timestamp = int( - str(datetime.datetime.now().timestamp()).replace(".", "") - ) - - new_prompt_item = EmbeddedConversationItem(new_prompt, timestamp) - - if not redo_request: - self.conversation_threads[conversation_id].history.append( - new_prompt_item - ) - - if edited_request: - new_prompt = "".join( - [ - item.text - for item in self.conversation_threads[ - ctx.channel.id - ].history - ] - ) - self.redo_users[ctx.author.id].prompt = new_prompt - else: - # Create and upsert the embedding for the conversation id, prompt, timestamp - await self.pinecone_service.upsert_conversation_embedding( - self.model, - conversation_id, - new_prompt, - timestamp, - custom_api_key=custom_api_key, - ) - - embedding_prompt_less_author = await self.model.send_embedding_request( - prompt_less_author, custom_api_key=custom_api_key - ) # Use the version of the prompt without the author's name for better clarity on retrieval. - - # Now, build the new prompt by getting the X most similar with pinecone - similar_prompts = self.pinecone_service.get_n_similar( - conversation_id, - embedding_prompt_less_author, - n=self.model.num_conversation_lookback, - ) - - # When we are in embeddings mode, only the pre-text is contained in self.conversation_threads[message.channel.id].history, so we - # can use that as a base to build our new prompt - prompt_with_history = [ - self.conversation_threads[ctx.channel.id].history[0] - ] - - # Append the similar prompts to the prompt with history - prompt_with_history += [ - EmbeddedConversationItem(prompt, timestamp) - for prompt, timestamp in similar_prompts - ] - - # iterate UP TO the last X prompts in the history - for i in range( - 1, - min( - len(self.conversation_threads[ctx.channel.id].history), - self.model.num_static_conversation_items, - ), - ): - prompt_with_history.append( - self.conversation_threads[ctx.channel.id].history[-i] - ) - - # remove duplicates from prompt_with_history and set the conversation history - prompt_with_history = list(dict.fromkeys(prompt_with_history)) - self.conversation_threads[ - ctx.channel.id - ].history = prompt_with_history - - # Sort the prompt_with_history by increasing timestamp if pinecone is enabled - if self.pinecone_service: - prompt_with_history.sort(key=lambda x: x.timestamp) - - # Ensure that the last prompt in this list is the prompt we just sent (new_prompt_item) - if prompt_with_history[-1] != new_prompt_item: - try: - prompt_with_history.remove(new_prompt_item) - except ValueError: - pass - prompt_with_history.append(new_prompt_item) - - prompt_with_history = "".join( - [item.text for item in prompt_with_history] - ) - - new_prompt = prompt_with_history + "\nGPTie: " - - tokens = self.usage_service.count_tokens(new_prompt) - - # No pinecone, we do conversation summarization for long term memory instead - elif ( - id in self.conversation_threads - and tokens > self.model.summarize_threshold - and not from_ask_command - and not from_edit_command - and not self.pinecone_service # This should only happen if we are not doing summarizations. - ): - - # We don't need to worry about the differences between interactions and messages in this block, - # because if we are in this block, we can only be using a message object for ctx - if self.model.summarize_conversations: - await ctx.reply( - "I'm currently summarizing our current conversation so we can keep chatting, " - "give me one moment!" - ) - - await self.summarize_conversation(ctx, new_prompt) - - # Check again if the prompt is about to go past the token limit - new_prompt = ( - "".join( - [ - item.text - for item in self.conversation_threads[id].history - ] - ) - + "\nGPTie: " - ) - - tokens = self.usage_service.count_tokens(new_prompt) - - if ( - tokens > self.model.summarize_threshold - 150 - ): # 150 is a buffer for the second stage - await ctx.reply( - "I tried to summarize our current conversation so we could keep chatting, " - "but it still went over the token " - "limit. Please try again later." - ) - - await self.end_conversation(ctx) - return - else: - await ctx.reply("The conversation context limit has been reached.") - await self.end_conversation(ctx) - return - - # Send the request to the model - if from_edit_command: - response = await self.model.send_edit_request( - input=new_prompt, - instruction=instruction, - temp_override=temp_override, - top_p_override=top_p_override, - codex=codex, - custom_api_key=custom_api_key, - ) - else: - response = await self.model.send_request( - new_prompt, - tokens=tokens, - temp_override=temp_override, - top_p_override=top_p_override, - frequency_penalty_override=frequency_penalty_override, - presence_penalty_override=presence_penalty_override, - model=model, - custom_api_key=custom_api_key, - ) - - # Clean the request response - response_text = self.cleanse_response(str(response["choices"][0]["text"])) - - if from_ask_command: - # Append the prompt to the beginning of the response, in italics, then a new line - response_text = response_text.strip() - response_text = f"***{prompt}***\n\n{response_text}" - elif from_edit_command: - if codex: - response_text = response_text.strip() - response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n```\n{response_text}\n```" - else: - response_text = response_text.strip() - response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n{response_text}\n" - - # If gpt3 tries writing a user mention try to replace it with their name - response_text = await self.mention_to_username(ctx, response_text) - - # If the user is conversing, add the GPT response to their conversation history. - if ( - id in self.conversation_threads - and not from_ask_command - and not self.pinecone_service - ): - if not redo_request: - self.conversation_threads[id].history.append( - EmbeddedConversationItem( - "\nGPTie: " + str(response_text) + "<|endofstatement|>\n", 0 - ) - ) - - # Embeddings case! - elif ( - id in self.conversation_threads - and not from_ask_command - and not from_edit_command - and self.pinecone_service - ): - conversation_id = id - - # Create an embedding and timestamp for the prompt - response_text = ( - "\nGPTie: " + str(response_text) + "<|endofstatement|>\n" - ) - - response_text = response_text.encode("ascii", "ignore").decode() - - # Print the current timestamp - timestamp = int( - str(datetime.datetime.now().timestamp()).replace(".", "") - ) - self.conversation_threads[conversation_id].history.append( - EmbeddedConversationItem(response_text, timestamp) - ) - - # Create and upsert the embedding for the conversation id, prompt, timestamp - embedding = await self.pinecone_service.upsert_conversation_embedding( - self.model, - conversation_id, - response_text, - timestamp, - custom_api_key=custom_api_key, - ) - - # Cleanse again - response_text = self.cleanse_response(response_text) - - # escape any other mentions like @here or @everyone - response_text = discord.utils.escape_mentions(response_text) - - - # If we don't have a response message, we are not doing a redo, send as a new message(s) - if not response_message: - if len(response_text) > self.TEXT_CUTOFF: - if not from_context: - paginator = None - await self.paginate_and_send(response_text, ctx) - else: - embed_pages = await self.paginate_embed(response_text, codex, prompt, instruction) - view=ConversationView(ctx, self, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key) - paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view) - response_message = await paginator.respond(ctx.interaction) - else: - paginator = None - if not from_context: - response_message = await ctx.reply( - response_text, - view=ConversationView( - ctx, - self, - ctx.channel.id, - model, - custom_api_key=custom_api_key, - ), - ) - elif from_edit_command: - response_message = await ctx.respond( - response_text, - view=ConversationView( - ctx, - self, - ctx.channel.id, - model, - from_edit_command=from_edit_command, - custom_api_key=custom_api_key - ), - ) - else: - response_message = await ctx.respond( - response_text, - view=ConversationView( - ctx, - self, - ctx.channel.id, - model, - from_ask_command=from_ask_command, - custom_api_key=custom_api_key - ), - ) - - if response_message: - # Get the actual message object of response_message in case it's an WebhookMessage - actual_response_message = ( - response_message - if not from_context - else await ctx.fetch_message(response_message.id) - ) - - self.redo_users[ctx.author.id] = RedoUser( - prompt=new_prompt, - instruction=instruction, - ctx=ctx, - message=ctx, - response=actual_response_message, - codex=codex, - paginator=paginator - ) - self.redo_users[ctx.author.id].add_interaction( - actual_response_message.id - ) - - # We are doing a redo, edit the message. - else: - paginator = self.redo_users.get(ctx.author.id).paginator - if isinstance(paginator, pages.Paginator): - embed_pages = await self.paginate_embed(response_text, codex, prompt, instruction) - view=ConversationView(ctx, self, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key) - await paginator.update(pages=embed_pages, custom_view=view) - elif len(response_text) > self.TEXT_CUTOFF: - if not from_context: - await response_message.channel.send("Over 2000 characters", delete_after=5) - else: - await response_message.edit(content=response_text) - - await self.send_debug_message( - self.generate_debug_message(prompt, response), self.debug_channel - ) - - if ctx.author.id in self.awaiting_responses: - self.awaiting_responses.remove(ctx.author.id) - if not from_ask_command and not from_edit_command: - if ctx.channel.id in self.awaiting_thread_responses: - self.awaiting_thread_responses.remove(ctx.channel.id) - - # Error catching for AIOHTTP Errors - except aiohttp.ClientResponseError as e: - message = ( - f"The API returned an invalid response: **{e.status}: {e.message}**" - ) - if from_context: - await ctx.send_followup(message) - else: - await ctx.reply(message) - self.remove_awaiting( - ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command - ) - - # Error catching for OpenAI model value errors - except ValueError as e: - if from_context: - await ctx.send_followup(e) - else: - await ctx.reply(e) - self.remove_awaiting( - ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command - ) - - # General catch case for everything - except Exception: - - message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting." - await ctx.send_followup(message) if from_context else await ctx.reply( - message - ) - self.remove_awaiting( - ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command - ) - traceback.print_exc() - - try: - await self.end_conversation(ctx) - except: - pass - return - # COMMANDS async def help_command(self, ctx): @@ -1309,13 +715,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key(user.id, ctx) + user_api_key = await TextService.get_user_api_key(user.id, ctx, USER_KEY_DB) if not user_api_key: return await ctx.defer() - await self.encapsulated_send( + await TextService.encapsulated_send( + self, user.id, prompt, ctx, @@ -1349,7 +756,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.defer() - await self.encapsulated_send( + await TextService.encapsulated_send( + self, user.id, prompt=input, ctx=ctx, @@ -1511,7 +919,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.conversation_threads[thread.id].count += 1 - await self.encapsulated_send( + await TextService.encapsulated_send( + self, thread.id, opener if thread.id not in self.conversation_threads or self.pinecone_service @@ -1650,198 +1059,3 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): "alert_channel": self.get_moderated_alert_channel(guild_id), } MOD_DB.commit() - -# VIEWS AND MODALS - -class ConversationView(discord.ui.View): - def __init__( - self, - ctx, - converser_cog, - id, - model, - from_ask_command=False, - from_edit_command=False, - custom_api_key=None, - ): - super().__init__(timeout=3600) # 1 hour interval to redo. - self.converser_cog = converser_cog - self.ctx = ctx - self.model = model - self.from_ask_command = from_ask_command - self.from_edit_command = from_edit_command - self.custom_api_key = custom_api_key - self.add_item( - RedoButton( - self.converser_cog, - model=model, - from_ask_command=from_ask_command, - from_edit_command=from_edit_command, - custom_api_key=self.custom_api_key, - ) - ) - - if id in self.converser_cog.conversation_threads: - self.add_item(EndConvoButton(self.converser_cog)) - - async def on_timeout(self): - # Remove the button from the view/message - self.clear_items() - # Send a message to the user saying the view has timed out - if self.message: - await self.message.edit( - view=None, - ) - else: - await self.ctx.edit( - view=None, - ) - - -class EndConvoButton(discord.ui.Button["ConversationView"]): - def __init__(self, converser_cog): - super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end") - self.converser_cog = converser_cog - - async def callback(self, interaction: discord.Interaction): - - # Get the user - user_id = interaction.user.id - if ( - user_id in self.converser_cog.conversation_thread_owners - and self.converser_cog.conversation_thread_owners[user_id] - == interaction.channel.id - ): - try: - await self.converser_cog.end_conversation( - interaction, opener_user_id=interaction.user.id - ) - except Exception as e: - print(e) - traceback.print_exc() - await interaction.response.send_message( - e, ephemeral=True, delete_after=30 - ) - pass - else: - await interaction.response.send_message( - "This is not your conversation to end!", ephemeral=True, delete_after=10 - ) - - -class RedoButton(discord.ui.Button["ConversationView"]): - def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key): - super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo") - self.converser_cog = converser_cog - self.model = model - self.from_ask_command = from_ask_command - self.from_edit_command = from_edit_command - self.custom_api_key = custom_api_key - - async def callback(self, interaction: discord.Interaction): - - # Get the user - user_id = interaction.user.id - if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[ - user_id - ].in_interaction(interaction.message.id): - # Get the message and the prompt and call encapsulated_send - prompt = self.converser_cog.redo_users[user_id].prompt - instruction = self.converser_cog.redo_users[user_id].instruction - ctx = self.converser_cog.redo_users[user_id].ctx - response_message = self.converser_cog.redo_users[user_id].response - codex = self.converser_cog.redo_users[user_id].codex - - msg = await interaction.response.send_message( - "Retrying your original request...", ephemeral=True, delete_after=15 - ) - - await self.converser_cog.encapsulated_send( - id=user_id, - prompt=prompt, - instruction=instruction, - ctx=ctx, - model=self.model, - response_message=response_message, - codex=codex, - custom_api_key=self.custom_api_key, - redo_request=True, - from_ask_command=self.from_ask_command, - from_edit_command=self.from_edit_command, - ) - else: - await interaction.response.send_message( - "You can only redo the most recent prompt that you sent yourself.", - ephemeral=True, - delete_after=10, - ) - - -class SetupModal(discord.ui.Modal): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self.add_item( - discord.ui.InputText( - label="OpenAI API Key", - placeholder="sk--......", - ) - ) - - async def callback(self, interaction: discord.Interaction): - user = interaction.user - api_key = self.children[0].value - # Validate that api_key is indeed in this format - if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key): - await interaction.response.send_message( - "Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.", - ephemeral=True, - delete_after=100, - ) - else: - # We can save the key for the user to the database. - - # Make a test request using the api key to ensure that it is valid. - try: - await Model.send_test_request(api_key) - await interaction.response.send_message( - "Your API key was successfully validated.", - ephemeral=True, - delete_after=10, - ) - - except aiohttp.ClientResponseError as e: - await interaction.response.send_message( - f"The API returned an invalid response: **{e.status}: {e.message}**", - ephemeral=True, - delete_after=30, - ) - return - - except Exception as e: - await interaction.response.send_message( - f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding", - ephemeral=True, - delete_after=30, - ) - return - - # Save the key to the database - try: - USER_KEY_DB[user.id] = api_key - USER_KEY_DB.commit() - await interaction.followup.send( - "Your API key was successfully saved.", - ephemeral=True, - delete_after=10, - ) - except Exception as e: - traceback.print_exc() - await interaction.followup.send( - "There was an error saving your API key.", - ephemeral=True, - delete_after=30, - ) - return - - pass diff --git a/gpt3discord.py b/gpt3discord.py index b2bdd35..08aafda 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -8,22 +8,22 @@ import pinecone from pycord.multicog import apply_multicog import os -from models.pinecone_service_model import PineconeService +from services.pinecone_service import PineconeService if sys.platform == "win32": separator = "\\" else: separator = "/" -from cogs.draw_image_generation import DrawDallEService -from cogs.gpt_3_commands_and_converser import GPT3ComCon -from cogs.image_prompt_optimizer import ImgPromptOptimizer +from cogs.image_service_cog import DrawDallEService +from cogs.text_service_cog import GPT3ComCon +from cogs.prompt_optimizer_cog import ImgPromptOptimizer from cogs.commands import Commands -from models.deletion_service_model import Deletion -from models.message_model import Message +from services.deletion_service import Deletion +from services.message_queue_service import Message from models.openai_model import Model -from models.usage_service_model import UsageService -from models.env_service_model import EnvService +from services.usage_service import UsageService +from services.environment_service import EnvService __version__ = "6.0" diff --git a/models/autocomplete_model.py b/models/autocomplete_model.py index b78dbeb..62371c1 100644 --- a/models/autocomplete_model.py +++ b/models/autocomplete_model.py @@ -3,9 +3,9 @@ import os import re import discord -from models.usage_service_model import UsageService +from services.usage_service import UsageService from models.openai_model import Model -from models.env_service_model import EnvService +from services.environment_service import EnvService usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) model = Model(usage_service) diff --git a/models/check_model.py b/models/check_model.py index e3b0e17..43d30d6 100644 --- a/models/check_model.py +++ b/models/check_model.py @@ -1,6 +1,6 @@ import discord -from models.env_service_model import EnvService +from services.environment_service import EnvService from typing import Callable ADMIN_ROLES = EnvService.get_admin_roles() diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/deletion_service_model.py b/services/deletion_service.py similarity index 100% rename from models/deletion_service_model.py rename to services/deletion_service.py diff --git a/models/env_service_model.py b/services/environment_service.py similarity index 100% rename from models/env_service_model.py rename to services/environment_service.py diff --git a/cogs/draw_image_generation.py b/services/image_service.py similarity index 71% rename from cogs/draw_image_generation.py rename to services/image_service.py index aaa8eb1..f202250 100644 --- a/cogs/draw_image_generation.py +++ b/services/image_service.py @@ -1,452 +1,373 @@ -import asyncio -import os -import tempfile -import traceback -from io import BytesIO - -import aiohttp -import discord -from PIL import Image -from pycord.multicog import add_to_group - - -# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time -from sqlitedict import SqliteDict - -from cogs.gpt_3_commands_and_converser import GPT3ComCon -from models.env_service_model import EnvService -from models.user_model import RedoUser - -redo_users = {} -users_to_interactions = {} -ALLOWED_GUILDS = EnvService.get_allowed_guilds() - -USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() -USER_KEY_DB = None -if USER_INPUT_API_KEYS: - USER_KEY_DB = SqliteDict("user_key_db.sqlite") - - -class DrawDallEService(discord.Cog, name="DrawDallEService"): - def __init__( - self, bot, usage_service, model, message_queue, deletion_queue, converser_cog - ): - super().__init__() - self.bot = bot - self.usage_service = usage_service - self.model = model - self.message_queue = message_queue - self.deletion_queue = deletion_queue - self.converser_cog = converser_cog - print("Draw service initialized") - - async def encapsulated_send( - self, - user_id, - prompt, - ctx, - response_message=None, - vary=None, - draw_from_optimizer=None, - custom_api_key=None, - ): - await asyncio.sleep(0) - # send the prompt to the model - from_context = isinstance(ctx, discord.ApplicationContext) - - try: - file, image_urls = await self.model.send_image_request( - ctx, - prompt, - vary=vary if not draw_from_optimizer else None, - custom_api_key=custom_api_key, - ) - - # Error catching for API errors - except aiohttp.ClientResponseError as e: - message = ( - f"The API returned an invalid response: **{e.status}: {e.message}**" - ) - await ctx.channel.send(message) if not from_context else await ctx.respond( - message - ) - return - - except ValueError as e: - message = f"Error: {e}. Please try again with a different prompt." - await ctx.channel.send(message) if not from_context else await ctx.respond( - message - ) - - return - - # Start building an embed to send to the user with the results of the image generation - embed = discord.Embed( - title="Image Generation Results" - if not vary - else "Image Generation Results (Varying)" - if not draw_from_optimizer - else "Image Generation Results (Drawing from Optimizer)", - description=f"{prompt}", - color=0xC730C7, - ) - - # Add the image file to the embed - embed.set_image(url=f"attachment://{file.filename}") - - if not response_message: # Original generation case - # Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog) - result_message = ( - await ctx.channel.send( - embed=embed, - file=file, - ) - if not from_context - else await ctx.respond(embed=embed, file=file) - ) - - await result_message.edit( - view=SaveView( - ctx, - image_urls, - self, - self.converser_cog, - result_message, - custom_api_key=custom_api_key, - ) - ) - - self.converser_cog.users_to_interactions[user_id] = [] - self.converser_cog.users_to_interactions[user_id].append(result_message.id) - - # Get the actual result message object - if from_context: - result_message = await ctx.fetch_message(result_message.id) - - redo_users[user_id] = RedoUser( - prompt=prompt, - message=ctx, - ctx=ctx, - response=response_message, - instruction=None, - codex=False, - paginator=None - ) - - else: - if not vary: # Editing case - message = await response_message.edit( - embed=embed, - file=file, - ) - await message.edit( - view=SaveView( - ctx, - image_urls, - self, - self.converser_cog, - message, - custom_api_key=custom_api_key, - ) - ) - else: # Varying case - if not draw_from_optimizer: - result_message = await response_message.edit_original_response( - content="Image variation completed!", - embed=embed, - file=file, - ) - await result_message.edit( - view=SaveView( - ctx, - image_urls, - self, - self.converser_cog, - result_message, - True, - custom_api_key=custom_api_key, - ) - ) - - else: - result_message = await response_message.edit_original_response( - content="I've drawn the optimized prompt!", - embed=embed, - file=file, - ) - await result_message.edit( - view=SaveView( - ctx, - image_urls, - self, - self.converser_cog, - result_message, - custom_api_key=custom_api_key, - ) - ) - - redo_users[user_id] = RedoUser( - prompt=prompt, - message=ctx, - ctx=ctx, - response=response_message, - instruction=None, - codex=False, - paginator=None, - ) - - self.converser_cog.users_to_interactions[user_id].append( - response_message.id - ) - self.converser_cog.users_to_interactions[user_id].append( - result_message.id - ) - - async def draw_command(self, ctx: discord.ApplicationContext, prompt: str): - user_api_key = None - if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx) - if not user_api_key: - return - - await ctx.defer() - - user = ctx.user - - if user == self.bot.user: - return - - try: - asyncio.ensure_future( - self.encapsulated_send( - user.id, prompt, ctx, custom_api_key=user_api_key - ) - ) - - except Exception as e: - print(e) - traceback.print_exc() - await ctx.respond("Something went wrong. Please try again later.") - await ctx.send_followup(e) - - async def local_size_command(self, ctx: discord.ApplicationContext): - await ctx.defer() - # Get the size of the dall-e images folder that we have on the current system. - - image_path = self.model.IMAGE_SAVE_PATH - total_size = 0 - for dirpath, dirnames, filenames in os.walk(image_path): - for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) - - # Format the size to be in MB and send. - total_size = total_size / 1000000 - await ctx.respond(f"The size of the local images folder is {total_size} MB.") - - async def clear_local_command(self, ctx): - await ctx.defer() - - # Delete all the local images in the images folder. - image_path = self.model.IMAGE_SAVE_PATH - for dirpath, dirnames, filenames in os.walk(image_path): - for f in filenames: - try: - fp = os.path.join(dirpath, f) - os.remove(fp) - except Exception as e: - print(e) - - await ctx.respond("Local images cleared.") - - -class SaveView(discord.ui.View): - def __init__( - self, - ctx, - image_urls, - cog, - converser_cog, - message, - no_retry=False, - only_save=None, - custom_api_key=None, - ): - super().__init__( - timeout=3600 if not only_save else None - ) # 1 hour timeout for Retry, Save - self.ctx = ctx - self.image_urls = image_urls - self.cog = cog - self.no_retry = no_retry - self.converser_cog = converser_cog - self.message = message - self.custom_api_key = custom_api_key - for x in range(1, len(image_urls) + 1): - self.add_item(SaveButton(x, image_urls[x - 1])) - if not only_save: - if not no_retry: - self.add_item( - RedoButton( - self.cog, - converser_cog=self.converser_cog, - custom_api_key=self.custom_api_key, - ) - ) - for x in range(1, len(image_urls) + 1): - self.add_item( - VaryButton( - x, - image_urls[x - 1], - self.cog, - converser_cog=self.converser_cog, - custom_api_key=self.custom_api_key, - ) - ) - - # On the timeout event, override it and we want to clear the items. - async def on_timeout(self): - # Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then - # update the message - self.clear_items() - - # Create a new view with the same params as this one, but pass only_save=True - new_view = SaveView( - self.ctx, - self.image_urls, - self.cog, - self.converser_cog, - self.message, - self.no_retry, - only_save=True, - ) - - # Set the view of the message to the new view - await self.ctx.edit(view=new_view) - - -class VaryButton(discord.ui.Button): - def __init__(self, number, image_url, cog, converser_cog, custom_api_key): - super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number)) - self.number = number - self.image_url = image_url - self.cog = cog - self.converser_cog = converser_cog - self.custom_api_key = custom_api_key - - async def callback(self, interaction: discord.Interaction): - user_id = interaction.user.id - interaction_id = interaction.message.id - - if interaction_id not in self.converser_cog.users_to_interactions[user_id]: - if len(self.converser_cog.users_to_interactions[user_id]) >= 2: - interaction_id2 = interaction.id - if ( - interaction_id2 - not in self.converser_cog.users_to_interactions[user_id] - ): - await interaction.response.send_message( - content="You can not vary images in someone else's chain!", - ephemeral=True, - ) - else: - await interaction.response.send_message( - content="You can only vary for images that you generated yourself!", - ephemeral=True, - ) - return - - if user_id in redo_users: - response_message = await interaction.response.send_message( - content="Varying image number " + str(self.number) + "..." - ) - self.converser_cog.users_to_interactions[user_id].append( - response_message.message.id - ) - self.converser_cog.users_to_interactions[user_id].append( - response_message.id - ) - prompt = redo_users[user_id].prompt - - asyncio.ensure_future( - self.cog.encapsulated_send( - user_id, - prompt, - interaction.message, - response_message=response_message, - vary=self.image_url, - custom_api_key=self.custom_api_key, - ) - ) - - -class SaveButton(discord.ui.Button["SaveView"]): - def __init__(self, number: int, image_url: str): - super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number)) - self.number = number - self.image_url = image_url - - async def callback(self, interaction: discord.Interaction): - # If the image url doesn't start with "http", then we need to read the file from the URI, and then send the - # file to the user as an attachment. - try: - if not self.image_url.startswith("http"): - with open(self.image_url, "rb") as f: - image = Image.open(BytesIO(f.read())) - temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - image.save(temp_file.name) - - await interaction.response.send_message( - content="Here is your image for download (open original and save)", - file=discord.File(temp_file.name), - ephemeral=True, - ) - else: - await interaction.response.send_message( - f"You can directly download this image from {self.image_url}", - ephemeral=True, - ) - except Exception as e: - await interaction.response.send_message(f"Error: {e}", ephemeral=True) - traceback.print_exc() - - -class RedoButton(discord.ui.Button["SaveView"]): - def __init__(self, cog, converser_cog, custom_api_key): - super().__init__(style=discord.ButtonStyle.danger, label="Retry") - self.cog = cog - self.converser_cog = converser_cog - self.custom_api_key = custom_api_key - - async def callback(self, interaction: discord.Interaction): - user_id = interaction.user.id - interaction_id = interaction.message.id - - if interaction_id not in self.converser_cog.users_to_interactions[user_id]: - await interaction.response.send_message( - content="You can only retry for prompts that you generated yourself!", - ephemeral=True, - ) - return - - # We have passed the intial check of if the interaction belongs to the user - if user_id in redo_users: - # Get the message and the prompt and call encapsulated_send - ctx = redo_users[user_id].ctx - prompt = redo_users[user_id].prompt - response_message = redo_users[user_id].response - message = await interaction.response.send_message( - f"Regenerating the image for your original prompt, check the original message.", - ephemeral=True, - ) - self.converser_cog.users_to_interactions[user_id].append(message.id) - - asyncio.ensure_future( - self.cog.encapsulated_send( - user_id, - prompt, - ctx, - response_message, - custom_api_key=self.custom_api_key, - ) - ) +import asyncio +import tempfile +import traceback +from io import BytesIO + +import aiohttp +import discord +from PIL import Image + +from models.user_model import RedoUser + + +class ImageService: + + def __init__(self): + pass + + @staticmethod + async def encapsulated_send( + image_service_cog, + user_id, + prompt, + ctx, + response_message=None, + vary=None, + draw_from_optimizer=None, + custom_api_key=None, + ): + await asyncio.sleep(0) + # send the prompt to the model + from_context = isinstance(ctx, discord.ApplicationContext) + + try: + file, image_urls = await image_service_cog.model.send_image_request( + ctx, + prompt, + vary=vary if not draw_from_optimizer else None, + custom_api_key=custom_api_key, + ) + + # Error catching for API errors + except aiohttp.ClientResponseError as e: + message = ( + f"The API returned an invalid response: **{e.status}: {e.message}**" + ) + await ctx.channel.send(message) if not from_context else await ctx.respond( + message + ) + return + + except ValueError as e: + message = f"Error: {e}. Please try again with a different prompt." + await ctx.channel.send(message) if not from_context else await ctx.respond( + message + ) + + return + + # Start building an embed to send to the user with the results of the image generation + embed = discord.Embed( + title="Image Generation Results" + if not vary + else "Image Generation Results (Varying)" + if not draw_from_optimizer + else "Image Generation Results (Drawing from Optimizer)", + description=f"{prompt}", + color=0xC730C7, + ) + + # Add the image file to the embed + embed.set_image(url=f"attachment://{file.filename}") + + if not response_message: # Original generation case + # Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, image_service_cog, image_service_cog.converser_cog) + result_message = ( + await ctx.channel.send( + embed=embed, + file=file, + ) + if not from_context + else await ctx.respond(embed=embed, file=file) + ) + + await result_message.edit( + view=SaveView( + ctx, + image_urls, + image_service_cog, + image_service_cog.converser_cog, + result_message, + custom_api_key=custom_api_key, + ) + ) + + image_service_cog.converser_cog.users_to_interactions[user_id] = [] + image_service_cog.converser_cog.users_to_interactions[user_id].append(result_message.id) + + # Get the actual result message object + if from_context: + result_message = await ctx.fetch_message(result_message.id) + + image_service_cog.redo_users[user_id] = RedoUser( + prompt=prompt, + message=ctx, + ctx=ctx, + response=response_message, + instruction=None, + codex=False, + paginator=None + ) + + else: + if not vary: # Editing case + message = await response_message.edit( + embed=embed, + file=file, + ) + await message.edit( + view=SaveView( + ctx, + image_urls, + image_service_cog, + image_service_cog.converser_cog, + message, + custom_api_key=custom_api_key, + ) + ) + else: # Varying case + if not draw_from_optimizer: + result_message = await response_message.edit_original_response( + content="Image variation completed!", + embed=embed, + file=file, + ) + await result_message.edit( + view=SaveView( + ctx, + image_urls, + image_service_cog, + image_service_cog.converser_cog, + result_message, + True, + custom_api_key=custom_api_key, + ) + ) + + else: + result_message = await response_message.edit_original_response( + content="I've drawn the optimized prompt!", + embed=embed, + file=file, + ) + await result_message.edit( + view=SaveView( + ctx, + image_urls, + image_service_cog, + image_service_cog.converser_cog, + result_message, + custom_api_key=custom_api_key, + ) + ) + + image_service_cog.redo_users[user_id] = RedoUser( + prompt=prompt, + message=ctx, + ctx=ctx, + response=response_message, + instruction=None, + codex=False, + paginator=None, + ) + + image_service_cog.converser_cog.users_to_interactions[user_id].append( + response_message.id + ) + image_service_cog.converser_cog.users_to_interactions[user_id].append( + result_message.id + ) + + +class SaveView(discord.ui.View): + def __init__( + self, + ctx, + image_urls, + cog, + converser_cog, + message, + no_retry=False, + only_save=None, + custom_api_key=None, + ): + super().__init__( + timeout=3600 if not only_save else None + ) # 1 hour timeout for Retry, Save + self.ctx = ctx + self.image_urls = image_urls + self.cog = cog + self.no_retry = no_retry + self.converser_cog = converser_cog + self.message = message + self.custom_api_key = custom_api_key + for x in range(1, len(image_urls) + 1): + self.add_item(SaveButton(x, image_urls[x - 1])) + if not only_save: + if not no_retry: + self.add_item( + RedoButton( + self.cog, + converser_cog=self.converser_cog, + custom_api_key=self.custom_api_key, + ) + ) + for x in range(1, len(image_urls) + 1): + self.add_item( + VaryButton( + x, + image_urls[x - 1], + self.cog, + converser_cog=self.converser_cog, + custom_api_key=self.custom_api_key, + ) + ) + + # On the timeout event, override it and we want to clear the items. + async def on_timeout(self): + # Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then + # update the message + self.clear_items() + + # Create a new view with the same params as this one, but pass only_save=True + new_view = SaveView( + self.ctx, + self.image_urls, + self.cog, + self.converser_cog, + self.message, + self.no_retry, + only_save=True, + ) + + # Set the view of the message to the new view + await self.ctx.edit(view=new_view) + + +class VaryButton(discord.ui.Button): + def __init__(self, number, image_url, cog, converser_cog, custom_api_key): + super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number)) + self.number = number + self.image_url = image_url + self.cog = cog + self.converser_cog = converser_cog + self.custom_api_key = custom_api_key + + async def callback(self, interaction: discord.Interaction): + user_id = interaction.user.id + interaction_id = interaction.message.id + + if interaction_id not in self.converser_cog.users_to_interactions[user_id]: + if len(self.converser_cog.users_to_interactions[user_id]) >= 2: + interaction_id2 = interaction.id + if ( + interaction_id2 + not in self.converser_cog.users_to_interactions[user_id] + ): + await interaction.response.send_message( + content="You can not vary images in someone else's chain!", + ephemeral=True, + ) + else: + await interaction.response.send_message( + content="You can only vary for images that you generated yourself!", + ephemeral=True, + ) + return + + if user_id in self.cog.redo_users: + response_message = await interaction.response.send_message( + content="Varying image number " + str(self.number) + "..." + ) + self.converser_cog.users_to_interactions[user_id].append( + response_message.message.id + ) + self.converser_cog.users_to_interactions[user_id].append( + response_message.id + ) + prompt = self.cog.redo_users[user_id].prompt + + asyncio.ensure_future( + ImageService.encapsulated_send( + self.cog, + user_id, + prompt, + interaction.message, + response_message=response_message, + vary=self.image_url, + custom_api_key=self.custom_api_key, + ) + ) + + +class SaveButton(discord.ui.Button["SaveView"]): + def __init__(self, number: int, image_url: str): + super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number)) + self.number = number + self.image_url = image_url + + async def callback(self, interaction: discord.Interaction): + # If the image url doesn't start with "http", then we need to read the file from the URI, and then send the + # file to the user as an attachment. + try: + if not self.image_url.startswith("http"): + with open(self.image_url, "rb") as f: + image = Image.open(BytesIO(f.read())) + temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + image.save(temp_file.name) + + await interaction.response.send_message( + content="Here is your image for download (open original and save)", + file=discord.File(temp_file.name), + ephemeral=True, + ) + else: + await interaction.response.send_message( + f"You can directly download this image from {self.image_url}", + ephemeral=True, + ) + except Exception as e: + await interaction.response.send_message(f"Error: {e}", ephemeral=True) + traceback.print_exc() + + +class RedoButton(discord.ui.Button["SaveView"]): + def __init__(self, cog, converser_cog, custom_api_key): + super().__init__(style=discord.ButtonStyle.danger, label="Retry") + self.cog = cog + self.converser_cog = converser_cog + self.custom_api_key = custom_api_key + + async def callback(self, interaction: discord.Interaction): + user_id = interaction.user.id + interaction_id = interaction.message.id + + if interaction_id not in self.converser_cog.users_to_interactions[user_id]: + await interaction.response.send_message( + content="You can only retry for prompts that you generated yourself!", + ephemeral=True, + ) + return + + # We have passed the intial check of if the interaction belongs to the user + if user_id in self.cog.redo_users: + # Get the message and the prompt and call encapsulated_send + ctx = self.cog.redo_users[user_id].ctx + prompt = self.cog.redo_users[user_id].prompt + response_message = self.cog.redo_users[user_id].response + message = await interaction.response.send_message( + f"Regenerating the image for your original prompt, check the original message.", + ephemeral=True, + ) + self.converser_cog.users_to_interactions[user_id].append(message.id) + + asyncio.ensure_future( + ImageService.encapsulated_send( + self.cog, + user_id, + prompt, + ctx, + response_message, + custom_api_key=self.custom_api_key, + ) + ) diff --git a/models/message_model.py b/services/message_queue_service.py similarity index 100% rename from models/message_model.py rename to services/message_queue_service.py diff --git a/models/moderations_service_model.py b/services/moderations_service.py similarity index 97% rename from models/moderations_service_model.py rename to services/moderations_service.py index 07d2540..cb69490 100644 --- a/models/moderations_service_model.py +++ b/services/moderations_service.py @@ -7,7 +7,7 @@ from pathlib import Path import discord from models.openai_model import Model -from models.usage_service_model import UsageService +from services.usage_service import UsageService usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) model = Model(usage_service) diff --git a/models/pinecone_service_model.py b/services/pinecone_service.py similarity index 100% rename from models/pinecone_service_model.py rename to services/pinecone_service.py diff --git a/services/text_service.py b/services/text_service.py new file mode 100644 index 0000000..51a27c4 --- /dev/null +++ b/services/text_service.py @@ -0,0 +1,820 @@ +import datetime +import re +import traceback + +import aiohttp +import discord +from discord.ext import pages + +from services.deletion_service import Deletion +from models.openai_model import Model +from models.user_model import EmbeddedConversationItem, RedoUser + + +class TextService: + def __init__(self): + pass + + @staticmethod + async def encapsulated_send( + converser_cog, + id, + prompt, + ctx, + response_message=None, + temp_override=None, + top_p_override=None, + frequency_penalty_override=None, + presence_penalty_override=None, + from_ask_command=False, + instruction=None, + from_edit_command=False, + codex=False, + model=None, + custom_api_key=None, + edited_request=False, + redo_request=False, + ): + new_prompt = ( + prompt + "\nGPTie: " + if not from_ask_command and not from_edit_command + else prompt + ) + + from_context = isinstance(ctx, discord.ApplicationContext) + + if not instruction: + tokens = converser_cog.usage_service.count_tokens(new_prompt) + else: + tokens = converser_cog.usage_service.count_tokens( + new_prompt + ) + converser_cog.usage_service.count_tokens(instruction) + + try: + + # Pinecone is enabled, we will create embeddings for this conversation. + if converser_cog.pinecone_service and ctx.channel.id in converser_cog.conversation_threads: + # Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists + # check if the text attribute for any object inside converser_cog.conversation_threads[converation_id].history + # contains ""GPTie: <|endofstatement|>"", if so, delete + for item in converser_cog.conversation_threads[ctx.channel.id].history: + if item.text.strip() == "GPTie:<|endofstatement|>": + converser_cog.conversation_threads[ctx.channel.id].history.remove(item) + + # The conversation_id is the id of the thread + conversation_id = ctx.channel.id + + # Create an embedding and timestamp for the prompt + new_prompt = prompt.encode("ascii", "ignore").decode() + prompt_less_author = f"{new_prompt} <|endofstatement|>\n" + + user_displayname = ctx.author.display_name + + new_prompt = ( + f"\n'{user_displayname}': {new_prompt} <|endofstatement|>\n" + ) + new_prompt = new_prompt.encode("ascii", "ignore").decode() + + timestamp = int( + str(datetime.datetime.now().timestamp()).replace(".", "") + ) + + new_prompt_item = EmbeddedConversationItem(new_prompt, timestamp) + + if not redo_request: + converser_cog.conversation_threads[conversation_id].history.append( + new_prompt_item + ) + + if edited_request: + new_prompt = "".join( + [ + item.text + for item in converser_cog.conversation_threads[ + ctx.channel.id + ].history + ] + ) + converser_cog.redo_users[ctx.author.id].prompt = new_prompt + else: + # Create and upsert the embedding for the conversation id, prompt, timestamp + await converser_cog.pinecone_service.upsert_conversation_embedding( + converser_cog.model, + conversation_id, + new_prompt, + timestamp, + custom_api_key=custom_api_key, + ) + + embedding_prompt_less_author = await converser_cog.model.send_embedding_request( + prompt_less_author, custom_api_key=custom_api_key + ) # Use the version of the prompt without the author's name for better clarity on retrieval. + + # Now, build the new prompt by getting the X most similar with pinecone + similar_prompts = converser_cog.pinecone_service.get_n_similar( + conversation_id, + embedding_prompt_less_author, + n=converser_cog.model.num_conversation_lookback, + ) + + # When we are in embeddings mode, only the pre-text is contained in converser_cog.conversation_threads[message.channel.id].history, so we + # can use that as a base to build our new prompt + prompt_with_history = [ + converser_cog.conversation_threads[ctx.channel.id].history[0] + ] + + # Append the similar prompts to the prompt with history + prompt_with_history += [ + EmbeddedConversationItem(prompt, timestamp) + for prompt, timestamp in similar_prompts + ] + + # iterate UP TO the last X prompts in the history + for i in range( + 1, + min( + len(converser_cog.conversation_threads[ctx.channel.id].history), + converser_cog.model.num_static_conversation_items, + ), + ): + prompt_with_history.append( + converser_cog.conversation_threads[ctx.channel.id].history[-i] + ) + + # remove duplicates from prompt_with_history and set the conversation history + prompt_with_history = list(dict.fromkeys(prompt_with_history)) + converser_cog.conversation_threads[ + ctx.channel.id + ].history = prompt_with_history + + # Sort the prompt_with_history by increasing timestamp if pinecone is enabled + if converser_cog.pinecone_service: + prompt_with_history.sort(key=lambda x: x.timestamp) + + # Ensure that the last prompt in this list is the prompt we just sent (new_prompt_item) + if prompt_with_history[-1] != new_prompt_item: + try: + prompt_with_history.remove(new_prompt_item) + except ValueError: + pass + prompt_with_history.append(new_prompt_item) + + prompt_with_history = "".join( + [item.text for item in prompt_with_history] + ) + + new_prompt = prompt_with_history + "\nGPTie: " + + tokens = converser_cog.usage_service.count_tokens(new_prompt) + + # No pinecone, we do conversation summarization for long term memory instead + elif ( + id in converser_cog.conversation_threads + and tokens > converser_cog.model.summarize_threshold + and not from_ask_command + and not from_edit_command + and not converser_cog.pinecone_service # This should only happen if we are not doing summarizations. + ): + + # We don't need to worry about the differences between interactions and messages in this block, + # because if we are in this block, we can only be using a message object for ctx + if converser_cog.model.summarize_conversations: + await ctx.reply( + "I'm currently summarizing our current conversation so we can keep chatting, " + "give me one moment!" + ) + + await converser_cog.summarize_conversation(ctx, new_prompt) + + # Check again if the prompt is about to go past the token limit + new_prompt = ( + "".join( + [ + item.text + for item in converser_cog.conversation_threads[id].history + ] + ) + + "\nGPTie: " + ) + + tokens = converser_cog.usage_service.count_tokens(new_prompt) + + if ( + tokens > converser_cog.model.summarize_threshold - 150 + ): # 150 is a buffer for the second stage + await ctx.reply( + "I tried to summarize our current conversation so we could keep chatting, " + "but it still went over the token " + "limit. Please try again later." + ) + + await converser_cog.end_conversation(ctx) + return + else: + await ctx.reply("The conversation context limit has been reached.") + await converser_cog.end_conversation(ctx) + return + + # Send the request to the model + if from_edit_command: + response = await converser_cog.model.send_edit_request( + input=new_prompt, + instruction=instruction, + temp_override=temp_override, + top_p_override=top_p_override, + codex=codex, + custom_api_key=custom_api_key, + ) + else: + response = await converser_cog.model.send_request( + new_prompt, + tokens=tokens, + temp_override=temp_override, + top_p_override=top_p_override, + frequency_penalty_override=frequency_penalty_override, + presence_penalty_override=presence_penalty_override, + model=model, + custom_api_key=custom_api_key, + ) + + # Clean the request response + response_text = converser_cog.cleanse_response(str(response["choices"][0]["text"])) + + if from_ask_command: + # Append the prompt to the beginning of the response, in italics, then a new line + response_text = response_text.strip() + response_text = f"***{prompt}***\n\n{response_text}" + elif from_edit_command: + if codex: + response_text = response_text.strip() + response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n```\n{response_text}\n```" + else: + response_text = response_text.strip() + response_text = f"***Prompt: {prompt}***\n***Instruction: {instruction}***\n\n{response_text}\n" + + # If gpt3 tries writing a user mention try to replace it with their name + response_text = await converser_cog.mention_to_username(ctx, response_text) + + # If the user is conversing, add the GPT response to their conversation history. + if ( + id in converser_cog.conversation_threads + and not from_ask_command + and not converser_cog.pinecone_service + ): + if not redo_request: + converser_cog.conversation_threads[id].history.append( + EmbeddedConversationItem( + "\nGPTie: " + str(response_text) + "<|endofstatement|>\n", 0 + ) + ) + + # Embeddings case! + elif ( + id in converser_cog.conversation_threads + and not from_ask_command + and not from_edit_command + and converser_cog.pinecone_service + ): + conversation_id = id + + # Create an embedding and timestamp for the prompt + response_text = ( + "\nGPTie: " + str(response_text) + "<|endofstatement|>\n" + ) + + response_text = response_text.encode("ascii", "ignore").decode() + + # Print the current timestamp + timestamp = int( + str(datetime.datetime.now().timestamp()).replace(".", "") + ) + converser_cog.conversation_threads[conversation_id].history.append( + EmbeddedConversationItem(response_text, timestamp) + ) + + # Create and upsert the embedding for the conversation id, prompt, timestamp + embedding = await converser_cog.pinecone_service.upsert_conversation_embedding( + converser_cog.model, + conversation_id, + response_text, + timestamp, + custom_api_key=custom_api_key, + ) + + # Cleanse again + response_text = converser_cog.cleanse_response(response_text) + + # escape any other mentions like @here or @everyone + response_text = discord.utils.escape_mentions(response_text) + + + # If we don't have a response message, we are not doing a redo, send as a new message(s) + if not response_message: + if len(response_text) > converser_cog.TEXT_CUTOFF: + if not from_context: + paginator = None + await converser_cog.paginate_and_send(response_text, ctx) + else: + embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction) + view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key) + paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view) + response_message = await paginator.respond(ctx.interaction) + else: + paginator = None + if not from_context: + response_message = await ctx.reply( + response_text, + view=ConversationView( + ctx, + converser_cog, + ctx.channel.id, + model, + custom_api_key=custom_api_key, + ), + ) + elif from_edit_command: + response_message = await ctx.respond( + response_text, + view=ConversationView( + ctx, + converser_cog, + ctx.channel.id, + model, + from_edit_command=from_edit_command, + custom_api_key=custom_api_key + ), + ) + else: + response_message = await ctx.respond( + response_text, + view=ConversationView( + ctx, + converser_cog, + ctx.channel.id, + model, + from_ask_command=from_ask_command, + custom_api_key=custom_api_key + ), + ) + + if response_message: + # Get the actual message object of response_message in case it's an WebhookMessage + actual_response_message = ( + response_message + if not from_context + else await ctx.fetch_message(response_message.id) + ) + + converser_cog.redo_users[ctx.author.id] = RedoUser( + prompt=new_prompt, + instruction=instruction, + ctx=ctx, + message=ctx, + response=actual_response_message, + codex=codex, + paginator=paginator + ) + converser_cog.redo_users[ctx.author.id].add_interaction( + actual_response_message.id + ) + + # We are doing a redo, edit the message. + else: + paginator = converser_cog.redo_users.get(ctx.author.id).paginator + if isinstance(paginator, pages.Paginator): + embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction) + view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key) + await paginator.update(pages=embed_pages, custom_view=view) + elif len(response_text) > converser_cog.TEXT_CUTOFF: + if not from_context: + await response_message.channel.send("Over 2000 characters", delete_after=5) + else: + await response_message.edit(content=response_text) + + await converser_cog.send_debug_message( + converser_cog.generate_debug_message(prompt, response), converser_cog.debug_channel + ) + + if ctx.author.id in converser_cog.awaiting_responses: + converser_cog.awaiting_responses.remove(ctx.author.id) + if not from_ask_command and not from_edit_command: + if ctx.channel.id in converser_cog.awaiting_thread_responses: + converser_cog.awaiting_thread_responses.remove(ctx.channel.id) + + # Error catching for AIOHTTP Errors + except aiohttp.ClientResponseError as e: + message = ( + f"The API returned an invalid response: **{e.status}: {e.message}**" + ) + if from_context: + await ctx.send_followup(message) + else: + await ctx.reply(message) + converser_cog.remove_awaiting( + ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command + ) + + # Error catching for OpenAI model value errors + except ValueError as e: + if from_context: + await ctx.send_followup(e) + else: + await ctx.reply(e) + converser_cog.remove_awaiting( + ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command + ) + + # General catch case for everything + except Exception: + + message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting." + await ctx.send_followup(message) if from_context else await ctx.reply( + message + ) + converser_cog.remove_awaiting( + ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command + ) + traceback.print_exc() + + try: + await converser_cog.end_conversation(ctx) + except: + pass + return + + @staticmethod + async def process_conversation_message(converser_cog, message, USER_INPUT_API_KEYS, USER_KEY_DB): + content = message.content.strip() + conversing = converser_cog.check_conversing( + message.author.id, message.channel.id, content + ) + + # If the user is conversing and they want to end it, end it immediately before we continue any further. + if conversing and message.content.lower() in converser_cog.END_PROMPTS: + await converser_cog.end_conversation(message) + return + + if conversing: + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key( + message.author.id, message, USER_KEY_DB + ) + if not user_api_key: + return + + prompt = await converser_cog.mention_to_username(message, content) + + await converser_cog.check_conversation_limit(message) + + # If the user is in a conversation thread + if message.channel.id in converser_cog.conversation_threads: + + # Since this is async, we don't want to allow the user to send another prompt while a conversation + # prompt is processing, that'll mess up the conversation history! + if message.author.id in converser_cog.awaiting_responses: + message = await message.reply( + "You are already waiting for a response from GPT3. Please wait for it to respond before sending another message." + ) + + # get the current date, add 10 seconds to it, and then turn it into a timestamp. + # we need to use our deletion service because this isn't an interaction, it's a regular message. + deletion_time = datetime.datetime.now() + datetime.timedelta( + seconds=10 + ) + deletion_time = deletion_time.timestamp() + + deletion_message = Deletion(message, deletion_time) + await converser_cog.deletion_queue.put(deletion_message) + + return + + if message.channel.id in converser_cog.awaiting_thread_responses: + message = await message.reply( + "This thread is already waiting for a response from GPT3. Please wait for it to respond before sending another message." + ) + + # get the current date, add 10 seconds to it, and then turn it into a timestamp. + # we need to use our deletion service because this isn't an interaction, it's a regular message. + deletion_time = datetime.datetime.now() + datetime.timedelta( + seconds=10 + ) + deletion_time = deletion_time.timestamp() + + deletion_message = Deletion(message, deletion_time) + await converser_cog.deletion_queue.put(deletion_message) + + return + + converser_cog.awaiting_responses.append(message.author.id) + converser_cog.awaiting_thread_responses.append(message.channel.id) + + if not converser_cog.pinecone_service: + converser_cog.conversation_threads[message.channel.id].history.append( + EmbeddedConversationItem( + f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n", + 0, + ) + ) + + # increment the conversation counter for the user + converser_cog.conversation_threads[message.channel.id].count += 1 + + # Send the request to the model + # If conversing, the prompt to send is the history, otherwise, it's just the prompt + if ( + converser_cog.pinecone_service + or message.channel.id not in converser_cog.conversation_threads + ): + primary_prompt = prompt + else: + primary_prompt = "".join( + [ + item.text + for item in converser_cog.conversation_threads[ + message.channel.id + ].history + ] + ) + + # set conversation overrides + overrides = converser_cog.conversation_threads[message.channel.id].get_overrides() + + await TextService.encapsulated_send( + converser_cog, + message.channel.id, + primary_prompt, + message, + temp_override=overrides["temperature"], + top_p_override=overrides["top_p"], + frequency_penalty_override=overrides["frequency_penalty"], + presence_penalty_override=overrides["presence_penalty"], + model=converser_cog.conversation_threads[message.channel.id].model, + custom_api_key=user_api_key, + ) + return True + + @staticmethod + async def get_user_api_key(user_id, ctx, USER_KEY_DB): + user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id] + if user_api_key is None or user_api_key == "": + modal = SetupModal(title="API Key Setup",user_key_db=USER_KEY_DB) + if isinstance(ctx, discord.ApplicationContext): + await ctx.send_modal(modal) + await ctx.send_followup( + "You must set up your API key before using this command." + ) + else: + await ctx.reply( + "You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key." + ) + return user_api_key + + @staticmethod + async def process_conversation_edit(converser_cog, after, original_message): + if after.author.id in converser_cog.redo_users: + if after.id == original_message[after.author.id]: + response_message = converser_cog.redo_users[after.author.id].response + ctx = converser_cog.redo_users[after.author.id].ctx + await response_message.edit(content="Redoing prompt 🔄...") + + edited_content = await converser_cog.mention_to_username(after, after.content) + + if after.channel.id in converser_cog.conversation_threads: + # Remove the last two elements from the history array and add the new : prompt + converser_cog.conversation_threads[ + after.channel.id + ].history = converser_cog.conversation_threads[after.channel.id].history[:-2] + + pinecone_dont_reinsert = None + if not converser_cog.pinecone_service: + converser_cog.conversation_threads[after.channel.id].history.append( + EmbeddedConversationItem( + f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n", + 0, + ) + ) + + converser_cog.conversation_threads[after.channel.id].count += 1 + + overrides = converser_cog.conversation_threads[after.channel.id].get_overrides() + + await TextService.encapsulated_send( + converser_cog, + id=after.channel.id, + prompt=edited_content, + ctx=ctx, + response_message=response_message, + temp_override=overrides["temperature"], + top_p_override=overrides["top_p"], + frequency_penalty_override=overrides["frequency_penalty"], + presence_penalty_override=overrides["presence_penalty"], + model=converser_cog.conversation_threads[after.channel.id].model, + edited_request=True, + ) + + if not converser_cog.pinecone_service: + converser_cog.redo_users[after.author.id].prompt = edited_content + + +""" +Conversation interaction buttons +""" +class ConversationView(discord.ui.View): + def __init__( + self, + ctx, + converser_cog, + id, + model, + from_ask_command=False, + from_edit_command=False, + custom_api_key=None, + ): + super().__init__(timeout=3600) # 1 hour interval to redo. + self.converser_cog = converser_cog + self.ctx = ctx + self.model = model + self.from_ask_command = from_ask_command + self.from_edit_command = from_edit_command + self.custom_api_key = custom_api_key + self.add_item( + RedoButton( + self.converser_cog, + model=model, + from_ask_command=from_ask_command, + from_edit_command=from_edit_command, + custom_api_key=self.custom_api_key, + ) + ) + + if id in self.converser_cog.conversation_threads: + self.add_item(EndConvoButton(self.converser_cog)) + + async def on_timeout(self): + # Remove the button from the view/message + self.clear_items() + # Send a message to the user saying the view has timed out + if self.message: + await self.message.edit( + view=None, + ) + else: + await self.ctx.edit( + view=None, + ) + + +class EndConvoButton(discord.ui.Button["ConversationView"]): + def __init__(self, converser_cog): + super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end") + self.converser_cog = converser_cog + + async def callback(self, interaction: discord.Interaction): + + # Get the user + user_id = interaction.user.id + if ( + user_id in self.converser_cog.conversation_thread_owners + and self.converser_cog.conversation_thread_owners[user_id] + == interaction.channel.id + ): + try: + await self.converser_cog.end_conversation( + interaction, opener_user_id=interaction.user.id + ) + except Exception as e: + print(e) + traceback.print_exc() + await interaction.response.send_message( + e, ephemeral=True, delete_after=30 + ) + pass + else: + await interaction.response.send_message( + "This is not your conversation to end!", ephemeral=True, delete_after=10 + ) + + +class RedoButton(discord.ui.Button["ConversationView"]): + def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key): + super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo") + self.converser_cog = converser_cog + self.model = model + self.from_ask_command = from_ask_command + self.from_edit_command = from_edit_command + self.custom_api_key = custom_api_key + + async def callback(self, interaction: discord.Interaction): + + # Get the user + user_id = interaction.user.id + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[ + user_id + ].in_interaction(interaction.message.id): + # Get the message and the prompt and call encapsulated_send + prompt = self.converser_cog.redo_users[user_id].prompt + instruction = self.converser_cog.redo_users[user_id].instruction + ctx = self.converser_cog.redo_users[user_id].ctx + response_message = self.converser_cog.redo_users[user_id].response + codex = self.converser_cog.redo_users[user_id].codex + + msg = await interaction.response.send_message( + "Retrying your original request...", ephemeral=True, delete_after=15 + ) + + await TextService.encapsulated_send( + self.converser_cog, + id=user_id, + prompt=prompt, + instruction=instruction, + ctx=ctx, + model=self.model, + response_message=response_message, + codex=codex, + custom_api_key=self.custom_api_key, + redo_request=True, + from_ask_command=self.from_ask_command, + from_edit_command=self.from_edit_command, + ) + else: + await interaction.response.send_message( + "You can only redo the most recent prompt that you sent yourself.", + ephemeral=True, + delete_after=10, + ) + + +""" +The setup modal when using user input API keys +""" +class SetupModal(discord.ui.Modal): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Get the argument named "user_key_db" and save it as USER_KEY_DB + self.USER_KEY_DB = kwargs.pop("user_key_db") + + self.add_item( + discord.ui.InputText( + label="OpenAI API Key", + placeholder="sk--......", + ) + ) + + async def callback(self, interaction: discord.Interaction): + user = interaction.user + api_key = self.children[0].value + # Validate that api_key is indeed in this format + if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key): + await interaction.response.send_message( + "Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.", + ephemeral=True, + delete_after=100, + ) + else: + # We can save the key for the user to the database. + + # Make a test request using the api key to ensure that it is valid. + try: + await Model.send_test_request(api_key) + await interaction.response.send_message( + "Your API key was successfully validated.", + ephemeral=True, + delete_after=10, + ) + + except aiohttp.ClientResponseError as e: + await interaction.response.send_message( + f"The API returned an invalid response: **{e.status}: {e.message}**", + ephemeral=True, + delete_after=30, + ) + return + + except Exception as e: + await interaction.response.send_message( + f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding", + ephemeral=True, + delete_after=30, + ) + return + + # Save the key to the database + try: + self.USER_KEY_DB[user.id] = api_key + self.USER_KEY_DB.commit() + await interaction.followup.send( + "Your API key was successfully saved.", + ephemeral=True, + delete_after=10, + ) + except Exception as e: + traceback.print_exc() + await interaction.followup.send( + "There was an error saving your API key.", + ephemeral=True, + delete_after=30, + ) + return + + pass diff --git a/models/usage_service_model.py b/services/usage_service.py similarity index 100% rename from models/usage_service_model.py rename to services/usage_service.py