diff --git a/cogs/text_service_cog.py b/cogs/text_service_cog.py index 760bb0b..0ef8b3f 100644 --- a/cogs/text_service_cog.py +++ b/cogs/text_service_cog.py @@ -1017,6 +1017,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): top_p: float, frequency_penalty: float, presence_penalty: float, + use_threads: bool = True, # Add this parameter ): """Command handler. Starts a conversation with the bot @@ -1059,32 +1060,40 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # ) # return - if private: - await ctx.respond( - embed=discord.Embed( - title=f"{user.name}'s private conversation with GPT3", - color=0x808080, + # Add a variable to store the channel or thread + target = None + + if use_threads: + if private: + embed_title = f"{user.name}'s private conversation with GPT3" + thread = await ctx.channel.create_thread( + name=user.name + "'s private conversation with GPT3", + auto_archive_duration=60, ) - ) - thread = await ctx.channel.create_thread( - name=user.name + "'s private conversation with GPT3", - auto_archive_duration=60, - ) - elif not private: - message_thread = await ctx.respond( - embed=discord.Embed( - title=f"{user.name}'s conversation with GPT3", color=0x808080 + target = thread + else: + embed_title = f"{user.name}'s conversation with GPT3" + message_embed = discord.Embed(title=embed_title, color=0x808080) + message_thread = await ctx.send(embed=message_embed) + thread = await message_thread.create_thread( + name=user.name + "'s conversation with GPT3", + auto_archive_duration=60, ) - ) - # Get the actual message object for the message_thread - message_thread_real = await ctx.fetch_message(message_thread.id) - thread = await message_thread_real.create_thread( - name=user.name + "'s conversation with GPT3", - auto_archive_duration=60, - ) + target = thread + else: + target = ctx.channel + if private: + embed_title = f"{user.name}'s private conversation with GPT3" + else: + embed_title = f"{user.name}'s conversation with GPT3" + + embed = discord.Embed(title=embed_title, color=0x808080) + await ctx.send(embed=embed) + + - self.conversation_threads[thread.id] = Thread(thread.id) - self.conversation_threads[thread.id].model = ( + self.conversation_threads[target.id] = Thread(target.id) + self.conversation_threads[target.id].model = ( self.model.model if not model else model ) @@ -1159,32 +1168,24 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.conversation_thread_owners[user_id_normalized].append(thread.id) overrides = self.conversation_threads[thread.id].get_overrides() - await thread.send(f"<@{str(ctx.user.id)}> is the thread owner.") + await target.send(f"<@{str(ctx.user.id)}> is the thread owner.") - await thread.send( + await target.send( embed=EmbedStatics.generate_conversation_embed( - self.conversation_threads, thread, opener, overrides + self.conversation_threads, target, opener, overrides ) ) # send opening if opener: - thread_message = await thread.send( + target_message = await target.send( embed=EmbedStatics.generate_opener_embed(opener) ) - if thread.id in self.conversation_threads: + if target.id in self.conversation_threads: self.awaiting_responses.append(user_id_normalized) - self.awaiting_thread_responses.append(thread.id) - - if not self.pinecone_service: - self.conversation_threads[thread.id].history.append( - EmbeddedConversationItem( - f"\n{ctx.author.display_name}: {opener} <|endofstatement|>\n", - 0, - ) - ) + self.awaiting_target_responses.append(target.id) - self.conversation_threads[thread.id].count += 1 + # ... (no other changes in the middle part of the function) overrides = Override( overrides["temperature"], @@ -1195,21 +1196,21 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await TextService.encapsulated_send( self, - thread.id, + target.id, opener - if thread.id not in self.conversation_threads or self.pinecone_service + if target.id not in self.conversation_threads or self.pinecone_service else "".join( - [item.text for item in self.conversation_threads[thread.id].history] + [item.text for item in self.conversation_threads[target.id].history] ), - thread_message, + target_message, overrides=overrides, user=user, - model=self.conversation_threads[thread.id].model, + model=self.conversation_threads[target.id].model, custom_api_key=user_api_key, ) self.awaiting_responses.remove(user_id_normalized) - if thread.id in self.awaiting_thread_responses: - self.awaiting_thread_responses.remove(thread.id) + if target.id in self.awaiting_target_responses: + self.awaiting_target_responses.remove(target.id) async def end_command(self, ctx: discord.ApplicationContext): """Command handler. Gets the user's thread and ends it"""