Fix bugs and add end conversation button to convo messages

Kaveen Kumarasinghe 2 years ago
parent 4366c7f089
commit 3ea5f544e8

@ -1,3 +1,4 @@
import asyncio
import datetime
import os
import re
@ -13,13 +14,8 @@ from discord.ext import commands
from cogs.image_prompt_optimizer import ImgPromptOptimizer
class RedoUser:
def __init__(self, prompt, message, response_message):
self.prompt = prompt
self.message = message
self.response_message = response_message
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from models.user_model import RedoUser
redo_users = {}
users_to_interactions = {}
@ -27,7 +23,7 @@ users_to_interactions = {}
class DrawDallEService(commands.Cog, name="DrawDallEService"):
def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
):
self.bot = bot
self.usage_service = usage_service
@ -51,14 +47,15 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
print(f"Image prompt optimizer was added")
async def encapsulated_send(
self,
prompt,
message,
response_message=None,
vary=None,
draw_from_optimizer=None,
user_id=None,
self,
prompt,
message,
response_message=None,
vary=None,
draw_from_optimizer=None,
user_id=None,
):
await asyncio.sleep(0)
# send the prompt to the model
file, image_urls = self.model.send_image_request(
prompt, vary=vary if not draw_from_optimizer else None
@ -154,7 +151,7 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
# The image prompt is everything after the command
prompt = " ".join(args)
await self.encapsulated_send(prompt, message)
asyncio.ensure_future(self.encapsulated_send(prompt, message))
except Exception as e:
print(e)
@ -263,8 +260,8 @@ class VaryButton(discord.ui.Button):
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id
if (
interaction_id2
not in self.converser_cog.users_to_interactions[user_id]
interaction_id2
not in self.converser_cog.users_to_interactions[user_id]
):
await interaction.response.send_message(
content="You can not vary images in someone else's chain!",
@ -288,13 +285,15 @@ class VaryButton(discord.ui.Button):
response_message.id
)
prompt = redo_users[user_id].prompt
await self.cog.encapsulated_send(
asyncio.ensure_future(self.cog.encapsulated_send(
prompt,
interaction.message,
response_message=response_message,
vary=self.image_url,
user_id=user_id,
)
)
class SaveButton(discord.ui.Button["SaveView"]):
@ -346,7 +345,6 @@ class RedoButton(discord.ui.Button["SaveView"]):
return
# We have passed the intial check of if the interaction belongs to the user
if user_id in redo_users:
# Get the message and the prompt and call encapsulated_send
message = redo_users[user_id].message
@ -358,4 +356,4 @@ class RedoButton(discord.ui.Button["SaveView"]):
)
self.converser_cog.users_to_interactions[user_id].append(message.id)
await self.cog.encapsulated_send(prompt, message, response_message)
asyncio.ensure_future(self.cog.encapsulated_send(prompt, message, response_message))

@ -1,7 +1,10 @@
import asyncio
import datetime
import functools
import json
import os
import re
import threading
import time
import traceback
@ -12,31 +15,23 @@ from cogs.draw_image_generation import DrawDallEService
from cogs.image_prompt_optimizer import ImgPromptOptimizer
from models.deletion_service import Deletion
from models.message_model import Message
from models.user_model import User
from models.user_model import User, RedoUser
from collections import defaultdict
class RedoUser:
def __init__(self, prompt, message, response):
self.prompt = prompt
self.message = message
self.response = response
redo_users = {}
original_message = {}
class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
def __init__(
self,
bot,
usage_service,
model,
message_queue,
deletion_queue,
DEBUG_GUILD,
DEBUG_CHANNEL,
self,
bot,
usage_service,
model,
message_queue,
deletion_queue,
DEBUG_GUILD,
DEBUG_CHANNEL,
):
self.debug_channel = None
self.bot = bot
@ -59,6 +54,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.summarize = self.model.summarize_conversations
self.deletion_queue = deletion_queue
self.users_to_interactions = defaultdict(list)
self.redo_users = {}
try:
# Attempt to read a conversation starter text string from the file.
@ -134,13 +130,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
def check_conversing(self, message):
cond1 = (
message.author.id in self.conversating_users
and message.channel.name in ["gpt3", "general-bot", "bot"]
message.author.id in self.conversating_users
and message.channel.name in ["gpt3", "general-bot", "bot"]
)
cond2 = (
message.author.id in self.conversating_users
and message.author.id in self.conversation_threads
and message.channel.id == self.conversation_threads[message.author.id]
message.author.id in self.conversating_users
and message.author.id in self.conversation_threads
and message.channel.id == self.conversation_threads[message.author.id]
)
# If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation
@ -286,7 +282,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
async def paginate_and_send(self, response_text, message):
response_text = [
response_text[i : i + self.TEXT_CUTOFF]
response_text[i: i + self.TEXT_CUTOFF]
for i in range(0, len(response_text), self.TEXT_CUTOFF)
]
# Send each chunk as a message
@ -303,7 +299,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
async def queue_debug_chunks(self, debug_message, message, debug_channel):
debug_message_chunks = [
debug_message[i : i + self.TEXT_CUTOFF]
debug_message[i: i + self.TEXT_CUTOFF]
for i in range(0, len(debug_message), self.TEXT_CUTOFF)
]
@ -346,8 +342,8 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
if message.author.id in self.conversating_users:
# If the user has reached the max conversation length, end the conversation
if (
self.conversating_users[message.author.id].count
>= self.model.max_conversation_length
self.conversating_users[message.author.id].count
>= self.model.max_conversation_length
):
await message.reply(
"You have reached the maximum conversation length. You have ended the conversation with GPT3, and it has ended."
@ -374,6 +370,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.conversating_users[message.author.id].history = new_conversation_history
async def encapsulated_send(self, message, prompt, response_message=None):
await asyncio.sleep(0)
# Append a newline, and GPTie: to the prompt
new_prompt = prompt + "\nGPTie: "
@ -396,14 +393,14 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# Check again if the prompt is about to go past the token limit
new_prompt = (
"".join(self.conversating_users[message.author.id].history)
+ "\nGPTie: "
"".join(self.conversating_users[message.author.id].history)
+ "\nGPTie: "
)
tokens = self.usage_service.count_tokens(new_prompt)
if (
tokens > self.model.summarize_threshold - 150
tokens > self.model.summarize_threshold - 150
): # 150 is a buffer for the second stage
await message.reply(
"I tried to summarize our current conversation so we could keep chatting, "
@ -420,6 +417,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await self.end_conversation(message)
return
# REQUEST!!!!
response = self.model.send_request(new_prompt, message, tokens=tokens)
response_text = response["choices"][0]["text"]
@ -446,15 +444,17 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# Paginate and send the response back to the users
if not response_message:
if len(response_text) > self.TEXT_CUTOFF:
await self.paginate_and_send(response_text, message)
await self.paginate_and_send(response_text, message) # No paginations for multi-messages.
else:
response_message = await message.reply(
response_text.replace("<|endofstatement|>", ""),
view=RedoView(self),
view=RedoView(self, message.author.id),
)
redo_users[message.author.id] = RedoUser(
self.redo_users[message.author.id] = RedoUser(
prompt, message, response_message
)
self.redo_users[message.author.id].add_interaction(response_message.id)
print(f"Added the interaction {response_message.id} to the redo user {message.author.id}")
original_message[message.author.id] = message.id
else:
# We have response_text available, this is the original message that we want to edit
@ -485,10 +485,10 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# A listener for message edits to redo prompts if they are edited
@commands.Cog.listener()
async def on_message_edit(self, before, after):
if after.author.id in redo_users:
if after.author.id in self.redo_users:
if after.id == original_message[after.author.id]:
message = redo_users[after.author.id].message
response_message = redo_users[after.author.id].response
message = self.redo_users[after.author.id].message
response_message = self.redo_users[after.author.id].response
await response_message.edit(content="Redoing prompt 🔄...")
edited_content = after.content
@ -497,7 +497,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# "Human:" message, create a new Human: section with the new prompt, and then set the prompt to
# the new prompt, then send that new prompt as the new prompt.
if after.author.id in self.conversating_users:
# Remove the last two elements from the history array and add the new Human: prompt
self.conversating_users[
after.author.id
@ -512,7 +511,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await self.encapsulated_send(message, edited_content, response_message)
redo_users[after.author.id].prompt = after.content
self.redo_users[after.author.id].prompt = after.content
@commands.Cog.listener()
async def on_message(self, message):
@ -552,7 +551,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# A global GLOBAL_COOLDOWN_TIME timer for all users
if (message.author.id in self.last_used) and (
time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME
time.time() - self.last_used[message.author.id] < self.GLOBAL_COOLDOWN_TIME
):
await message.reply(
"You must wait "
@ -650,20 +649,35 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# Send the request to the model
# If conversing, the prompt to send is the history, otherwise, it's just the prompt
await self.encapsulated_send(
message,
prompt
if message.author.id not in self.conversating_users
else "".join(self.conversating_users[message.author.id].history),
)
# Create a new thread to run
# self.encapsulated_send(
# message,
# prompt
# if message.author.id not in self.conversating_users
# else "".join(self.conversating_users[message.author.id].history),
# )
# This created thread needs to call encapsulated_send in a coroutine/async fashion.
# This is because encapsulated_send is a coroutine, and we need to await it to get the response from the model.
# We can't await it in the main thread, so we need to create a new thread to run it in.
# We can make sure that when the thread executes it executes in an async fashion by
asyncio.run_coroutine_threadsafe(self.encapsulated_send(
message,
prompt
if message.author.id not in self.conversating_users
else "".join(self.conversating_users[message.author.id].history),
), asyncio.get_running_loop())
class RedoView(discord.ui.View):
def __init__(self, converser_cog):
def __init__(self, converser_cog, user_id):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.add_item(RedoButton(self.converser_cog))
if user_id in self.converser_cog.conversating_users:
self.add_item(EndConvoButton(self.converser_cog))
async def on_timeout(self):
# Remove the button from the view/message
self.clear_items()
@ -673,29 +687,50 @@ class RedoView(discord.ui.View):
)
class EndConvoButton(discord.ui.Button["RedoView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="End Conversation")
self.converser_cog = converser_cog
async def callback(self, interaction: discord.Interaction):
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id):
try:
await self.converser_cog.end_conversation(self.converser_cog.redo_users[user_id].message)
await interaction.response.send_message("Your conversation has ended!", ephemeral=True,
delete_after=10)
except Exception as e:
print(e)
traceback.print_exc()
await interaction.response.send_message(e, ephemeral=True, delete_after=30)
pass
else:
await interaction.response.send_message("This is not your conversation to end!", ephemeral=True, delete_after=10)
class RedoButton(discord.ui.Button["RedoView"]):
def __init__(self, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.converser_cog = converser_cog
async def callback(self, interaction: discord.Interaction):
msg = await interaction.response.send_message(
"Retrying your original request...", ephemeral=True
)
# Put the message into the deletion queue with a timestamp of 10 seconds from now to be deleted
deletion = Deletion(
msg, (datetime.datetime.now() + datetime.timedelta(seconds=10)).timestamp()
)
await self.converser_cog.deletion_queue.put(deletion)
# Get the user
user_id = interaction.user.id
if user_id in redo_users:
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction.message.id):
# Get the message and the prompt and call encapsulated_send
message = redo_users[user_id].message
prompt = redo_users[user_id].prompt
response_message = redo_users[user_id].response
message = self.converser_cog.redo_users[user_id].message
prompt = self.converser_cog.redo_users[user_id].prompt
response_message = self.converser_cog.redo_users[user_id].response
msg = await interaction.response.send_message(
"Retrying your original request...", ephemeral=True, delete_after=15
)
await self.converser_cog.encapsulated_send(
message, prompt, response_message
)
else:
await interaction.response.send_message("You can only redo the most recent prompt that you sent yourself.", ephemeral=True, delete_after=10)

@ -8,15 +8,7 @@ import discord
from discord.ext import commands
from models.deletion_service import Deletion
redo_users = {}
class RedoUser:
def __init__(self, prompt, message, response):
self.prompt = prompt
self.message = message
self.response = response
from models.user_model import RedoUser
class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
@ -73,11 +65,14 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
print(
f"Received an image optimization request for the following prompt: {prompt}"
)
# Get the token amount for the prompt
tokens = self.usage_service.count_tokens(prompt)
try:
response = self.model.send_request(
response = await self.model.send_request(
prompt,
ctx.message,
tokens=tokens,
top_p_override=1.0,
temp_override=0.9,
presence_penalty_override=0.5,
@ -101,7 +96,8 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
response_message.id
)
redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message)
self.converser_cog.redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message)
self.converser_cog.redo_users[ctx.author.id].add_interaction(response_message.id)
await response_message.edit(
view=OptimizeView(
self.converser_cog, self.image_service_cog, self.deletion_queue
@ -144,7 +140,7 @@ class DrawButton(discord.ui.Button["OptimizeView"]):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
if interaction_id not in self.converser_cog.users_to_interactions[user_id] or interaction_id not in self.converser_cog.redo_users[user_id].interactions:
await interaction.response.send_message(
content="You can only draw for prompts that you generated yourself!",
ephemeral=True,
@ -183,34 +179,24 @@ class RedoButton(discord.ui.Button["OptimizeView"]):
self.deletion_queue = deletion_queue
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only redo for prompts that you generated yourself!",
ephemeral=True,
)
return
msg = await interaction.response.send_message(
"Redoing your original request...", ephemeral=True
)
# Put the message into the deletion queue with a timestamp of 10 seconds from now to be deleted
deletion = Deletion(
msg, (datetime.datetime.now() + datetime.timedelta(seconds=10)).timestamp()
)
await self.deletion_queue.put(deletion)
# Get the user
user_id = interaction.user.id
if user_id in redo_users:
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[user_id].in_interaction(interaction_id):
# Get the message and the prompt and call encapsulated_send
message = redo_users[user_id].message
prompt = redo_users[user_id].prompt
response_message = redo_users[user_id].response
message = self.converser_cog.redo_users[user_id].message
prompt = self.converser_cog.redo_users[user_id].prompt
response_message = self.converser_cog.redo_users[user_id].response
msg = await interaction.response.send_message(
"Redoing your original request...", ephemeral=True, delete_after=20
)
await self.converser_cog.encapsulated_send(
message, prompt, response_message
)
else:
await interaction.response.send_message(
content="You can only redo for prompts that you generated yourself!",
ephemeral=True, delete_after=10
)

@ -3,6 +3,29 @@ Store information about a discord user, for the purposes of enabling conversatio
history, message count, and the id of the user in order to track them.
"""
class RedoUser:
def __init__(self, prompt, message, response):
self.prompt = prompt
self.message = message
self.response = response
self.interactions = []
def add_interaction(self, interaction):
self.interactions.append(interaction)
def in_interaction(self, interaction):
return interaction in self.interactions
# Represented by user_id
def __hash__(self):
return hash(self.message.author.id)
def __eq__(self, other):
return self.message.author.id == other.message.author.id
# repr
def __repr__(self):
return f"RedoUser({self.message.author.id})"
class User:
def __init__(self, id):

Loading…
Cancel
Save