Kaveen Kumarasinghe 2 years ago
commit 1f96dda52a

@ -23,7 +23,7 @@ users_to_interactions = {}
class DrawDallEService(commands.Cog, name="DrawDallEService"): class DrawDallEService(commands.Cog, name="DrawDallEService"):
def __init__( def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
): ):
self.bot = bot self.bot = bot
self.usage_service = usage_service self.usage_service = usage_service
@ -47,13 +47,13 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
print(f"Image prompt optimizer was added") print(f"Image prompt optimizer was added")
async def encapsulated_send( async def encapsulated_send(
self, self,
prompt, prompt,
message, message,
response_message=None, response_message=None,
vary=None, vary=None,
draw_from_optimizer=None, draw_from_optimizer=None,
user_id=None, user_id=None,
): ):
await asyncio.sleep(0) await asyncio.sleep(0)
# send the prompt to the model # send the prompt to the model
@ -260,8 +260,8 @@ class VaryButton(discord.ui.Button):
if len(self.converser_cog.users_to_interactions[user_id]) >= 2: if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id interaction_id2 = interaction.id
if ( if (
interaction_id2 interaction_id2
not in self.converser_cog.users_to_interactions[user_id] not in self.converser_cog.users_to_interactions[user_id]
): ):
await interaction.response.send_message( await interaction.response.send_message(
content="You can not vary images in someone else's chain!", content="You can not vary images in someone else's chain!",
@ -286,13 +286,14 @@ class VaryButton(discord.ui.Button):
) )
prompt = redo_users[user_id].prompt prompt = redo_users[user_id].prompt
asyncio.ensure_future(self.cog.encapsulated_send( asyncio.ensure_future(
prompt, self.cog.encapsulated_send(
interaction.message, prompt,
response_message=response_message, interaction.message,
vary=self.image_url, response_message=response_message,
user_id=user_id, vary=self.image_url,
) user_id=user_id,
)
) )
@ -356,4 +357,6 @@ class RedoButton(discord.ui.Button["SaveView"]):
) )
self.converser_cog.users_to_interactions[user_id].append(message.id) self.converser_cog.users_to_interactions[user_id].append(message.id)
asyncio.ensure_future(self.cog.encapsulated_send(prompt, message, response_message)) asyncio.ensure_future(
self.cog.encapsulated_send(prompt, message, response_message)
)

@ -24,14 +24,14 @@ original_message = {}
class GPT3ComCon(commands.Cog, name="GPT3ComCon"): class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
def __init__( def __init__(
self, self,
bot, bot,
usage_service, usage_service,
model, model,
message_queue, message_queue,
deletion_queue, deletion_queue,
DEBUG_GUILD, DEBUG_GUILD,
DEBUG_CHANNEL, DEBUG_CHANNEL,
): ):
self.debug_channel = None self.debug_channel = None
self.bot = bot self.bot = bot
@ -130,13 +130,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
def check_conversing(self, message): def check_conversing(self, message):
cond1 = ( cond1 = (
message.author.id in self.conversating_users message.author.id in self.conversating_users
and message.channel.name in ["gpt3", "general-bot", "bot"] and message.channel.name in ["gpt3", "general-bot", "bot"]
) )
cond2 = ( cond2 = (
message.author.id in self.conversating_users message.author.id in self.conversating_users
and message.author.id in self.conversation_threads and message.author.id in self.conversation_threads
and message.channel.id == self.conversation_threads[message.author.id] and message.channel.id == self.conversation_threads[message.author.id]
) )
# If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation # If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation
@ -282,7 +282,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
async def paginate_and_send(self, response_text, message): async def paginate_and_send(self, response_text, message):
response_text = [ response_text = [
response_text[i: i + self.TEXT_CUTOFF] response_text[i : i + self.TEXT_CUTOFF]
for i in range(0, len(response_text), self.TEXT_CUTOFF) for i in range(0, len(response_text), self.TEXT_CUTOFF)
] ]
# Send each chunk as a message # Send each chunk as a message
@ -299,7 +299,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
async def queue_debug_chunks(self, debug_message, message, debug_channel): async def queue_debug_chunks(self, debug_message, message, debug_channel):
debug_message_chunks = [ debug_message_chunks = [
debug_message[i: i + self.TEXT_CUTOFF] debug_message[i : i + self.TEXT_CUTOFF]
for i in range(0, len(debug_message), self.TEXT_CUTOFF) for i in range(0, len(debug_message), self.TEXT_CUTOFF)
] ]
@ -342,8 +342,8 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
if message.author.id in self.conversating_users: if message.author.id in self.conversating_users:
# If the user has reached the max conversation length, end the conversation # If the user has reached the max conversation length, end the conversation
if ( if (
self.conversating_users[message.author.id].count self.conversating_users[message.author.id].count
>= self.model.max_conversation_length >= self.model.max_conversation_length
): ):
await message.reply( await message.reply(
"You have reached the maximum conversation length. You have ended the conversation with GPT3, and it has ended." "You have reached the maximum conversation length. You have ended the conversation with GPT3, and it has ended."
@ -393,14 +393,14 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# Check again if the prompt is about to go past the token limit # Check again if the prompt is about to go past the token limit
new_prompt = ( new_prompt = (
"".join(self.conversating_users[message.author.id].history) "".join(self.conversating_users[message.author.id].history)
+ "\nGPTie: " + "\nGPTie: "
) )
tokens = self.usage_service.count_tokens(new_prompt) tokens = self.usage_service.count_tokens(new_prompt)
if ( if (
tokens > self.model.summarize_threshold - 150 tokens > self.model.summarize_threshold - 150
): # 150 is a buffer for the second stage ): # 150 is a buffer for the second stage
await message.reply( await message.reply(
"I tried to summarize our current conversation so we could keep chatting, " "I tried to summarize our current conversation so we could keep chatting, "
@ -444,7 +444,9 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# Paginate and send the response back to the users # Paginate and send the response back to the users
if not response_message: if not response_message:
if len(response_text) > self.TEXT_CUTOFF: if len(response_text) > self.TEXT_CUTOFF:
await self.paginate_and_send(response_text, message) # No paginations for multi-messages. await self.paginate_and_send(
response_text, message
) # No paginations for multi-messages.
else: else:
response_message = await message.reply( response_message = await message.reply(
response_text.replace("<|endofstatement|>", ""), response_text.replace("<|endofstatement|>", ""),
@ -453,8 +455,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.redo_users[message.author.id] = RedoUser( self.redo_users[message.author.id] = RedoUser(
prompt, message, response_message prompt, message, response_message
) )
self.redo_users[message.author.id].add_interaction(response_message.id) self.redo_users[message.author.id].add_interaction(
print(f"Added the interaction {response_message.id} to the redo user {message.author.id}") response_message.id
)
print(
f"Added the interaction {response_message.id} to the redo user {message.author.id}"
)
original_message[message.author.id] = message.id original_message[message.author.id] = message.id
else: else:
# We have response_text available, this is the original message that we want to edit # We have response_text available, this is the original message that we want to edit
@ -551,7 +557,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# A global GLOBAL_COOLDOWN_TIME timer for all users # A global GLOBAL_COOLDOWN_TIME timer for all users
if (message.author.id in self.last_used) and ( if (message.author.id in self.last_used) and (
time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME
): ):
await message.reply( await message.reply(
"You must wait " "You must wait "
@ -661,12 +667,15 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# This is because encapsulated_send is a coroutine, and we need to await it to get the response from the model. # This is because encapsulated_send is a coroutine, and we need to await it to get the response from the model.
# We can't await it in the main thread, so we need to create a new thread to run it in. # We can't await it in the main thread, so we need to create a new thread to run it in.
# We can make sure that when the thread executes it executes in an async fashion by # We can make sure that when the thread executes it executes in an async fashion by
asyncio.run_coroutine_threadsafe(self.encapsulated_send( asyncio.run_coroutine_threadsafe(
message, self.encapsulated_send(
prompt message,
if message.author.id not in self.conversating_users prompt
else "".join(self.conversating_users[message.author.id].history), if message.author.id not in self.conversating_users
), asyncio.get_running_loop()) else "".join(self.conversating_users[message.author.id].history),
),
asyncio.get_running_loop(),
)
class RedoView(discord.ui.View): class RedoView(discord.ui.View):
@ -696,18 +705,27 @@ class EndConvoButton(discord.ui.Button["RedoView"]):
# Get the user # Get the user
user_id = interaction.user.id user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id): if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
try: try:
await self.converser_cog.end_conversation(self.converser_cog.redo_users[user_id].message) await self.converser_cog.end_conversation(
await interaction.response.send_message("Your conversation has ended!", ephemeral=True, self.converser_cog.redo_users[user_id].message
delete_after=10) )
await interaction.response.send_message(
"Your conversation has ended!", ephemeral=True, delete_after=10
)
except Exception as e: except Exception as e:
print(e) print(e)
traceback.print_exc() traceback.print_exc()
await interaction.response.send_message(e, ephemeral=True, delete_after=30) await interaction.response.send_message(
e, ephemeral=True, delete_after=30
)
pass pass
else: else:
await interaction.response.send_message("This is not your conversation to end!", ephemeral=True, delete_after=10) await interaction.response.send_message(
"This is not your conversation to end!", ephemeral=True, delete_after=10
)
class RedoButton(discord.ui.Button["RedoView"]): class RedoButton(discord.ui.Button["RedoView"]):
@ -719,7 +737,9 @@ class RedoButton(discord.ui.Button["RedoView"]):
# Get the user # Get the user
user_id = interaction.user.id user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id): if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction.message.id):
# Get the message and the prompt and call encapsulated_send # Get the message and the prompt and call encapsulated_send
message = self.converser_cog.redo_users[user_id].message message = self.converser_cog.redo_users[user_id].message
prompt = self.converser_cog.redo_users[user_id].prompt prompt = self.converser_cog.redo_users[user_id].prompt
@ -733,4 +753,8 @@ class RedoButton(discord.ui.Button["RedoView"]):
message, prompt, response_message message, prompt, response_message
) )
else: else:
await interaction.response.send_message("You can only redo the most recent prompt that you sent yourself.", ephemeral=True, delete_after=10) await interaction.response.send_message(
"You can only redo the most recent prompt that you sent yourself.",
ephemeral=True,
delete_after=10,
)

@ -96,8 +96,12 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
response_message.id response_message.id
) )
self.converser_cog.redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message) self.converser_cog.redo_users[ctx.author.id] = RedoUser(
self.converser_cog.redo_users[ctx.author.id].add_interaction(response_message.id) prompt, ctx.message, response_message
)
self.converser_cog.redo_users[ctx.author.id].add_interaction(
response_message.id
)
await response_message.edit( await response_message.edit(
view=OptimizeView( view=OptimizeView(
self.converser_cog, self.image_service_cog, self.deletion_queue self.converser_cog, self.image_service_cog, self.deletion_queue
@ -140,7 +144,10 @@ class DrawButton(discord.ui.Button["OptimizeView"]):
user_id = interaction.user.id user_id = interaction.user.id
interaction_id = interaction.message.id interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id] or interaction_id not in self.converser_cog.redo_users[user_id].interactions: if (
interaction_id not in self.converser_cog.users_to_interactions[user_id]
or interaction_id not in self.converser_cog.redo_users[user_id].interactions
):
await interaction.response.send_message( await interaction.response.send_message(
content="You can only draw for prompts that you generated yourself!", content="You can only draw for prompts that you generated yourself!",
ephemeral=True, ephemeral=True,
@ -184,7 +191,9 @@ class RedoButton(discord.ui.Button["OptimizeView"]):
# Get the user # Get the user
user_id = interaction.user.id user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction_id): if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction_id):
# Get the message and the prompt and call encapsulated_send # Get the message and the prompt and call encapsulated_send
message = self.converser_cog.redo_users[user_id].message message = self.converser_cog.redo_users[user_id].message
prompt = self.converser_cog.redo_users[user_id].prompt prompt = self.converser_cog.redo_users[user_id].prompt
@ -198,5 +207,6 @@ class RedoButton(discord.ui.Button["OptimizeView"]):
else: else:
await interaction.response.send_message( await interaction.response.send_message(
content="You can only redo for prompts that you generated yourself!", content="You can only redo for prompts that you generated yourself!",
ephemeral=True, delete_after=10 ephemeral=True,
delete_after=10,
) )

@ -3,6 +3,7 @@ Store information about a discord user, for the purposes of enabling conversatio
history, message count, and the id of the user in order to track them. history, message count, and the id of the user in order to track them.
""" """
class RedoUser: class RedoUser:
def __init__(self, prompt, message, response): def __init__(self, prompt, message, response):
self.prompt = prompt self.prompt = prompt
@ -27,6 +28,7 @@ class RedoUser:
def __repr__(self): def __repr__(self):
return f"RedoUser({self.message.author.id})" return f"RedoUser({self.message.author.id})"
class User: class User:
def __init__(self, id): def __init__(self, id):
self.id = id self.id = id

Loading…
Cancel
Save