Changed conversations to be keyed by thread

The one starting it is the thread owner
Can only own one thread but can converse in multiple

Added a check dmchannel for the message edit moderation
Gave an error when opening a conversation then ending it
Rene Teigen 2 years ago
parent 602742eb59
commit c0e0d8ba8b

@ -14,7 +14,7 @@ from models.deletion_service_model import Deletion
from models.env_service_model import EnvService
from models.message_model import Message
from models.moderations_service_model import Moderation
from models.user_model import User, RedoUser
from models.user_model import RedoUser, Thread
from models.check_model import Check
from models.autocomplete_model import Settings_autocompleter, File_autocompleter
from collections import defaultdict
@ -44,7 +44,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.debug_channel = None
self.bot = bot
self._last_member_ = None
self.conversating_users = {}
self.conversation_threads = {}
self.DAVINCI_ROLES = ["admin", "Admin", "GPT", "gpt"]
self.END_PROMPTS = [
"end",
@ -103,7 +103,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
)
self.TEXT_CUTOFF = 1900
self.message_queue = message_queue
self.conversation_threads = {}
self.conversation_thread_owners = {}
# Create slash command groups
dalle = discord.SlashCommandGroup(
@ -220,11 +220,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await thread.delete()
await ctx.respond("All conversation threads have been deleted.")
def check_conversing(self, user_id, channel_id, message_content):
#TODO: add extra condition to check if multi is enabled for the thread, stated in conversation_threads
def check_conversing(self, user_id, channel_id, message_content, multi=None):
cond1 = (
user_id in self.conversating_users
and user_id in self.conversation_threads
and channel_id == self.conversation_threads[user_id]
channel_id in self.conversation_threads
#and user_id in self.conversation_thread_owners
#and channel_id == self.conversation_thread_owners[user_id]
)
# If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation
try:
@ -235,23 +236,33 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
return (cond1) and cond2
async def end_conversation(self, message, opener_user_id=None):
normalized_user_id = opener_user_id if opener_user_id else message.author.id
self.conversating_users.pop(normalized_user_id)
async def end_conversation(self, ctx, opener_user_id=None):
normalized_user_id = opener_user_id if opener_user_id else ctx.author.id
try:
channel_id = self.conversation_thread_owners[normalized_user_id]
except:
await ctx.delete(delay=5)
await ctx.reply("Only the conversation starter can end this.", delete_after=5)
return
self.conversation_threads.pop(channel_id)
if isinstance(message, discord.ApplicationContext):
await message.respond(
"Your conversation has ended!", ephemeral=True, delete_after=10
if isinstance(ctx, discord.ApplicationContext):
await ctx.respond(
"You have ended the conversation with GPT3. Start a conversation with /gpt converse", ephemeral=True, delete_after=10
)
elif isinstance(ctx, discord.Interaction):
await ctx.response.send_message(
"You have ended the conversation with GPT3. Start a conversation with /gpt converse", ephemeral=True, delete_after=10
)
else:
await message.reply(
await ctx.reply(
"You have ended the conversation with GPT3. Start a conversation with /gpt converse"
)
# Close all conversation threads for the user
if normalized_user_id in self.conversation_threads:
thread_id = self.conversation_threads[normalized_user_id]
self.conversation_threads.pop(normalized_user_id)
if normalized_user_id in self.conversation_thread_owners:
thread_id = self.conversation_thread_owners[normalized_user_id]
self.conversation_thread_owners.pop(normalized_user_id)
# Attempt to close and lock the thread.
try:
@ -449,10 +460,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
async def check_conversation_limit(self, message):
# After each response, check if the user has reached the conversation limit in terms of messages or time.
if message.author.id in self.conversating_users:
if message.channel.id in self.conversation_threads:
# If the user has reached the max conversation length, end the conversation
if (
self.conversating_users[message.author.id].count
self.conversation_threads[message.channel.id].count
>= self.model.max_conversation_length
):
await message.reply(
@ -475,15 +486,16 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
)
# Get the last entry from the user's conversation history
new_conversation_history.append(
self.conversating_users[message.author.id].history[-1] + "\n"
self.conversation_threads[message.channel.id].history[-1] + "\n"
)
self.conversating_users[message.author.id].history = new_conversation_history
self.conversation_threads[message.channel.id].history = new_conversation_history
# A listener for message edits to redo prompts if they are edited
@discord.Cog.listener()
async def on_message_edit(self, before, after):
# Moderation
if not isinstance(after.channel, discord.DMChannel):
if (
after.guild.id in self.moderation_queues
and self.moderation_queues[after.guild.id] is not None
@ -506,22 +518,22 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# If the user is conversing, we need to get their conversation history, delete the last
# "Human:" message, create a new Human: section with the new prompt, and then set the prompt to
# the new prompt, then send that new prompt as the new prompt.
if after.author.id in self.conversating_users:
if after.channel.id in self.conversation_threads:
# Remove the last two elements from the history array and add the new Human: prompt
self.conversating_users[
after.author.id
].history = self.conversating_users[after.author.id].history[:-2]
self.conversating_users[after.author.id].history.append(
self.conversation_threads[
after.channel.id
].history = self.conversation_threads[after.channel.id].history[:-2]
self.conversation_threads[after.channel.id].history.append(
f"\nHuman: {after.content}<|endofstatement|>\n"
)
edited_content = "".join(
self.conversating_users[after.author.id].history
self.conversation_threads[after.channel.id].history
)
self.conversating_users[after.author.id].count += 1
self.conversation_threads[after.channel.id].count += 1
print("Doing the encapsulated send")
await self.encapsulated_send(
user_id=after.author.id,
id=after.channel.id,
prompt=edited_content,
ctx=ctx,
response_message=response_message,
@ -570,7 +582,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# We want to have conversationality functionality. To have gpt3 remember context, we need to append the conversation/prompt
# history to the prompt. We can do this by checking if the user is in the conversating_users dictionary, and if they are,
# we can append their history to the prompt.
if message.author.id in self.conversating_users:
if message.channel.id in self.conversation_threads:
# Since this is async, we don't want to allow the user to send another prompt while a conversation
# prompt is processing, that'll mess up the conversation history!
@ -595,28 +607,28 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
original_message[message.author.id] = message.id
self.conversating_users[message.author.id].history.append(
self.conversation_threads[message.channel.id].history.append(
"\nHuman: " + prompt + "<|endofstatement|>\n"
)
# increment the conversation counter for the user
self.conversating_users[message.author.id].count += 1
self.conversation_threads[message.channel.id].count += 1
# Send the request to the model
# If conversing, the prompt to send is the history, otherwise, it's just the prompt
await self.encapsulated_send(
message.author.id,
message.channel.id,
prompt
if message.author.id not in self.conversating_users
else "".join(self.conversating_users[message.author.id].history),
if message.channel.id not in self.conversation_threads
else "".join(self.conversation_threads[message.channel.id].history),
message,
)
# ctx can be of type AppContext(interaction) or Message
async def encapsulated_send(
self,
user_id,
id,
prompt,
ctx,
temp_override=None,
@ -646,7 +658,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# Check if the prompt is about to go past the token limit
if (
user_id in self.conversating_users
id in self.conversation_threads
and tokens > self.model.summarize_threshold
and not from_g_command
):
@ -663,7 +675,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# Check again if the prompt is about to go past the token limit
new_prompt = (
"".join(self.conversating_users[user_id].history) + "\nGPTie: "
"".join(self.conversation_threads[id].history) + "\nGPTie: "
)
tokens = self.usage_service.count_tokens(new_prompt)
@ -713,8 +725,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
)
# If the user is conversing, add the GPT response to their conversation history.
if user_id in self.conversating_users and not from_g_command:
self.conversating_users[user_id].history.append(
if id in self.conversation_threads and not from_g_command:
self.conversation_threads[id].history.append(
"\nGPTie: " + str(response_text) + "<|endofstatement|>\n"
)
@ -726,12 +738,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
response_message = (
await ctx.respond(
response_text,
view=RedoView(ctx, self, user_id),
view=ConversationView(ctx, self, ctx.channel.id),
)
if from_context
else await ctx.reply(
response_text,
view=RedoView(ctx, self, user_id),
view=ConversationView(ctx, self, ctx.channel.id),
)
)
@ -742,10 +754,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
else await ctx.fetch_message(response_message.id)
)
self.redo_users[user_id] = RedoUser(
self.redo_users[ctx.author.id] = RedoUser(
prompt, ctx, ctx, actual_response_message
)
self.redo_users[user_id].add_interaction(actual_response_message.id)
self.redo_users[ctx.author.id].add_interaction(actual_response_message.id)
# We are doing a redo, edit the message.
else:
@ -755,8 +767,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.generate_debug_message(prompt, response), self.debug_channel
)
if user_id in self.awaiting_responses:
self.awaiting_responses.remove(user_id)
if ctx.author.id in self.awaiting_responses:
self.awaiting_responses.remove(ctx.author.id)
# Error catching for OpenAI model value errors
except ValueError as e:
@ -775,8 +787,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await ctx.send_followup(message) if from_context else await ctx.reply(
message
)
if user_id in self.awaiting_responses:
self.awaiting_responses.remove(user_id)
if ctx.author.id in self.awaiting_responses:
self.awaiting_responses.remove(ctx.author.id)
traceback.print_exc()
try:
@ -901,11 +913,11 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
user = ctx.user
if user.id in self.conversating_users:
if user.id in self.conversation_thread_owners:
message = await ctx.respond(
"You are already conversating with GPT3. End the conversation with !g end or just say 'end' in a supported channel"
"You've already created a thread, end it before creating a new one",
delete_after=5
)
await self.deletion_queue(message)
return
if not opener and not opener_file:
@ -930,18 +942,6 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
else:
pass
self.conversating_users[user_id_normalized] = User(user_id_normalized)
# Append the starter text for gpt3 to the user's history so it gets concatenated with the prompt later
if minimal or opener_file:
self.conversating_users[user_id_normalized].history.append(
self.CONVERSATION_STARTER_TEXT_MINIMAL
)
elif not minimal:
self.conversating_users[user_id_normalized].history.append(
self.CONVERSATION_STARTER_TEXT
)
if private:
await ctx.respond(user.name + "'s private conversation with GPT3")
thread = await ctx.channel.create_thread(
@ -957,6 +957,18 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
auto_archive_duration=60,
)
self.conversation_threads[thread.id] = Thread(thread.id)
# Append the starter text for gpt3 to the user's history so it gets concatenated with the prompt later
if minimal or opener_file:
self.conversation_threads[thread.id].history.append(
self.CONVERSATION_STARTER_TEXT_MINIMAL
)
elif not minimal:
self.conversation_threads[thread.id].history.append(
self.CONVERSATION_STARTER_TEXT
)
await thread.send(
"<@"
+ str(user_id_normalized)
@ -966,24 +978,26 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# send opening
if opener:
thread_message = await thread.send("***Opening prompt*** \n" + opener)
if user_id_normalized in self.conversating_users:
if thread.id in self.conversation_threads:
self.awaiting_responses.append(user_id_normalized)
self.conversating_users[user_id_normalized].history.append(
self.conversation_threads[thread.id].history.append(
"\nHuman: " + opener + "<|endofstatement|>\n"
)
self.conversating_users[user_id_normalized].count += 1
self.conversation_threads[thread.id].count += 1
await self.encapsulated_send(
user_id_normalized,
thread.id,
opener
if user_id_normalized not in self.conversating_users
else "".join(self.conversating_users[user_id_normalized].history),
if thread.id not in self.conversation_threads
else "".join(self.conversation_threads[thread.id].history),
thread_message,
)
self.awaiting_responses.remove(user_id_normalized)
self.conversation_thread_owners[user_id_normalized] = thread.id
self.conversation_threads[user_id_normalized] = thread.id
@add_to_group("system")
@discord.slash_command(
@ -1063,10 +1077,15 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
guild_ids=ALLOWED_GUILDS,
)
@discord.guild_only()
async def end_chat(self, ctx: discord.ApplicationContext):
async def end(self, ctx: discord.ApplicationContext):
await ctx.defer(ephemeral=True)
user_id = ctx.user.id
if user_id in self.conversating_users:
try:
thread_id = self.conversation_thread_owners[user_id]
except:
await ctx.respond("You haven't started any conversations", ephemeral=True, delete_after=10)
return
if thread_id in self.conversation_threads:
try:
await self.end_conversation(ctx)
except Exception as e:
@ -1140,14 +1159,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await self.process_settings_command(ctx, parameter, value)
class RedoView(discord.ui.View):
def __init__(self, ctx, converser_cog, user_id):
class ConversationView(discord.ui.View):
def __init__(self, ctx, converser_cog, id):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.ctx = ctx
self.add_item(RedoButton(self.converser_cog))
if user_id in self.converser_cog.conversating_users:
if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog))
async def on_timeout(self):
@ -1164,7 +1183,7 @@ class RedoView(discord.ui.View):
)
class EndConvoButton(discord.ui.Button["RedoView"]):
class EndConvoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="End Conversation")
self.converser_cog = converser_cog
@ -1173,16 +1192,10 @@ class EndConvoButton(discord.ui.Button["RedoView"]):
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
if user_id in self.converser_cog.conversation_thread_owners and self.converser_cog.conversation_thread_owners[user_id] == interaction.channel.id:
try:
await self.converser_cog.end_conversation(
self.converser_cog.redo_users[user_id].message,
opener_user_id=user_id,
)
await interaction.response.send_message(
"Your conversation has ended!", ephemeral=True, delete_after=10
interaction, opener_user_id=interaction.user.id
)
except Exception as e:
print(e)
@ -1197,7 +1210,7 @@ class EndConvoButton(discord.ui.Button["RedoView"]):
)
class RedoButton(discord.ui.Button["RedoView"]):
class RedoButton(discord.ui.Button["ConversationView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.converser_cog = converser_cog
@ -1219,7 +1232,7 @@ class RedoButton(discord.ui.Button["RedoView"]):
)
await self.converser_cog.encapsulated_send(
user_id, prompt, ctx, response_message
id=user_id, prompt=prompt, ctx=ctx, response_message=response_message
)
else:
await interaction.response.send_message(

@ -50,3 +50,24 @@ class User:
def __str__(self):
return self.__repr__()
class Thread:
def __init__(self, id):
self.id = id
self.history = []
self.count = 0
# These user objects should be accessible by ID, for example if we had a bunch of user
# objects in a list, and we did `if 1203910293001 in user_list`, it would return True
# if the user with that ID was in the list
def __eq__(self, other):
return self.id == other.id
def __hash__(self):
return hash(self.id)
def __repr__(self):
return f"Thread(id={self.id}, history={self.history})"
def __str__(self):
return self.__repr__()
Loading…
Cancel
Save