Update text_service_cog.py

Uhhh, channel chat support

Signed-off-by: Jiankun-Huang <58890715+Jiankun-Huang@users.noreply.github.com>
Jiankun-Huang 2 years ago committed by GitHub
parent 4f5cacca46
commit 66efa5abd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1017,6 +1017,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
top_p: float, top_p: float,
frequency_penalty: float, frequency_penalty: float,
presence_penalty: float, presence_penalty: float,
use_threads: bool = True, # Add this parameter
): ):
"""Command handler. Starts a conversation with the bot """Command handler. Starts a conversation with the bot
@ -1059,32 +1060,40 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# ) # )
# return # return
if private: # Add a variable to store the channel or thread
await ctx.respond( target = None
embed=discord.Embed(
title=f"{user.name}'s private conversation with GPT3", if use_threads:
color=0x808080, if private:
embed_title = f"{user.name}'s private conversation with GPT3"
thread = await ctx.channel.create_thread(
name=user.name + "'s private conversation with GPT3",
auto_archive_duration=60,
) )
) target = thread
thread = await ctx.channel.create_thread( else:
name=user.name + "'s private conversation with GPT3", embed_title = f"{user.name}'s conversation with GPT3"
auto_archive_duration=60, message_embed = discord.Embed(title=embed_title, color=0x808080)
) message_thread = await ctx.send(embed=message_embed)
elif not private: thread = await message_thread.create_thread(
message_thread = await ctx.respond( name=user.name + "'s conversation with GPT3",
embed=discord.Embed( auto_archive_duration=60,
title=f"{user.name}'s conversation with GPT3", color=0x808080
) )
) target = thread
# Get the actual message object for the message_thread else:
message_thread_real = await ctx.fetch_message(message_thread.id) target = ctx.channel
thread = await message_thread_real.create_thread( if private:
name=user.name + "'s conversation with GPT3", embed_title = f"{user.name}'s private conversation with GPT3"
auto_archive_duration=60, else:
) embed_title = f"{user.name}'s conversation with GPT3"
embed = discord.Embed(title=embed_title, color=0x808080)
await ctx.send(embed=embed)
self.conversation_threads[thread.id] = Thread(thread.id) self.conversation_threads[target.id] = Thread(target.id)
self.conversation_threads[thread.id].model = ( self.conversation_threads[target.id].model = (
self.model.model if not model else model self.model.model if not model else model
) )
@ -1159,32 +1168,24 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.conversation_thread_owners[user_id_normalized].append(thread.id) self.conversation_thread_owners[user_id_normalized].append(thread.id)
overrides = self.conversation_threads[thread.id].get_overrides() overrides = self.conversation_threads[thread.id].get_overrides()
await thread.send(f"<@{str(ctx.user.id)}> is the thread owner.") await target.send(f"<@{str(ctx.user.id)}> is the thread owner.")
await thread.send( await target.send(
embed=EmbedStatics.generate_conversation_embed( embed=EmbedStatics.generate_conversation_embed(
self.conversation_threads, thread, opener, overrides self.conversation_threads, target, opener, overrides
) )
) )
# send opening # send opening
if opener: if opener:
thread_message = await thread.send( target_message = await target.send(
embed=EmbedStatics.generate_opener_embed(opener) embed=EmbedStatics.generate_opener_embed(opener)
) )
if thread.id in self.conversation_threads: if target.id in self.conversation_threads:
self.awaiting_responses.append(user_id_normalized) self.awaiting_responses.append(user_id_normalized)
self.awaiting_thread_responses.append(thread.id) self.awaiting_target_responses.append(target.id)
if not self.pinecone_service:
self.conversation_threads[thread.id].history.append(
EmbeddedConversationItem(
f"\n{ctx.author.display_name}: {opener} <|endofstatement|>\n",
0,
)
)
self.conversation_threads[thread.id].count += 1 # ... (no other changes in the middle part of the function)
overrides = Override( overrides = Override(
overrides["temperature"], overrides["temperature"],
@ -1195,21 +1196,21 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await TextService.encapsulated_send( await TextService.encapsulated_send(
self, self,
thread.id, target.id,
opener opener
if thread.id not in self.conversation_threads or self.pinecone_service if target.id not in self.conversation_threads or self.pinecone_service
else "".join( else "".join(
[item.text for item in self.conversation_threads[thread.id].history] [item.text for item in self.conversation_threads[target.id].history]
), ),
thread_message, target_message,
overrides=overrides, overrides=overrides,
user=user, user=user,
model=self.conversation_threads[thread.id].model, model=self.conversation_threads[target.id].model,
custom_api_key=user_api_key, custom_api_key=user_api_key,
) )
self.awaiting_responses.remove(user_id_normalized) self.awaiting_responses.remove(user_id_normalized)
if thread.id in self.awaiting_thread_responses: if target.id in self.awaiting_target_responses:
self.awaiting_thread_responses.remove(thread.id) self.awaiting_target_responses.remove(target.id)
async def end_command(self, ctx: discord.ApplicationContext): async def end_command(self, ctx: discord.ApplicationContext):
"""Command handler. Gets the user's thread and ends it""" """Command handler. Gets the user's thread and ends it"""

Loading…
Cancel
Save