diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index 1471355..64b10f8 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -23,7 +23,7 @@ users_to_interactions = {} class DrawDallEService(commands.Cog, name="DrawDallEService"): def __init__( - self, bot, usage_service, model, message_queue, deletion_queue, converser_cog + self, bot, usage_service, model, message_queue, deletion_queue, converser_cog ): self.bot = bot self.usage_service = usage_service @@ -47,13 +47,13 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): print(f"Image prompt optimizer was added") async def encapsulated_send( - self, - prompt, - message, - response_message=None, - vary=None, - draw_from_optimizer=None, - user_id=None, + self, + prompt, + message, + response_message=None, + vary=None, + draw_from_optimizer=None, + user_id=None, ): await asyncio.sleep(0) # send the prompt to the model @@ -260,8 +260,8 @@ class VaryButton(discord.ui.Button): if len(self.converser_cog.users_to_interactions[user_id]) >= 2: interaction_id2 = interaction.id if ( - interaction_id2 - not in self.converser_cog.users_to_interactions[user_id] + interaction_id2 + not in self.converser_cog.users_to_interactions[user_id] ): await interaction.response.send_message( content="You can not vary images in someone else's chain!", @@ -286,13 +286,14 @@ class VaryButton(discord.ui.Button): ) prompt = redo_users[user_id].prompt - asyncio.ensure_future(self.cog.encapsulated_send( - prompt, - interaction.message, - response_message=response_message, - vary=self.image_url, - user_id=user_id, - ) + asyncio.ensure_future( + self.cog.encapsulated_send( + prompt, + interaction.message, + response_message=response_message, + vary=self.image_url, + user_id=user_id, + ) ) @@ -356,4 +357,6 @@ class RedoButton(discord.ui.Button["SaveView"]): ) self.converser_cog.users_to_interactions[user_id].append(message.id) - asyncio.ensure_future(self.cog.encapsulated_send(prompt, message, response_message)) + asyncio.ensure_future( + self.cog.encapsulated_send(prompt, message, response_message) + ) diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index e18af5f..e3dfd33 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -24,14 +24,14 @@ original_message = {} class GPT3ComCon(commands.Cog, name="GPT3ComCon"): def __init__( - self, - bot, - usage_service, - model, - message_queue, - deletion_queue, - DEBUG_GUILD, - DEBUG_CHANNEL, + self, + bot, + usage_service, + model, + message_queue, + deletion_queue, + DEBUG_GUILD, + DEBUG_CHANNEL, ): self.debug_channel = None self.bot = bot @@ -130,13 +130,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): def check_conversing(self, message): cond1 = ( - message.author.id in self.conversating_users - and message.channel.name in ["gpt3", "general-bot", "bot"] + message.author.id in self.conversating_users + and message.channel.name in ["gpt3", "general-bot", "bot"] ) cond2 = ( - message.author.id in self.conversating_users - and message.author.id in self.conversation_threads - and message.channel.id == self.conversation_threads[message.author.id] + message.author.id in self.conversating_users + and message.author.id in self.conversation_threads + and message.channel.id == self.conversation_threads[message.author.id] ) # If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation @@ -282,7 +282,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): async def paginate_and_send(self, response_text, message): response_text = [ - response_text[i: i + self.TEXT_CUTOFF] + response_text[i : i + self.TEXT_CUTOFF] for i in range(0, len(response_text), self.TEXT_CUTOFF) ] # Send each chunk as a message @@ -299,7 +299,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): async def queue_debug_chunks(self, debug_message, message, debug_channel): debug_message_chunks = [ - debug_message[i: i + self.TEXT_CUTOFF] + debug_message[i : i + self.TEXT_CUTOFF] for i in range(0, len(debug_message), self.TEXT_CUTOFF) ] @@ -342,8 +342,8 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): if message.author.id in self.conversating_users: # If the user has reached the max conversation length, end the conversation if ( - self.conversating_users[message.author.id].count - >= self.model.max_conversation_length + self.conversating_users[message.author.id].count + >= self.model.max_conversation_length ): await message.reply( "You have reached the maximum conversation length. You have ended the conversation with GPT3, and it has ended." @@ -393,14 +393,14 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # Check again if the prompt is about to go past the token limit new_prompt = ( - "".join(self.conversating_users[message.author.id].history) - + "\nGPTie: " + "".join(self.conversating_users[message.author.id].history) + + "\nGPTie: " ) tokens = self.usage_service.count_tokens(new_prompt) if ( - tokens > self.model.summarize_threshold - 150 + tokens > self.model.summarize_threshold - 150 ): # 150 is a buffer for the second stage await message.reply( "I tried to summarize our current conversation so we could keep chatting, " @@ -444,7 +444,9 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # Paginate and send the response back to the users if not response_message: if len(response_text) > self.TEXT_CUTOFF: - await self.paginate_and_send(response_text, message) # No paginations for multi-messages. + await self.paginate_and_send( + response_text, message + ) # No paginations for multi-messages. else: response_message = await message.reply( response_text.replace("<|endofstatement|>", ""), @@ -453,8 +455,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.redo_users[message.author.id] = RedoUser( prompt, message, response_message ) - self.redo_users[message.author.id].add_interaction(response_message.id) - print(f"Added the interaction {response_message.id} to the redo user {message.author.id}") + self.redo_users[message.author.id].add_interaction( + response_message.id + ) + print( + f"Added the interaction {response_message.id} to the redo user {message.author.id}" + ) original_message[message.author.id] = message.id else: # We have response_text available, this is the original message that we want to edit @@ -551,7 +557,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # A global GLOBAL_COOLDOWN_TIME timer for all users if (message.author.id in self.last_used) and ( - time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME + time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME ): await message.reply( "You must wait " @@ -661,12 +667,15 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # This is because encapsulated_send is a coroutine, and we need to await it to get the response from the model. # We can't await it in the main thread, so we need to create a new thread to run it in. # We can make sure that when the thread executes it executes in an async fashion by - asyncio.run_coroutine_threadsafe(self.encapsulated_send( - message, - prompt - if message.author.id not in self.conversating_users - else "".join(self.conversating_users[message.author.id].history), - ), asyncio.get_running_loop()) + asyncio.run_coroutine_threadsafe( + self.encapsulated_send( + message, + prompt + if message.author.id not in self.conversating_users + else "".join(self.conversating_users[message.author.id].history), + ), + asyncio.get_running_loop(), + ) class RedoView(discord.ui.View): @@ -696,18 +705,27 @@ class EndConvoButton(discord.ui.Button["RedoView"]): # Get the user user_id = interaction.user.id - if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id): + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[ + user_id + ].in_interaction(interaction.message.id): try: - await self.converser_cog.end_conversation(self.converser_cog.redo_users[user_id].message) - await interaction.response.send_message("Your conversation has ended!", ephemeral=True, - delete_after=10) + await self.converser_cog.end_conversation( + self.converser_cog.redo_users[user_id].message + ) + await interaction.response.send_message( + "Your conversation has ended!", ephemeral=True, delete_after=10 + ) except Exception as e: print(e) traceback.print_exc() - await interaction.response.send_message(e, ephemeral=True, delete_after=30) + await interaction.response.send_message( + e, ephemeral=True, delete_after=30 + ) pass else: - await interaction.response.send_message("This is not your conversation to end!", ephemeral=True, delete_after=10) + await interaction.response.send_message( + "This is not your conversation to end!", ephemeral=True, delete_after=10 + ) class RedoButton(discord.ui.Button["RedoView"]): @@ -719,7 +737,9 @@ class RedoButton(discord.ui.Button["RedoView"]): # Get the user user_id = interaction.user.id - if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id): + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[ + user_id + ].in_interaction(interaction.message.id): # Get the message and the prompt and call encapsulated_send message = self.converser_cog.redo_users[user_id].message prompt = self.converser_cog.redo_users[user_id].prompt @@ -733,4 +753,8 @@ class RedoButton(discord.ui.Button["RedoView"]): message, prompt, response_message ) else: - await interaction.response.send_message("You can only redo the most recent prompt that you sent yourself.", ephemeral=True, delete_after=10) + await interaction.response.send_message( + "You can only redo the most recent prompt that you sent yourself.", + ephemeral=True, + delete_after=10, + ) diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index 0a046e2..9923589 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -96,8 +96,12 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): response_message.id ) - self.converser_cog.redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message) - self.converser_cog.redo_users[ctx.author.id].add_interaction(response_message.id) + self.converser_cog.redo_users[ctx.author.id] = RedoUser( + prompt, ctx.message, response_message + ) + self.converser_cog.redo_users[ctx.author.id].add_interaction( + response_message.id + ) await response_message.edit( view=OptimizeView( self.converser_cog, self.image_service_cog, self.deletion_queue @@ -140,7 +144,10 @@ class DrawButton(discord.ui.Button["OptimizeView"]): user_id = interaction.user.id interaction_id = interaction.message.id - if interaction_id not in self.converser_cog.users_to_interactions[user_id] or interaction_id not in self.converser_cog.redo_users[user_id].interactions: + if ( + interaction_id not in self.converser_cog.users_to_interactions[user_id] + or interaction_id not in self.converser_cog.redo_users[user_id].interactions + ): await interaction.response.send_message( content="You can only draw for prompts that you generated yourself!", ephemeral=True, @@ -184,7 +191,9 @@ class RedoButton(discord.ui.Button["OptimizeView"]): # Get the user user_id = interaction.user.id - if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction_id): + if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[ + user_id + ].in_interaction(interaction_id): # Get the message and the prompt and call encapsulated_send message = self.converser_cog.redo_users[user_id].message prompt = self.converser_cog.redo_users[user_id].prompt @@ -198,5 +207,6 @@ class RedoButton(discord.ui.Button["OptimizeView"]): else: await interaction.response.send_message( content="You can only redo for prompts that you generated yourself!", - ephemeral=True, delete_after=10 + ephemeral=True, + delete_after=10, ) diff --git a/models/user_model.py b/models/user_model.py index 7990aaf..f8f39bc 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -3,6 +3,7 @@ Store information about a discord user, for the purposes of enabling conversatio history, message count, and the id of the user in order to track them. """ + class RedoUser: def __init__(self, prompt, message, response): self.prompt = prompt @@ -27,6 +28,7 @@ class RedoUser: def __repr__(self): return f"RedoUser({self.message.author.id})" + class User: def __init__(self, id): self.id = id