|
|
|
@ -24,8 +24,11 @@ class RedoUser:
|
|
|
|
|
redo_users = {}
|
|
|
|
|
users_to_interactions = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
def __init__(self, bot, usage_service, model, message_queue, deletion_queue, converser_cog):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
|
|
|
|
|
):
|
|
|
|
|
self.bot = bot
|
|
|
|
|
self.usage_service = usage_service
|
|
|
|
|
self.model = model
|
|
|
|
@ -47,8 +50,15 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
)
|
|
|
|
|
print(f"Image prompt optimizer was added")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def encapsulated_send(self, prompt, message, response_message=None, vary=None, draw_from_optimizer=None, user_id=None):
|
|
|
|
|
async def encapsulated_send(
|
|
|
|
|
self,
|
|
|
|
|
prompt,
|
|
|
|
|
message,
|
|
|
|
|
response_message=None,
|
|
|
|
|
vary=None,
|
|
|
|
|
draw_from_optimizer=None,
|
|
|
|
|
user_id=None,
|
|
|
|
|
):
|
|
|
|
|
# send the prompt to the model
|
|
|
|
|
file, image_urls = self.model.send_image_request(
|
|
|
|
|
prompt, vary=vary if not draw_from_optimizer else None
|
|
|
|
@ -56,7 +66,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
|
|
|
|
|
# Start building an embed to send to the user with the results of the image generation
|
|
|
|
|
embed = discord.Embed(
|
|
|
|
|
title="Image Generation Results" if not vary else "Image Generation Results (Varying)" if not draw_from_optimizer else "Image Generation Results (Drawing from Optimizer)",
|
|
|
|
|
title="Image Generation Results"
|
|
|
|
|
if not vary
|
|
|
|
|
else "Image Generation Results (Varying)"
|
|
|
|
|
if not draw_from_optimizer
|
|
|
|
|
else "Image Generation Results (Drawing from Optimizer)",
|
|
|
|
|
description=f"{prompt}",
|
|
|
|
|
color=0xC730C7,
|
|
|
|
|
)
|
|
|
|
@ -64,42 +78,55 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
# Add the image file to the embed
|
|
|
|
|
embed.set_image(url=f"attachment://{file.filename}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not response_message: # Original generation case
|
|
|
|
|
# Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)
|
|
|
|
|
result_message = await message.channel.send(
|
|
|
|
|
embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)
|
|
|
|
|
embed=embed,
|
|
|
|
|
file=file,
|
|
|
|
|
view=SaveView(image_urls, self, self.converser_cog),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.converser_cog.users_to_interactions[message.author.id] = []
|
|
|
|
|
self.converser_cog.users_to_interactions[message.author.id].append(result_message.id)
|
|
|
|
|
|
|
|
|
|
redo_users[message.author.id] = RedoUser(
|
|
|
|
|
prompt, message, result_message
|
|
|
|
|
self.converser_cog.users_to_interactions[message.author.id].append(
|
|
|
|
|
result_message.id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
redo_users[message.author.id] = RedoUser(prompt, message, result_message)
|
|
|
|
|
else:
|
|
|
|
|
if not vary: # Editing case
|
|
|
|
|
message = await response_message.edit(embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog))
|
|
|
|
|
message = await response_message.edit(
|
|
|
|
|
embed=embed,
|
|
|
|
|
file=file,
|
|
|
|
|
view=SaveView(image_urls, self, self.converser_cog),
|
|
|
|
|
)
|
|
|
|
|
else: # Varying case
|
|
|
|
|
if not draw_from_optimizer:
|
|
|
|
|
result_message = await response_message.edit_original_response(content="Image variation completed!",
|
|
|
|
|
embed=embed, file=file,
|
|
|
|
|
view=SaveView(image_urls, self, self.converser_cog, True))
|
|
|
|
|
result_message = await response_message.edit_original_response(
|
|
|
|
|
content="Image variation completed!",
|
|
|
|
|
embed=embed,
|
|
|
|
|
file=file,
|
|
|
|
|
view=SaveView(image_urls, self, self.converser_cog, True),
|
|
|
|
|
)
|
|
|
|
|
redo_users[message.author.id] = RedoUser(
|
|
|
|
|
prompt, message, result_message
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
result_message = await response_message.edit_original_response(
|
|
|
|
|
content="I've drawn the optimized prompt!", embed=embed, file=file,
|
|
|
|
|
view=SaveView(image_urls, self, self.converser_cog))
|
|
|
|
|
|
|
|
|
|
redo_users[user_id] = RedoUser(
|
|
|
|
|
prompt, message, result_message
|
|
|
|
|
content="I've drawn the optimized prompt!",
|
|
|
|
|
embed=embed,
|
|
|
|
|
file=file,
|
|
|
|
|
view=SaveView(image_urls, self, self.converser_cog),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if (user_id):
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(response_message.id)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(result_message.id)
|
|
|
|
|
redo_users[user_id] = RedoUser(prompt, message, result_message)
|
|
|
|
|
|
|
|
|
|
if user_id:
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(
|
|
|
|
|
response_message.id
|
|
|
|
|
)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(
|
|
|
|
|
result_message.id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@commands.command()
|
|
|
|
|
async def draw(self, ctx, *args):
|
|
|
|
@ -110,14 +137,17 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
|
|
|
|
|
# Only allow the bot to be used by people who have the role "Admin" or "GPT"
|
|
|
|
|
general_user = not any(
|
|
|
|
|
role in set(self.converser_cog.DAVINCI_ROLES).union(set(self.converser_cog.CURIE_ROLES))
|
|
|
|
|
role
|
|
|
|
|
in set(self.converser_cog.DAVINCI_ROLES).union(
|
|
|
|
|
set(self.converser_cog.CURIE_ROLES)
|
|
|
|
|
)
|
|
|
|
|
for role in message.author.roles
|
|
|
|
|
)
|
|
|
|
|
admin_user = not any(
|
|
|
|
|
role in self.converser_cog.DAVINCI_ROLES for role in message.author.roles
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if (not admin_user and not general_user):
|
|
|
|
|
if not admin_user and not general_user:
|
|
|
|
|
return
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
@ -132,9 +162,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
await ctx.reply("Something went wrong. Please try again later.")
|
|
|
|
|
await ctx.reply(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@commands.command()
|
|
|
|
|
async def local_size(self, ctx):
|
|
|
|
|
# Get the size of the dall-e images folder that we have on the current system.
|
|
|
|
@ -180,20 +207,24 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SaveView(discord.ui.View):
|
|
|
|
|
def __init__(self, image_urls, cog,converser_cog, no_retry=False):
|
|
|
|
|
def __init__(self, image_urls, cog, converser_cog, no_retry=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.cog = cog
|
|
|
|
|
self.converser_cog = converser_cog
|
|
|
|
|
for x in range(1, len(image_urls)+1):
|
|
|
|
|
self.add_item(SaveButton(x, image_urls[x-1]))
|
|
|
|
|
self.add_item(VaryButton(x, image_urls[x-1], self.cog, converser_cog=self.converser_cog))
|
|
|
|
|
for x in range(1, len(image_urls) + 1):
|
|
|
|
|
self.add_item(SaveButton(x, image_urls[x - 1]))
|
|
|
|
|
self.add_item(
|
|
|
|
|
VaryButton(
|
|
|
|
|
x, image_urls[x - 1], self.cog, converser_cog=self.converser_cog
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if not no_retry:
|
|
|
|
|
self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VaryButton(discord.ui.Button):
|
|
|
|
|
def __init__(self, number, image_url, cog, converser_cog):
|
|
|
|
|
super().__init__(style=discord.ButtonStyle.blurple, label='Vary ' + str(number))
|
|
|
|
|
super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number))
|
|
|
|
|
self.number = number
|
|
|
|
|
self.image_url = image_url
|
|
|
|
|
self.cog = cog
|
|
|
|
@ -202,20 +233,27 @@ class VaryButton(discord.ui.Button):
|
|
|
|
|
async def callback(self, interaction: discord.Interaction):
|
|
|
|
|
user_id = interaction.user.id
|
|
|
|
|
interaction_id = interaction.message.id
|
|
|
|
|
print(f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}")
|
|
|
|
|
print(
|
|
|
|
|
f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}"
|
|
|
|
|
)
|
|
|
|
|
print(f"The current interaction message id is {interaction_id}")
|
|
|
|
|
print(f"The current interaction ID is {interaction.id}")
|
|
|
|
|
|
|
|
|
|
if (interaction_id not in self.converser_cog.users_to_interactions[user_id]):
|
|
|
|
|
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
|
|
|
|
|
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
|
|
|
|
|
interaction_id2 = interaction.id
|
|
|
|
|
if (interaction_id2 not in self.converser_cog.users_to_interactions[user_id]):
|
|
|
|
|
if (
|
|
|
|
|
interaction_id2
|
|
|
|
|
not in self.converser_cog.users_to_interactions[user_id]
|
|
|
|
|
):
|
|
|
|
|
await interaction.response.send_message(
|
|
|
|
|
content="You can not vary images in someone else's chain!", ephemeral=True
|
|
|
|
|
content="You can not vary images in someone else's chain!",
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
await interaction.response.send_message(
|
|
|
|
|
content="You can only vary for images that you generated yourself!", ephemeral=True
|
|
|
|
|
content="You can only vary for images that you generated yourself!",
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
@ -223,21 +261,28 @@ class VaryButton(discord.ui.Button):
|
|
|
|
|
response_message = await interaction.response.send_message(
|
|
|
|
|
content="Varying image number " + str(self.number) + "..."
|
|
|
|
|
)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(response_message.message.id)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(response_message.id)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(
|
|
|
|
|
response_message.message.id
|
|
|
|
|
)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(
|
|
|
|
|
response_message.id
|
|
|
|
|
)
|
|
|
|
|
prompt = redo_users[user_id].prompt
|
|
|
|
|
await self.cog.encapsulated_send(prompt, interaction.message, response_message=response_message,
|
|
|
|
|
vary=self.image_url, user_id=user_id)
|
|
|
|
|
|
|
|
|
|
await self.cog.encapsulated_send(
|
|
|
|
|
prompt,
|
|
|
|
|
interaction.message,
|
|
|
|
|
response_message=response_message,
|
|
|
|
|
vary=self.image_url,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class SaveButton(discord.ui.Button['SaveView']):
|
|
|
|
|
|
|
|
|
|
class SaveButton(discord.ui.Button["SaveView"]):
|
|
|
|
|
def __init__(self, number: int, image_url: str):
|
|
|
|
|
super().__init__(style=discord.ButtonStyle.gray, label='Save ' + str(number))
|
|
|
|
|
super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
|
|
|
|
|
self.number = number
|
|
|
|
|
self.image_url = image_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def callback(self, interaction: discord.Interaction):
|
|
|
|
|
# If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
|
|
|
|
|
# file to the user as an attachment.
|
|
|
|
@ -250,20 +295,22 @@ class SaveButton(discord.ui.Button['SaveView']):
|
|
|
|
|
|
|
|
|
|
await interaction.response.send_message(
|
|
|
|
|
content="Here is your image for download (open original and save)",
|
|
|
|
|
file=discord.File(temp_file.name), ephemeral=True
|
|
|
|
|
file=discord.File(temp_file.name),
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
await interaction.response.send_message(f'You can directly download this image from {self.image_url}',
|
|
|
|
|
ephemeral=True)
|
|
|
|
|
await interaction.response.send_message(
|
|
|
|
|
f"You can directly download this image from {self.image_url}",
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await interaction.response.send_message(f'Error: {e}', ephemeral=True)
|
|
|
|
|
await interaction.response.send_message(f"Error: {e}", ephemeral=True)
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RedoButton(discord.ui.Button['SaveView']):
|
|
|
|
|
|
|
|
|
|
class RedoButton(discord.ui.Button["SaveView"]):
|
|
|
|
|
def __init__(self, cog, converser_cog):
|
|
|
|
|
super().__init__(style=discord.ButtonStyle.danger, label='Retry')
|
|
|
|
|
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
|
|
|
|
|
self.cog = cog
|
|
|
|
|
self.converser_cog = converser_cog
|
|
|
|
|
|
|
|
|
@ -271,9 +318,10 @@ class RedoButton(discord.ui.Button['SaveView']):
|
|
|
|
|
user_id = interaction.user.id
|
|
|
|
|
interaction_id = interaction.message.id
|
|
|
|
|
|
|
|
|
|
if (interaction_id not in self.converser_cog.users_to_interactions[user_id]):
|
|
|
|
|
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
|
|
|
|
|
await interaction.response.send_message(
|
|
|
|
|
content="You can only retry for prompts that you generated yourself!", ephemeral=True
|
|
|
|
|
content="You can only retry for prompts that you generated yourself!",
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
@ -285,7 +333,9 @@ class RedoButton(discord.ui.Button['SaveView']):
|
|
|
|
|
prompt = redo_users[user_id].prompt
|
|
|
|
|
response_message = redo_users[user_id].response_message
|
|
|
|
|
message = await interaction.response.send_message(
|
|
|
|
|
f'Regenerating the image for your original prompt, check the original message.', ephemeral=True)
|
|
|
|
|
f"Regenerating the image for your original prompt, check the original message.",
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
)
|
|
|
|
|
self.converser_cog.users_to_interactions[user_id].append(message.id)
|
|
|
|
|
|
|
|
|
|
await self.cog.encapsulated_send(prompt, message, response_message)
|
|
|
|
|