From 3ea5f544e84c254eee800574e25cd244446c2594 Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Thu, 22 Dec 2022 01:00:41 -0500 Subject: [PATCH] Fix bugs and add end conversation button to convo messages --- cogs/draw_image_generation.py | 40 ++++--- cogs/gpt_3_commands_and_converser.py | 151 +++++++++++++++++---------- cogs/image_prompt_optimizer.py | 54 ++++------ models/user_model.py | 23 ++++ 4 files changed, 155 insertions(+), 113 deletions(-) diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index fd22174..1471355 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -1,3 +1,4 @@ +import asyncio import datetime import os import re @@ -13,13 +14,8 @@ from discord.ext import commands from cogs.image_prompt_optimizer import ImgPromptOptimizer - -class RedoUser: - def __init__(self, prompt, message, response_message): - self.prompt = prompt - self.message = message - self.response_message = response_message - +# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time +from models.user_model import RedoUser redo_users = {} users_to_interactions = {} @@ -27,7 +23,7 @@ users_to_interactions = {} class DrawDallEService(commands.Cog, name="DrawDallEService"): def __init__( - self, bot, usage_service, model, message_queue, deletion_queue, converser_cog + self, bot, usage_service, model, message_queue, deletion_queue, converser_cog ): self.bot = bot self.usage_service = usage_service @@ -51,14 +47,15 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): print(f"Image prompt optimizer was added") async def encapsulated_send( - self, - prompt, - message, - response_message=None, - vary=None, - draw_from_optimizer=None, - user_id=None, + self, + prompt, + message, + response_message=None, + vary=None, + draw_from_optimizer=None, + user_id=None, ): + await asyncio.sleep(0) # send the prompt to the model file, image_urls = self.model.send_image_request( prompt, vary=vary if not draw_from_optimizer else None @@ -154,7 +151,7 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): # The image prompt is everything after the command prompt = " ".join(args) - await self.encapsulated_send(prompt, message) + asyncio.ensure_future(self.encapsulated_send(prompt, message)) except Exception as e: print(e) @@ -263,8 +260,8 @@ class VaryButton(discord.ui.Button): if len(self.converser_cog.users_to_interactions[user_id]) >= 2: interaction_id2 = interaction.id if ( - interaction_id2 - not in self.converser_cog.users_to_interactions[user_id] + interaction_id2 + not in self.converser_cog.users_to_interactions[user_id] ): await interaction.response.send_message( content="You can not vary images in someone else's chain!", @@ -288,13 +285,15 @@ class VaryButton(discord.ui.Button): response_message.id ) prompt = redo_users[user_id].prompt - await self.cog.encapsulated_send( + + asyncio.ensure_future(self.cog.encapsulated_send( prompt, interaction.message, response_message=response_message, vary=self.image_url, user_id=user_id, ) + ) class SaveButton(discord.ui.Button["SaveView"]): @@ -346,7 +345,6 @@ class RedoButton(discord.ui.Button["SaveView"]): return # We have passed the intial check of if the interaction belongs to the user - if user_id in redo_users: # Get the message and the prompt and call encapsulated_send message = redo_users[user_id].message @@ -358,4 +356,4 @@ class RedoButton(discord.ui.Button["SaveView"]): ) self.converser_cog.users_to_interactions[user_id].append(message.id) - await self.cog.encapsulated_send(prompt, message, response_message) + asyncio.ensure_future(self.cog.encapsulated_send(prompt, message, response_message)) diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index 9caf59b..e18af5f 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -1,7 +1,10 @@ +import asyncio import datetime +import functools import json import os import re +import threading import time import traceback @@ -12,31 +15,23 @@ from cogs.draw_image_generation import DrawDallEService from cogs.image_prompt_optimizer import ImgPromptOptimizer from models.deletion_service import Deletion from models.message_model import Message -from models.user_model import User +from models.user_model import User, RedoUser from collections import defaultdict -class RedoUser: - def __init__(self, prompt, message, response): - self.prompt = prompt - self.message = message - self.response = response - - -redo_users = {} original_message = {} class GPT3ComCon(commands.Cog, name="GPT3ComCon"): def __init__( - self, - bot, - usage_service, - model, - message_queue, - deletion_queue, - DEBUG_GUILD, - DEBUG_CHANNEL, + self, + bot, + usage_service, + model, + message_queue, + deletion_queue, + DEBUG_GUILD, + DEBUG_CHANNEL, ): self.debug_channel = None self.bot = bot @@ -59,6 +54,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.summarize = self.model.summarize_conversations self.deletion_queue = deletion_queue self.users_to_interactions = defaultdict(list) + self.redo_users = {} try: # Attempt to read a conversation starter text string from the file. @@ -134,13 +130,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): def check_conversing(self, message): cond1 = ( - message.author.id in self.conversating_users - and message.channel.name in ["gpt3", "general-bot", "bot"] + message.author.id in self.conversating_users + and message.channel.name in ["gpt3", "general-bot", "bot"] ) cond2 = ( - message.author.id in self.conversating_users - and message.author.id in self.conversation_threads - and message.channel.id == self.conversation_threads[message.author.id] + message.author.id in self.conversating_users + and message.author.id in self.conversation_threads + and message.channel.id == self.conversation_threads[message.author.id] ) # If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation @@ -286,7 +282,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): async def paginate_and_send(self, response_text, message): response_text = [ - response_text[i : i + self.TEXT_CUTOFF] + response_text[i: i + self.TEXT_CUTOFF] for i in range(0, len(response_text), self.TEXT_CUTOFF) ] # Send each chunk as a message @@ -303,7 +299,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): async def queue_debug_chunks(self, debug_message, message, debug_channel): debug_message_chunks = [ - debug_message[i : i + self.TEXT_CUTOFF] + debug_message[i: i + self.TEXT_CUTOFF] for i in range(0, len(debug_message), self.TEXT_CUTOFF) ] @@ -346,8 +342,8 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): if message.author.id in self.conversating_users: # If the user has reached the max conversation length, end the conversation if ( - self.conversating_users[message.author.id].count - >= self.model.max_conversation_length + self.conversating_users[message.author.id].count + >= self.model.max_conversation_length ): await message.reply( "You have reached the maximum conversation length. You have ended the conversation with GPT3, and it has ended." @@ -374,6 +370,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.conversating_users[message.author.id].history = new_conversation_history async def encapsulated_send(self, message, prompt, response_message=None): + await asyncio.sleep(0) # Append a newline, and GPTie: to the prompt new_prompt = prompt + "\nGPTie: " @@ -396,14 +393,14 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # Check again if the prompt is about to go past the token limit new_prompt = ( - "".join(self.conversating_users[message.author.id].history) - + "\nGPTie: " + "".join(self.conversating_users[message.author.id].history) + + "\nGPTie: " ) tokens = self.usage_service.count_tokens(new_prompt) if ( - tokens > self.model.summarize_threshold - 150 + tokens > self.model.summarize_threshold - 150 ): # 150 is a buffer for the second stage await message.reply( "I tried to summarize our current conversation so we could keep chatting, " @@ -420,6 +417,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await self.end_conversation(message) return + # REQUEST!!!! response = self.model.send_request(new_prompt, message, tokens=tokens) response_text = response["choices"][0]["text"] @@ -446,15 +444,17 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # Paginate and send the response back to the users if not response_message: if len(response_text) > self.TEXT_CUTOFF: - await self.paginate_and_send(response_text, message) + await self.paginate_and_send(response_text, message) # No paginations for multi-messages. else: response_message = await message.reply( response_text.replace("<|endofstatement|>", ""), - view=RedoView(self), + view=RedoView(self, message.author.id), ) - redo_users[message.author.id] = RedoUser( + self.redo_users[message.author.id] = RedoUser( prompt, message, response_message ) + self.redo_users[message.author.id].add_interaction(response_message.id) + print(f"Added the interaction {response_message.id} to the redo user {message.author.id}") original_message[message.author.id] = message.id else: # We have response_text available, this is the original message that we want to edit @@ -485,10 +485,10 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # A listener for message edits to redo prompts if they are edited @commands.Cog.listener() async def on_message_edit(self, before, after): - if after.author.id in redo_users: + if after.author.id in self.redo_users: if after.id == original_message[after.author.id]: - message = redo_users[after.author.id].message - response_message = redo_users[after.author.id].response + message = self.redo_users[after.author.id].message + response_message = self.redo_users[after.author.id].response await response_message.edit(content="Redoing prompt 🔄...") edited_content = after.content @@ -497,7 +497,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # "Human:" message, create a new Human: section with the new prompt, and then set the prompt to # the new prompt, then send that new prompt as the new prompt. if after.author.id in self.conversating_users: - # Remove the last two elements from the history array and add the new Human: prompt self.conversating_users[ after.author.id @@ -512,7 +511,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await self.encapsulated_send(message, edited_content, response_message) - redo_users[after.author.id].prompt = after.content + self.redo_users[after.author.id].prompt = after.content @commands.Cog.listener() async def on_message(self, message): @@ -552,7 +551,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # A global GLOBAL_COOLDOWN_TIME timer for all users if (message.author.id in self.last_used) and ( - time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME + time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME ): await message.reply( "You must wait " @@ -650,20 +649,35 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # Send the request to the model # If conversing, the prompt to send is the history, otherwise, it's just the prompt - await self.encapsulated_send( - message, - prompt - if message.author.id not in self.conversating_users - else "".join(self.conversating_users[message.author.id].history), - ) + # Create a new thread to run + # self.encapsulated_send( + # message, + # prompt + # if message.author.id not in self.conversating_users + # else "".join(self.conversating_users[message.author.id].history), + # ) + + # This created thread needs to call encapsulated_send in a coroutine/async fashion. + # This is because encapsulated_send is a coroutine, and we need to await it to get the response from the model. + # We can't await it in the main thread, so we need to create a new thread to run it in. + # We can make sure that when the thread executes it executes in an async fashion by + asyncio.run_coroutine_threadsafe(self.encapsulated_send( + message, + prompt + if message.author.id not in self.conversating_users + else "".join(self.conversating_users[message.author.id].history), + ), asyncio.get_running_loop()) class RedoView(discord.ui.View): - def __init__(self, converser_cog): + def __init__(self, converser_cog, user_id): super().__init__(timeout=3600) # 1 hour interval to redo. self.converser_cog = converser_cog self.add_item(RedoButton(self.converser_cog)) + if user_id in self.converser_cog.conversating_users: + self.add_item(EndConvoButton(self.converser_cog)) + async def on_timeout(self): # Remove the button from the view/message self.clear_items() @@ -673,29 +687,50 @@ class RedoView(discord.ui.View): ) +class EndConvoButton(discord.ui.Button["RedoView"]): + def __init__(self, converser_cog): + super().__init__(style=discord.ButtonStyle.danger, label="End Conversation") + self.converser_cog = converser_cog + + async def callback(self, interaction: discord.Interaction): + + # Get the user + user_id = interaction.user.id + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id): + try: + await self.converser_cog.end_conversation(self.converser_cog.redo_users[user_id].message) + await interaction.response.send_message("Your conversation has ended!", ephemeral=True, + delete_after=10) + except Exception as e: + print(e) + traceback.print_exc() + await interaction.response.send_message(e, ephemeral=True, delete_after=30) + pass + else: + await interaction.response.send_message("This is not your conversation to end!", ephemeral=True, delete_after=10) + + class RedoButton(discord.ui.Button["RedoView"]): def __init__(self, converser_cog): super().__init__(style=discord.ButtonStyle.danger, label="Retry") self.converser_cog = converser_cog async def callback(self, interaction: discord.Interaction): - msg = await interaction.response.send_message( - "Retrying your original request...", ephemeral=True - ) - - # Put the message into the deletion queue with a timestamp of 10 seconds from now to be deleted - deletion = Deletion( - msg, (datetime.datetime.now() + datetime.timedelta(seconds=10)).timestamp() - ) - await self.converser_cog.deletion_queue.put(deletion) # Get the user user_id = interaction.user.id - if user_id in redo_users: + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id): # Get the message and the prompt and call encapsulated_send - message = redo_users[user_id].message - prompt = redo_users[user_id].prompt - response_message = redo_users[user_id].response + message = self.converser_cog.redo_users[user_id].message + prompt = self.converser_cog.redo_users[user_id].prompt + response_message = self.converser_cog.redo_users[user_id].response + + msg = await interaction.response.send_message( + "Retrying your original request...", ephemeral=True, delete_after=15 + ) + await self.converser_cog.encapsulated_send( message, prompt, response_message ) + else: + await interaction.response.send_message("You can only redo the most recent prompt that you sent yourself.", ephemeral=True, delete_after=10) diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index 2f11fa1..0a046e2 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -8,15 +8,7 @@ import discord from discord.ext import commands from models.deletion_service import Deletion - -redo_users = {} - - -class RedoUser: - def __init__(self, prompt, message, response): - self.prompt = prompt - self.message = message - self.response = response +from models.user_model import RedoUser class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): @@ -73,11 +65,14 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): print( f"Received an image optimization request for the following prompt: {prompt}" ) + # Get the token amount for the prompt + tokens = self.usage_service.count_tokens(prompt) try: - response = self.model.send_request( + response = await self.model.send_request( prompt, ctx.message, + tokens=tokens, top_p_override=1.0, temp_override=0.9, presence_penalty_override=0.5, @@ -101,7 +96,8 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): response_message.id ) - redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message) + self.converser_cog.redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message) + self.converser_cog.redo_users[ctx.author.id].add_interaction(response_message.id) await response_message.edit( view=OptimizeView( self.converser_cog, self.image_service_cog, self.deletion_queue @@ -144,7 +140,7 @@ class DrawButton(discord.ui.Button["OptimizeView"]): user_id = interaction.user.id interaction_id = interaction.message.id - if interaction_id not in self.converser_cog.users_to_interactions[user_id]: + if interaction_id not in self.converser_cog.users_to_interactions[user_id] or interaction_id not in self.converser_cog.redo_users[user_id].interactions: await interaction.response.send_message( content="You can only draw for prompts that you generated yourself!", ephemeral=True, @@ -183,34 +179,24 @@ class RedoButton(discord.ui.Button["OptimizeView"]): self.deletion_queue = deletion_queue async def callback(self, interaction: discord.Interaction): - user_id = interaction.user.id interaction_id = interaction.message.id - if interaction_id not in self.converser_cog.users_to_interactions[user_id]: - await interaction.response.send_message( - content="You can only redo for prompts that you generated yourself!", - ephemeral=True, - ) - return - - msg = await interaction.response.send_message( - "Redoing your original request...", ephemeral=True - ) - - # Put the message into the deletion queue with a timestamp of 10 seconds from now to be deleted - deletion = Deletion( - msg, (datetime.datetime.now() + datetime.timedelta(seconds=10)).timestamp() - ) - await self.deletion_queue.put(deletion) - # Get the user user_id = interaction.user.id - if user_id in redo_users: + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction_id): # Get the message and the prompt and call encapsulated_send - message = redo_users[user_id].message - prompt = redo_users[user_id].prompt - response_message = redo_users[user_id].response + message = self.converser_cog.redo_users[user_id].message + prompt = self.converser_cog.redo_users[user_id].prompt + response_message = self.converser_cog.redo_users[user_id].response + msg = await interaction.response.send_message( + "Redoing your original request...", ephemeral=True, delete_after=20 + ) await self.converser_cog.encapsulated_send( message, prompt, response_message ) + else: + await interaction.response.send_message( + content="You can only redo for prompts that you generated yourself!", + ephemeral=True, delete_after=10 + ) diff --git a/models/user_model.py b/models/user_model.py index 11c6d61..7990aaf 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -3,6 +3,29 @@ Store information about a discord user, for the purposes of enabling conversatio history, message count, and the id of the user in order to track them. """ +class RedoUser: + def __init__(self, prompt, message, response): + self.prompt = prompt + self.message = message + self.response = response + self.interactions = [] + + def add_interaction(self, interaction): + self.interactions.append(interaction) + + def in_interaction(self, interaction): + return interaction in self.interactions + + # Represented by user_id + def __hash__(self): + return hash(self.message.author.id) + + def __eq__(self, other): + return self.message.author.id == other.message.author.id + + # repr + def __repr__(self): + return f"RedoUser({self.message.author.id})" class User: def __init__(self, id):