diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index b39a20c..d0e9ec0 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -14,7 +14,7 @@ from models.deletion_service_model import Deletion from models.env_service_model import EnvService from models.message_model import Message from models.moderations_service_model import Moderation -from models.user_model import User, RedoUser +from models.user_model import RedoUser, Thread from models.check_model import Check from models.autocomplete_model import Settings_autocompleter, File_autocompleter from collections import defaultdict @@ -44,7 +44,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.debug_channel = None self.bot = bot self._last_member_ = None - self.conversating_users = {} + self.conversation_threads = {} self.DAVINCI_ROLES = ["admin", "Admin", "GPT", "gpt"] self.END_PROMPTS = [ "end", @@ -103,7 +103,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) self.TEXT_CUTOFF = 1900 self.message_queue = message_queue - self.conversation_threads = {} + self.conversation_thread_owners = {} # Create slash command groups dalle = discord.SlashCommandGroup( @@ -220,11 +220,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await thread.delete() await ctx.respond("All conversation threads have been deleted.") - def check_conversing(self, user_id, channel_id, message_content): + #TODO: add extra condition to check if multi is enabled for the thread, stated in conversation_threads + def check_conversing(self, user_id, channel_id, message_content, multi=None): cond1 = ( - user_id in self.conversating_users - and user_id in self.conversation_threads - and channel_id == self.conversation_threads[user_id] + channel_id in self.conversation_threads + #and user_id in self.conversation_thread_owners + #and channel_id == self.conversation_thread_owners[user_id] ) # If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation try: @@ -235,23 +236,33 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): return (cond1) and cond2 - async def end_conversation(self, message, opener_user_id=None): - normalized_user_id = opener_user_id if opener_user_id else message.author.id - self.conversating_users.pop(normalized_user_id) + async def end_conversation(self, ctx, opener_user_id=None): + normalized_user_id = opener_user_id if opener_user_id else ctx.author.id + try: + channel_id = self.conversation_thread_owners[normalized_user_id] + except: + await ctx.delete(delay=5) + await ctx.reply("Only the conversation starter can end this.", delete_after=5) + return + self.conversation_threads.pop(channel_id) - if isinstance(message, discord.ApplicationContext): - await message.respond( - "Your conversation has ended!", ephemeral=True, delete_after=10 + if isinstance(ctx, discord.ApplicationContext): + await ctx.respond( + "You have ended the conversation with GPT3. Start a conversation with /gpt converse", ephemeral=True, delete_after=10 + ) + elif isinstance(ctx, discord.Interaction): + await ctx.response.send_message( + "You have ended the conversation with GPT3. Start a conversation with /gpt converse", ephemeral=True, delete_after=10 ) else: - await message.reply( + await ctx.reply( "You have ended the conversation with GPT3. Start a conversation with /gpt converse" ) # Close all conversation threads for the user - if normalized_user_id in self.conversation_threads: - thread_id = self.conversation_threads[normalized_user_id] - self.conversation_threads.pop(normalized_user_id) + if normalized_user_id in self.conversation_thread_owners: + thread_id = self.conversation_thread_owners[normalized_user_id] + self.conversation_thread_owners.pop(normalized_user_id) # Attempt to close and lock the thread. try: @@ -449,10 +460,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): async def check_conversation_limit(self, message): # After each response, check if the user has reached the conversation limit in terms of messages or time. - if message.author.id in self.conversating_users: + if message.channel.id in self.conversation_threads: # If the user has reached the max conversation length, end the conversation if ( - self.conversating_users[message.author.id].count + self.conversation_threads[message.channel.id].count >= self.model.max_conversation_length ): await message.reply( @@ -475,26 +486,27 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) # Get the last entry from the user's conversation history new_conversation_history.append( - self.conversating_users[message.author.id].history[-1] + "\n" + self.conversation_threads[message.channel.id].history[-1] + "\n" ) - self.conversating_users[message.author.id].history = new_conversation_history + self.conversation_threads[message.channel.id].history = new_conversation_history # A listener for message edits to redo prompts if they are edited @discord.Cog.listener() async def on_message_edit(self, before, after): # Moderation - if ( - after.guild.id in self.moderation_queues - and self.moderation_queues[after.guild.id] is not None - ): - # Create a timestamp that is 0.5 seconds from now - timestamp = ( - datetime.datetime.now() + datetime.timedelta(seconds=0.5) - ).timestamp() - await self.moderation_queues[after.guild.id].put( - Moderation(after, timestamp) - ) + if not isinstance(after.channel, discord.DMChannel): + if ( + after.guild.id in self.moderation_queues + and self.moderation_queues[after.guild.id] is not None + ): + # Create a timestamp that is 0.5 seconds from now + timestamp = ( + datetime.datetime.now() + datetime.timedelta(seconds=0.5) + ).timestamp() + await self.moderation_queues[after.guild.id].put( + Moderation(after, timestamp) + ) if after.author.id in self.redo_users: if after.id == original_message[after.author.id]: @@ -506,22 +518,22 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # If the user is conversing, we need to get their conversation history, delete the last # "Human:" message, create a new Human: section with the new prompt, and then set the prompt to # the new prompt, then send that new prompt as the new prompt. - if after.author.id in self.conversating_users: + if after.channel.id in self.conversation_threads: # Remove the last two elements from the history array and add the new Human: prompt - self.conversating_users[ - after.author.id - ].history = self.conversating_users[after.author.id].history[:-2] - self.conversating_users[after.author.id].history.append( + self.conversation_threads[ + after.channel.id + ].history = self.conversation_threads[after.channel.id].history[:-2] + self.conversation_threads[after.channel.id].history.append( f"\nHuman: {after.content}<|endofstatement|>\n" ) edited_content = "".join( - self.conversating_users[after.author.id].history + self.conversation_threads[after.channel.id].history ) - self.conversating_users[after.author.id].count += 1 + self.conversation_threads[after.channel.id].count += 1 print("Doing the encapsulated send") await self.encapsulated_send( - user_id=after.author.id, + id=after.channel.id, prompt=edited_content, ctx=ctx, response_message=response_message, @@ -570,7 +582,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # We want to have conversationality functionality. To have gpt3 remember context, we need to append the conversation/prompt # history to the prompt. We can do this by checking if the user is in the conversating_users dictionary, and if they are, # we can append their history to the prompt. - if message.author.id in self.conversating_users: + if message.channel.id in self.conversation_threads: # Since this is async, we don't want to allow the user to send another prompt while a conversation # prompt is processing, that'll mess up the conversation history! @@ -595,28 +607,28 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): original_message[message.author.id] = message.id - self.conversating_users[message.author.id].history.append( + self.conversation_threads[message.channel.id].history.append( "\nHuman: " + prompt + "<|endofstatement|>\n" ) # increment the conversation counter for the user - self.conversating_users[message.author.id].count += 1 + self.conversation_threads[message.channel.id].count += 1 # Send the request to the model # If conversing, the prompt to send is the history, otherwise, it's just the prompt await self.encapsulated_send( - message.author.id, + message.channel.id, prompt - if message.author.id not in self.conversating_users - else "".join(self.conversating_users[message.author.id].history), + if message.channel.id not in self.conversation_threads + else "".join(self.conversation_threads[message.channel.id].history), message, ) # ctx can be of type AppContext(interaction) or Message async def encapsulated_send( self, - user_id, + id, prompt, ctx, temp_override=None, @@ -646,7 +658,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # Check if the prompt is about to go past the token limit if ( - user_id in self.conversating_users + id in self.conversation_threads and tokens > self.model.summarize_threshold and not from_g_command ): @@ -663,7 +675,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # Check again if the prompt is about to go past the token limit new_prompt = ( - "".join(self.conversating_users[user_id].history) + "\nGPTie: " + "".join(self.conversation_threads[id].history) + "\nGPTie: " ) tokens = self.usage_service.count_tokens(new_prompt) @@ -713,8 +725,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) # If the user is conversing, add the GPT response to their conversation history. - if user_id in self.conversating_users and not from_g_command: - self.conversating_users[user_id].history.append( + if id in self.conversation_threads and not from_g_command: + self.conversation_threads[id].history.append( "\nGPTie: " + str(response_text) + "<|endofstatement|>\n" ) @@ -726,12 +738,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): response_message = ( await ctx.respond( response_text, - view=RedoView(ctx, self, user_id), + view=ConversationView(ctx, self, ctx.channel.id), ) if from_context else await ctx.reply( response_text, - view=RedoView(ctx, self, user_id), + view=ConversationView(ctx, self, ctx.channel.id), ) ) @@ -742,10 +754,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): else await ctx.fetch_message(response_message.id) ) - self.redo_users[user_id] = RedoUser( + self.redo_users[ctx.author.id] = RedoUser( prompt, ctx, ctx, actual_response_message ) - self.redo_users[user_id].add_interaction(actual_response_message.id) + self.redo_users[ctx.author.id].add_interaction(actual_response_message.id) # We are doing a redo, edit the message. else: @@ -755,8 +767,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.generate_debug_message(prompt, response), self.debug_channel ) - if user_id in self.awaiting_responses: - self.awaiting_responses.remove(user_id) + if ctx.author.id in self.awaiting_responses: + self.awaiting_responses.remove(ctx.author.id) # Error catching for OpenAI model value errors except ValueError as e: @@ -775,8 +787,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.send_followup(message) if from_context else await ctx.reply( message ) - if user_id in self.awaiting_responses: - self.awaiting_responses.remove(user_id) + if ctx.author.id in self.awaiting_responses: + self.awaiting_responses.remove(ctx.author.id) traceback.print_exc() try: @@ -901,11 +913,11 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): user = ctx.user - if user.id in self.conversating_users: + if user.id in self.conversation_thread_owners: message = await ctx.respond( - "You are already conversating with GPT3. End the conversation with !g end or just say 'end' in a supported channel" + "You've already created a thread, end it before creating a new one", + delete_after=5 ) - await self.deletion_queue(message) return if not opener and not opener_file: @@ -930,18 +942,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): else: pass - self.conversating_users[user_id_normalized] = User(user_id_normalized) - - # Append the starter text for gpt3 to the user's history so it gets concatenated with the prompt later - if minimal or opener_file: - self.conversating_users[user_id_normalized].history.append( - self.CONVERSATION_STARTER_TEXT_MINIMAL - ) - elif not minimal: - self.conversating_users[user_id_normalized].history.append( - self.CONVERSATION_STARTER_TEXT - ) - if private: await ctx.respond(user.name + "'s private conversation with GPT3") thread = await ctx.channel.create_thread( @@ -957,6 +957,18 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): auto_archive_duration=60, ) + self.conversation_threads[thread.id] = Thread(thread.id) + + # Append the starter text for gpt3 to the user's history so it gets concatenated with the prompt later + if minimal or opener_file: + self.conversation_threads[thread.id].history.append( + self.CONVERSATION_STARTER_TEXT_MINIMAL + ) + elif not minimal: + self.conversation_threads[thread.id].history.append( + self.CONVERSATION_STARTER_TEXT + ) + await thread.send( "<@" + str(user_id_normalized) @@ -966,24 +978,26 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # send opening if opener: thread_message = await thread.send("***Opening prompt*** \n" + opener) - if user_id_normalized in self.conversating_users: + if thread.id in self.conversation_threads: self.awaiting_responses.append(user_id_normalized) - self.conversating_users[user_id_normalized].history.append( + self.conversation_threads[thread.id].history.append( "\nHuman: " + opener + "<|endofstatement|>\n" ) - self.conversating_users[user_id_normalized].count += 1 + self.conversation_threads[thread.id].count += 1 await self.encapsulated_send( - user_id_normalized, + thread.id, opener - if user_id_normalized not in self.conversating_users - else "".join(self.conversating_users[user_id_normalized].history), + if thread.id not in self.conversation_threads + else "".join(self.conversation_threads[thread.id].history), thread_message, ) + self.awaiting_responses.remove(user_id_normalized) + + self.conversation_thread_owners[user_id_normalized] = thread.id - self.conversation_threads[user_id_normalized] = thread.id @add_to_group("system") @discord.slash_command( @@ -1063,10 +1077,15 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): guild_ids=ALLOWED_GUILDS, ) @discord.guild_only() - async def end_chat(self, ctx: discord.ApplicationContext): + async def end(self, ctx: discord.ApplicationContext): await ctx.defer(ephemeral=True) user_id = ctx.user.id - if user_id in self.conversating_users: + try: + thread_id = self.conversation_thread_owners[user_id] + except: + await ctx.respond("You haven't started any conversations", ephemeral=True, delete_after=10) + return + if thread_id in self.conversation_threads: try: await self.end_conversation(ctx) except Exception as e: @@ -1140,14 +1159,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await self.process_settings_command(ctx, parameter, value) -class RedoView(discord.ui.View): - def __init__(self, ctx, converser_cog, user_id): +class ConversationView(discord.ui.View): + def __init__(self, ctx, converser_cog, id): super().__init__(timeout=3600) # 1 hour interval to redo. self.converser_cog = converser_cog self.ctx = ctx self.add_item(RedoButton(self.converser_cog)) - if user_id in self.converser_cog.conversating_users: + if id in self.converser_cog.conversation_threads: self.add_item(EndConvoButton(self.converser_cog)) async def on_timeout(self): @@ -1164,7 +1183,7 @@ class RedoView(discord.ui.View): ) -class EndConvoButton(discord.ui.Button["RedoView"]): +class EndConvoButton(discord.ui.Button["ConversationView"]): def __init__(self, converser_cog): super().__init__(style=discord.ButtonStyle.danger, label="End Conversation") self.converser_cog = converser_cog @@ -1173,16 +1192,10 @@ class EndConvoButton(discord.ui.Button["RedoView"]): # Get the user user_id = interaction.user.id - if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[ - user_id - ].in_interaction(interaction.message.id): + if user_id in self.converser_cog.conversation_thread_owners and self.converser_cog.conversation_thread_owners[user_id] == interaction.channel.id: try: await self.converser_cog.end_conversation( - self.converser_cog.redo_users[user_id].message, - opener_user_id=user_id, - ) - await interaction.response.send_message( - "Your conversation has ended!", ephemeral=True, delete_after=10 + interaction, opener_user_id=interaction.user.id ) except Exception as e: print(e) @@ -1197,7 +1210,7 @@ class EndConvoButton(discord.ui.Button["RedoView"]): ) -class RedoButton(discord.ui.Button["RedoView"]): +class RedoButton(discord.ui.Button["ConversationView"]): def __init__(self, converser_cog): super().__init__(style=discord.ButtonStyle.danger, label="Retry") self.converser_cog = converser_cog @@ -1219,7 +1232,7 @@ class RedoButton(discord.ui.Button["RedoView"]): ) await self.converser_cog.encapsulated_send( - user_id, prompt, ctx, response_message + id=user_id, prompt=prompt, ctx=ctx, response_message=response_message ) else: await interaction.response.send_message( diff --git a/models/user_model.py b/models/user_model.py index 35c1c19..682a3a1 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -50,3 +50,24 @@ class User: def __str__(self): return self.__repr__() + +class Thread: + def __init__(self, id): + self.id = id + self.history = [] + self.count = 0 + + # These user objects should be accessible by ID, for example if we had a bunch of user + # objects in a list, and we did `if 1203910293001 in user_list`, it would return True + # if the user with that ID was in the list + def __eq__(self, other): + return self.id == other.id + + def __hash__(self): + return hash(self.id) + + def __repr__(self): + return f"Thread(id={self.id}, history={self.history})" + + def __str__(self): + return self.__repr__() \ No newline at end of file