diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index 0f3c269..487279f 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -1147,21 +1147,27 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): if len(response_text) > self.TEXT_CUTOFF: await self.paginate_and_send(response_text, ctx) else: - response_message = ( - await ctx.respond( + if not from_context: + response_message = await ctx.reply( response_text, view=ConversationView( - ctx, self, ctx.channel.id, custom_api_key=custom_api_key + ctx, self, ctx.channel.id, model, custom_api_key=custom_api_key ), ) - if from_context - else await ctx.reply( + elif from_edit_command: + response_message = await ctx.respond( response_text, view=ConversationView( - ctx, self, ctx.channel.id, custom_api_key=custom_api_key + ctx, self, ctx.channel.id, model, from_edit_command, custom_api_key=custom_api_key + ), + ) + else: + response_message = await ctx.respond( + response_text, + view=ConversationView( + ctx, self, ctx.channel.id, model, custom_api_key=custom_api_key ), ) - ) # Get the actual message object of response_message in case it's an WebhookMessage actual_response_message = ( @@ -1171,7 +1177,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) self.redo_users[ctx.author.id] = RedoUser( - new_prompt, ctx, ctx, actual_response_message + prompt=new_prompt, instruction=instruction, ctx=ctx, message=ctx, response=actual_response_message ) self.redo_users[ctx.author.id].add_interaction( actual_response_message.id @@ -1771,13 +1777,15 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): class ConversationView(discord.ui.View): - def __init__(self, ctx, converser_cog, id, custom_api_key=None): + def __init__(self, ctx, converser_cog, id, model, from_edit_command=False, custom_api_key=None): super().__init__(timeout=3600) # 1 hour interval to redo. self.converser_cog = converser_cog self.ctx = ctx + self.model = model + self.from_edit_command = from_edit_command self.custom_api_key = custom_api_key self.add_item( - RedoButton(self.converser_cog, custom_api_key=self.custom_api_key) + RedoButton(self.converser_cog, model, from_edit_command, custom_api_key=self.custom_api_key) ) if id in self.converser_cog.conversation_threads: @@ -1829,9 +1837,11 @@ class EndConvoButton(discord.ui.Button["ConversationView"]): class RedoButton(discord.ui.Button["ConversationView"]): - def __init__(self, converser_cog, custom_api_key): + def __init__(self, converser_cog, model, from_edit_command, custom_api_key): super().__init__(style=discord.ButtonStyle.danger, label="Retry") self.converser_cog = converser_cog + self.model = model + self.from_edit_command = from_edit_command self.custom_api_key = custom_api_key async def callback(self, interaction: discord.Interaction): @@ -1843,6 +1853,7 @@ class RedoButton(discord.ui.Button["ConversationView"]): ].in_interaction(interaction.message.id): # Get the message and the prompt and call encapsulated_send prompt = self.converser_cog.redo_users[user_id].prompt + instruction = self.converser_cog.redo_users[user_id].instruction ctx = self.converser_cog.redo_users[user_id].ctx response_message = self.converser_cog.redo_users[user_id].response @@ -1853,10 +1864,13 @@ class RedoButton(discord.ui.Button["ConversationView"]): await self.converser_cog.encapsulated_send( id=user_id, prompt=prompt, + instruction=instruction, ctx=ctx, + model=self.model, response_message=response_message, custom_api_key=self.custom_api_key, redo_request=True, + from_edit_command=self.from_edit_command ) else: await interaction.response.send_message( diff --git a/models/user_model.py b/models/user_model.py index f8521d5..5a58741 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -5,8 +5,9 @@ history, message count, and the id of the user in order to track them. class RedoUser: - def __init__(self, prompt, message, ctx, response): + def __init__(self, prompt, instruction, message, ctx, response): self.prompt = prompt + self.instruction = instruction self.message = message self.ctx = ctx self.response = response