Format Python code with psf/black push

github-actions 1 year ago
parent d1c2d24b1c
commit 1312a864df

@ -24,8 +24,11 @@ class RedoUser:
redo_users = {}
users_to_interactions = {}
class DrawDallEService(commands.Cog, name="DrawDallEService"):
def __init__(self, bot, usage_service, model, message_queue, deletion_queue, converser_cog):
def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
):
self.bot = bot
self.usage_service = usage_service
self.model = model
@ -47,8 +50,15 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
)
print(f"Image prompt optimizer was added")
async def encapsulated_send(self, prompt, message, response_message=None, vary=None, draw_from_optimizer=None, user_id=None):
async def encapsulated_send(
self,
prompt,
message,
response_message=None,
vary=None,
draw_from_optimizer=None,
user_id=None,
):
# send the prompt to the model
file, image_urls = self.model.send_image_request(
prompt, vary=vary if not draw_from_optimizer else None
@ -56,7 +66,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
# Start building an embed to send to the user with the results of the image generation
embed = discord.Embed(
title="Image Generation Results" if not vary else "Image Generation Results (Varying)" if not draw_from_optimizer else "Image Generation Results (Drawing from Optimizer)",
title="Image Generation Results"
if not vary
else "Image Generation Results (Varying)"
if not draw_from_optimizer
else "Image Generation Results (Drawing from Optimizer)",
description=f"{prompt}",
color=0xC730C7,
)
@ -64,42 +78,55 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
# Add the image file to the embed
embed.set_image(url=f"attachment://{file.filename}")
if not response_message: # Original generation case
# Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)
result_message = await message.channel.send(
embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)
embed=embed,
file=file,
view=SaveView(image_urls, self, self.converser_cog),
)
self.converser_cog.users_to_interactions[message.author.id] = []
self.converser_cog.users_to_interactions[message.author.id].append(result_message.id)
redo_users[message.author.id] = RedoUser(
prompt, message, result_message
self.converser_cog.users_to_interactions[message.author.id].append(
result_message.id
)
redo_users[message.author.id] = RedoUser(prompt, message, result_message)
else:
if not vary: # Editing case
message = await response_message.edit(embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog))
message = await response_message.edit(
embed=embed,
file=file,
view=SaveView(image_urls, self, self.converser_cog),
)
else: # Varying case
if not draw_from_optimizer:
result_message = await response_message.edit_original_response(content="Image variation completed!",
embed=embed, file=file,
view=SaveView(image_urls, self, self.converser_cog, True))
result_message = await response_message.edit_original_response(
content="Image variation completed!",
embed=embed,
file=file,
view=SaveView(image_urls, self, self.converser_cog, True),
)
redo_users[message.author.id] = RedoUser(
prompt, message, result_message
)
else:
result_message = await response_message.edit_original_response(
content="I've drawn the optimized prompt!", embed=embed, file=file,
view=SaveView(image_urls, self, self.converser_cog))
redo_users[user_id] = RedoUser(
prompt, message, result_message
content="I've drawn the optimized prompt!",
embed=embed,
file=file,
view=SaveView(image_urls, self, self.converser_cog),
)
if (user_id):
self.converser_cog.users_to_interactions[user_id].append(response_message.id)
self.converser_cog.users_to_interactions[user_id].append(result_message.id)
redo_users[user_id] = RedoUser(prompt, message, result_message)
if user_id:
self.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
self.converser_cog.users_to_interactions[user_id].append(
result_message.id
)
@commands.command()
async def draw(self, ctx, *args):
@ -110,14 +137,17 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
# Only allow the bot to be used by people who have the role "Admin" or "GPT"
general_user = not any(
role in set(self.converser_cog.DAVINCI_ROLES).union(set(self.converser_cog.CURIE_ROLES))
role
in set(self.converser_cog.DAVINCI_ROLES).union(
set(self.converser_cog.CURIE_ROLES)
)
for role in message.author.roles
)
admin_user = not any(
role in self.converser_cog.DAVINCI_ROLES for role in message.author.roles
)
if (not admin_user and not general_user):
if not admin_user and not general_user:
return
try:
@ -132,9 +162,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
await ctx.reply("Something went wrong. Please try again later.")
await ctx.reply(e)
@commands.command()
async def local_size(self, ctx):
# Get the size of the dall-e images folder that we have on the current system.
@ -180,20 +207,24 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
class SaveView(discord.ui.View):
def __init__(self, image_urls, cog,converser_cog, no_retry=False):
def __init__(self, image_urls, cog, converser_cog, no_retry=False):
super().__init__()
self.cog = cog
self.converser_cog = converser_cog
for x in range(1, len(image_urls)+1):
self.add_item(SaveButton(x, image_urls[x-1]))
self.add_item(VaryButton(x, image_urls[x-1], self.cog, converser_cog=self.converser_cog))
for x in range(1, len(image_urls) + 1):
self.add_item(SaveButton(x, image_urls[x - 1]))
self.add_item(
VaryButton(
x, image_urls[x - 1], self.cog, converser_cog=self.converser_cog
)
)
if not no_retry:
self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog))
class VaryButton(discord.ui.Button):
def __init__(self, number, image_url, cog, converser_cog):
super().__init__(style=discord.ButtonStyle.blurple, label='Vary ' + str(number))
super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number))
self.number = number
self.image_url = image_url
self.cog = cog
@ -202,20 +233,27 @@ class VaryButton(discord.ui.Button):
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
print(f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}")
print(
f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}"
)
print(f"The current interaction message id is {interaction_id}")
print(f"The current interaction ID is {interaction.id}")
if (interaction_id not in self.converser_cog.users_to_interactions[user_id]):
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id
if (interaction_id2 not in self.converser_cog.users_to_interactions[user_id]):
if (
interaction_id2
not in self.converser_cog.users_to_interactions[user_id]
):
await interaction.response.send_message(
content="You can not vary images in someone else's chain!", ephemeral=True
content="You can not vary images in someone else's chain!",
ephemeral=True,
)
else:
await interaction.response.send_message(
content="You can only vary for images that you generated yourself!", ephemeral=True
content="You can only vary for images that you generated yourself!",
ephemeral=True,
)
return
@ -223,21 +261,28 @@ class VaryButton(discord.ui.Button):
response_message = await interaction.response.send_message(
content="Varying image number " + str(self.number) + "..."
)
self.converser_cog.users_to_interactions[user_id].append(response_message.message.id)
self.converser_cog.users_to_interactions[user_id].append(response_message.id)
self.converser_cog.users_to_interactions[user_id].append(
response_message.message.id
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
prompt = redo_users[user_id].prompt
await self.cog.encapsulated_send(prompt, interaction.message, response_message=response_message,
vary=self.image_url, user_id=user_id)
await self.cog.encapsulated_send(
prompt,
interaction.message,
response_message=response_message,
vary=self.image_url,
user_id=user_id,
)
class SaveButton(discord.ui.Button['SaveView']):
class SaveButton(discord.ui.Button["SaveView"]):
def __init__(self, number: int, image_url: str):
super().__init__(style=discord.ButtonStyle.gray, label='Save ' + str(number))
super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
self.number = number
self.image_url = image_url
async def callback(self, interaction: discord.Interaction):
# If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
# file to the user as an attachment.
@ -250,20 +295,22 @@ class SaveButton(discord.ui.Button['SaveView']):
await interaction.response.send_message(
content="Here is your image for download (open original and save)",
file=discord.File(temp_file.name), ephemeral=True
file=discord.File(temp_file.name),
ephemeral=True,
)
else:
await interaction.response.send_message(f'You can directly download this image from {self.image_url}',
ephemeral=True)
await interaction.response.send_message(
f"You can directly download this image from {self.image_url}",
ephemeral=True,
)
except Exception as e:
await interaction.response.send_message(f'Error: {e}', ephemeral=True)
await interaction.response.send_message(f"Error: {e}", ephemeral=True)
traceback.print_exc()
class RedoButton(discord.ui.Button['SaveView']):
class RedoButton(discord.ui.Button["SaveView"]):
def __init__(self, cog, converser_cog):
super().__init__(style=discord.ButtonStyle.danger, label='Retry')
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.cog = cog
self.converser_cog = converser_cog
@ -271,9 +318,10 @@ class RedoButton(discord.ui.Button['SaveView']):
user_id = interaction.user.id
interaction_id = interaction.message.id
if (interaction_id not in self.converser_cog.users_to_interactions[user_id]):
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only retry for prompts that you generated yourself!", ephemeral=True
content="You can only retry for prompts that you generated yourself!",
ephemeral=True,
)
return
@ -285,7 +333,9 @@ class RedoButton(discord.ui.Button['SaveView']):
prompt = redo_users[user_id].prompt
response_message = redo_users[user_id].response_message
message = await interaction.response.send_message(
f'Regenerating the image for your original prompt, check the original message.', ephemeral=True)
f"Regenerating the image for your original prompt, check the original message.",
ephemeral=True,
)
self.converser_cog.users_to_interactions[user_id].append(message.id)
await self.cog.encapsulated_send(prompt, message, response_message)

@ -107,7 +107,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
)
print(f"Draw service was added")
@commands.command()
async def delete_all_conversation_threads(self, ctx):
# If the user has ADMIN_ROLES
@ -215,17 +214,28 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
# Create a two-column embed to display the settings, use \u200b to create a blank space
embed.add_field(
name="Setting",
value="\n".join([key for key in self.model.__dict__.keys() if key not in self.model._hidden_attributes]),
value="\n".join(
[
key
for key in self.model.__dict__.keys()
if key not in self.model._hidden_attributes
]
),
inline=True,
)
embed.add_field(
name="Value",
value="\n".join([str(value) for key, value in self.model.__dict__.items() if key not in self.model._hidden_attributes]),
value="\n".join(
[
str(value)
for key, value in self.model.__dict__.items()
if key not in self.model._hidden_attributes
]
),
inline=True,
)
await message.channel.send(embed=embed)
async def process_settings_command(self, message):
# Extract the parameter and the value
parameter = message.content[4:].split()[0]
@ -529,7 +539,8 @@ class RedoButtonView(
discord.ui.View
): # Create a class called MyView that subclasses discord.ui.View
@discord.ui.button(
label="Retry", style=discord.ButtonStyle.danger,
label="Retry",
style=discord.ButtonStyle.danger,
) # Create a button with the label "😎 Click me!" with color Blurple
async def button_callback(self, button, interaction):
msg = await interaction.response.send_message(

@ -11,6 +11,7 @@ from models.deletion_service import Deletion
redo_users = {}
class RedoUser:
def __init__(self, prompt, message, response):
self.prompt = prompt
@ -22,7 +23,14 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog, image_service_cog
self,
bot,
usage_service,
model,
message_queue,
deletion_queue,
converser_cog,
image_service_cog,
):
self.bot = bot
self.usage_service = usage_service
@ -43,7 +51,6 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
traceback.print_exc()
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
@commands.command()
async def imgoptimize(self, ctx, *args):
@ -79,10 +86,16 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
response_message = await ctx.reply(response_text)
self.converser_cog.users_to_interactions[ctx.message.author.id] = []
self.converser_cog.users_to_interactions[ctx.message.author.id].append(response_message.id)
self.converser_cog.users_to_interactions[ctx.message.author.id].append(
response_message.id
)
redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message)
await response_message.edit(view=OptimizeView(self.converser_cog, self.image_service_cog, self.deletion_queue))
await response_message.edit(
view=OptimizeView(
self.converser_cog, self.image_service_cog, self.deletion_queue
)
)
# Catch the value errors raised by the Model object
except ValueError as e:
@ -107,10 +120,10 @@ class OptimizeView(discord.ui.View):
self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue))
self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue))
class DrawButton(discord.ui.Button['OptimizeView']):
class DrawButton(discord.ui.Button["OptimizeView"]):
def __init__(self, converser_cog, image_service_cog, deletion_queue):
super().__init__(style=discord.ButtonStyle.green, label='Draw')
super().__init__(style=discord.ButtonStyle.green, label="Draw")
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
@ -122,7 +135,8 @@ class DrawButton(discord.ui.Button['OptimizeView']):
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only draw for prompts that you generated yourself!", ephemeral=True
content="You can only draw for prompts that you generated yourself!",
ephemeral=True,
)
return
@ -130,9 +144,12 @@ class DrawButton(discord.ui.Button['OptimizeView']):
"Drawing this prompt!", ephemeral=False
)
self.converser_cog.users_to_interactions[interaction.user.id].append(msg.id)
self.converser_cog.users_to_interactions[interaction.user.id].append(interaction.id)
self.converser_cog.users_to_interactions[interaction.user.id].append(interaction.message.id)
self.converser_cog.users_to_interactions[interaction.user.id].append(
interaction.id
)
self.converser_cog.users_to_interactions[interaction.user.id].append(
interaction.message.id
)
# get the text content of the message that was interacted with
prompt = interaction.message.content
@ -142,13 +159,14 @@ class DrawButton(discord.ui.Button['OptimizeView']):
prompt = re.sub(r"Output Prompt: ?", "", prompt)
# Call the image service cog to draw the image
await self.image_service_cog.encapsulated_send(prompt, None, msg, True, True, user_id)
await self.image_service_cog.encapsulated_send(
prompt, None, msg, True, True, user_id
)
class RedoButton(discord.ui.Button['OptimizeView']):
class RedoButton(discord.ui.Button["OptimizeView"]):
def __init__(self, converser_cog, image_service_cog, deletion_queue):
super().__init__(style=discord.ButtonStyle.danger, label='Retry')
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
@ -159,7 +177,8 @@ class RedoButton(discord.ui.Button['OptimizeView']):
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only redo for prompts that you generated yourself!", ephemeral=True
content="You can only redo for prompts that you generated yourself!",
ephemeral=True,
)
return
@ -181,4 +200,6 @@ class RedoButton(discord.ui.Button['OptimizeView']):
message = redo_users[user_id].message
prompt = redo_users[user_id].prompt
response_message = redo_users[user_id].response
await self.converser_cog.encapsulated_send(message, prompt, response_message)
await self.converser_cog.encapsulated_send(
message, prompt, response_message
)

@ -22,6 +22,7 @@ class Models:
DAVINCI = "text-davinci-003"
CURIE = "text-curie-001"
class ImageSize:
LARGE = "1024x1024"
MEDIUM = "512x512"
@ -58,7 +59,13 @@ class Model:
os.makedirs(self.IMAGE_SAVE_PATH)
self.custom_image_path = False
self._hidden_attributes = ["usage_service", "DAVINCI_ROLES", "custom_image_path", "custom_web_root", "_hidden_attributes"]
self._hidden_attributes = [
"usage_service",
"DAVINCI_ROLES",
"custom_image_path",
"custom_web_root",
"_hidden_attributes",
]
openai.api_key = os.getenv("OPENAI_TOKEN")
@ -73,7 +80,9 @@ class Model:
if value in ImageSize.__dict__.values():
self._image_size = value
else:
raise ValueError("Image size must be one of the following: SMALL(256x256), MEDIUM(512x512), LARGE(1024x1024)")
raise ValueError(
"Image size must be one of the following: SMALL(256x256), MEDIUM(512x512), LARGE(1024x1024)"
)
@property
def num_images(self):
@ -320,8 +329,8 @@ class Model:
print(response.__dict__)
image_urls = []
for result in response['data']:
image_urls.append(result['url'])
for result in response["data"]:
image_urls.append(result["url"])
# For each image url, open it as an image object using PIL
images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
@ -348,10 +357,10 @@ class Model:
height = max(heights) * num_rows
# Create a transparent image with the same size as the images
transparent = Image.new('RGBA', (max(widths), max(heights)))
transparent = Image.new("RGBA", (max(widths), max(heights)))
# Create a new image with the calculated size
new_im = Image.new('RGBA', (width, height))
new_im = Image.new("RGBA", (width, height))
# Paste the images and transparent segments into the grid
x_offset = y_offset = 0
@ -370,7 +379,6 @@ class Model:
x_offset = 0
y_offset += transparent.size[1]
# Save the new_im to a temporary file and return it as a discord.File
temp_file = tempfile.NamedTemporaryFile(suffix=".png")
new_im.save(temp_file.name)
@ -383,8 +391,12 @@ class Model:
safety_counter = 0
while image_size > 8 or safety_counter >= 2:
safety_counter += 1
print(f"Image size is {image_size}MB, which is too large for discord. Downscaling and trying again")
new_im = new_im.resize((int(new_im.width / 1.05), int(new_im.height / 1.05)))
print(
f"Image size is {image_size}MB, which is too large for discord. Downscaling and trying again"
)
new_im = new_im.resize(
(int(new_im.width / 1.05), int(new_im.height / 1.05))
)
temp_file = tempfile.NamedTemporaryFile(suffix=".png")
new_im.save(temp_file.name)
image_size = os.path.getsize(temp_file.name) / 1000000

Loading…
Cancel
Save