From 1312a864dfc5006acfd74b3af28f8371184fdb3d Mon Sep 17 00:00:00 2001 From: github-actions <${GITHUB_ACTOR}@users.noreply.github.com> Date: Sun, 18 Dec 2022 16:23:38 +0000 Subject: [PATCH] Format Python code with psf/black push --- cogs/draw_image_generation.py | 160 ++++++++++++++++++--------- cogs/gpt_3_commands_and_converser.py | 21 +++- cogs/image_prompt_optimizer.py | 53 ++++++--- models/openai_model.py | 30 +++-- 4 files changed, 179 insertions(+), 85 deletions(-) diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index cfbc011..38318cb 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -24,8 +24,11 @@ class RedoUser: redo_users = {} users_to_interactions = {} + class DrawDallEService(commands.Cog, name="DrawDallEService"): - def __init__(self, bot, usage_service, model, message_queue, deletion_queue, converser_cog): + def __init__( + self, bot, usage_service, model, message_queue, deletion_queue, converser_cog + ): self.bot = bot self.usage_service = usage_service self.model = model @@ -47,8 +50,15 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): ) print(f"Image prompt optimizer was added") - - async def encapsulated_send(self, prompt, message, response_message=None, vary=None, draw_from_optimizer=None, user_id=None): + async def encapsulated_send( + self, + prompt, + message, + response_message=None, + vary=None, + draw_from_optimizer=None, + user_id=None, + ): # send the prompt to the model file, image_urls = self.model.send_image_request( prompt, vary=vary if not draw_from_optimizer else None @@ -56,7 +66,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): # Start building an embed to send to the user with the results of the image generation embed = discord.Embed( - title="Image Generation Results" if not vary else "Image Generation Results (Varying)" if not draw_from_optimizer else "Image Generation Results (Drawing from Optimizer)", + title="Image Generation Results" + if not vary + else "Image Generation Results (Varying)" + if not draw_from_optimizer + else "Image Generation Results (Drawing from Optimizer)", description=f"{prompt}", color=0xC730C7, ) @@ -64,42 +78,55 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): # Add the image file to the embed embed.set_image(url=f"attachment://{file.filename}") - if not response_message: # Original generation case # Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog) result_message = await message.channel.send( - embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog) + embed=embed, + file=file, + view=SaveView(image_urls, self, self.converser_cog), ) self.converser_cog.users_to_interactions[message.author.id] = [] - self.converser_cog.users_to_interactions[message.author.id].append(result_message.id) - - redo_users[message.author.id] = RedoUser( - prompt, message, result_message + self.converser_cog.users_to_interactions[message.author.id].append( + result_message.id ) + + redo_users[message.author.id] = RedoUser(prompt, message, result_message) else: if not vary: # Editing case - message = await response_message.edit(embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)) + message = await response_message.edit( + embed=embed, + file=file, + view=SaveView(image_urls, self, self.converser_cog), + ) else: # Varying case if not draw_from_optimizer: - result_message = await response_message.edit_original_response(content="Image variation completed!", - embed=embed, file=file, - view=SaveView(image_urls, self, self.converser_cog, True)) + result_message = await response_message.edit_original_response( + content="Image variation completed!", + embed=embed, + file=file, + view=SaveView(image_urls, self, self.converser_cog, True), + ) redo_users[message.author.id] = RedoUser( prompt, message, result_message ) else: result_message = await response_message.edit_original_response( - content="I've drawn the optimized prompt!", embed=embed, file=file, - view=SaveView(image_urls, self, self.converser_cog)) - - redo_users[user_id] = RedoUser( - prompt, message, result_message + content="I've drawn the optimized prompt!", + embed=embed, + file=file, + view=SaveView(image_urls, self, self.converser_cog), ) - if (user_id): - self.converser_cog.users_to_interactions[user_id].append(response_message.id) - self.converser_cog.users_to_interactions[user_id].append(result_message.id) + redo_users[user_id] = RedoUser(prompt, message, result_message) + + if user_id: + self.converser_cog.users_to_interactions[user_id].append( + response_message.id + ) + self.converser_cog.users_to_interactions[user_id].append( + result_message.id + ) @commands.command() async def draw(self, ctx, *args): @@ -110,14 +137,17 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): # Only allow the bot to be used by people who have the role "Admin" or "GPT" general_user = not any( - role in set(self.converser_cog.DAVINCI_ROLES).union(set(self.converser_cog.CURIE_ROLES)) + role + in set(self.converser_cog.DAVINCI_ROLES).union( + set(self.converser_cog.CURIE_ROLES) + ) for role in message.author.roles ) admin_user = not any( role in self.converser_cog.DAVINCI_ROLES for role in message.author.roles ) - if (not admin_user and not general_user): + if not admin_user and not general_user: return try: @@ -132,9 +162,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): await ctx.reply("Something went wrong. Please try again later.") await ctx.reply(e) - - - @commands.command() async def local_size(self, ctx): # Get the size of the dall-e images folder that we have on the current system. @@ -180,20 +207,24 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): class SaveView(discord.ui.View): - def __init__(self, image_urls, cog,converser_cog, no_retry=False): + def __init__(self, image_urls, cog, converser_cog, no_retry=False): super().__init__() self.cog = cog self.converser_cog = converser_cog - for x in range(1, len(image_urls)+1): - self.add_item(SaveButton(x, image_urls[x-1])) - self.add_item(VaryButton(x, image_urls[x-1], self.cog, converser_cog=self.converser_cog)) + for x in range(1, len(image_urls) + 1): + self.add_item(SaveButton(x, image_urls[x - 1])) + self.add_item( + VaryButton( + x, image_urls[x - 1], self.cog, converser_cog=self.converser_cog + ) + ) if not no_retry: self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog)) class VaryButton(discord.ui.Button): def __init__(self, number, image_url, cog, converser_cog): - super().__init__(style=discord.ButtonStyle.blurple, label='Vary ' + str(number)) + super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number)) self.number = number self.image_url = image_url self.cog = cog @@ -202,20 +233,27 @@ class VaryButton(discord.ui.Button): async def callback(self, interaction: discord.Interaction): user_id = interaction.user.id interaction_id = interaction.message.id - print(f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}") + print( + f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}" + ) print(f"The current interaction message id is {interaction_id}") print(f"The current interaction ID is {interaction.id}") - if (interaction_id not in self.converser_cog.users_to_interactions[user_id]): + if interaction_id not in self.converser_cog.users_to_interactions[user_id]: if len(self.converser_cog.users_to_interactions[user_id]) >= 2: interaction_id2 = interaction.id - if (interaction_id2 not in self.converser_cog.users_to_interactions[user_id]): + if ( + interaction_id2 + not in self.converser_cog.users_to_interactions[user_id] + ): await interaction.response.send_message( - content="You can not vary images in someone else's chain!", ephemeral=True + content="You can not vary images in someone else's chain!", + ephemeral=True, ) else: await interaction.response.send_message( - content="You can only vary for images that you generated yourself!", ephemeral=True + content="You can only vary for images that you generated yourself!", + ephemeral=True, ) return @@ -223,21 +261,28 @@ class VaryButton(discord.ui.Button): response_message = await interaction.response.send_message( content="Varying image number " + str(self.number) + "..." ) - self.converser_cog.users_to_interactions[user_id].append(response_message.message.id) - self.converser_cog.users_to_interactions[user_id].append(response_message.id) + self.converser_cog.users_to_interactions[user_id].append( + response_message.message.id + ) + self.converser_cog.users_to_interactions[user_id].append( + response_message.id + ) prompt = redo_users[user_id].prompt - await self.cog.encapsulated_send(prompt, interaction.message, response_message=response_message, - vary=self.image_url, user_id=user_id) - + await self.cog.encapsulated_send( + prompt, + interaction.message, + response_message=response_message, + vary=self.image_url, + user_id=user_id, + ) -class SaveButton(discord.ui.Button['SaveView']): +class SaveButton(discord.ui.Button["SaveView"]): def __init__(self, number: int, image_url: str): - super().__init__(style=discord.ButtonStyle.gray, label='Save ' + str(number)) + super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number)) self.number = number self.image_url = image_url - async def callback(self, interaction: discord.Interaction): # If the image url doesn't start with "http", then we need to read the file from the URI, and then send the # file to the user as an attachment. @@ -250,20 +295,22 @@ class SaveButton(discord.ui.Button['SaveView']): await interaction.response.send_message( content="Here is your image for download (open original and save)", - file=discord.File(temp_file.name), ephemeral=True + file=discord.File(temp_file.name), + ephemeral=True, ) else: - await interaction.response.send_message(f'You can directly download this image from {self.image_url}', - ephemeral=True) + await interaction.response.send_message( + f"You can directly download this image from {self.image_url}", + ephemeral=True, + ) except Exception as e: - await interaction.response.send_message(f'Error: {e}', ephemeral=True) + await interaction.response.send_message(f"Error: {e}", ephemeral=True) traceback.print_exc() -class RedoButton(discord.ui.Button['SaveView']): - +class RedoButton(discord.ui.Button["SaveView"]): def __init__(self, cog, converser_cog): - super().__init__(style=discord.ButtonStyle.danger, label='Retry') + super().__init__(style=discord.ButtonStyle.danger, label="Retry") self.cog = cog self.converser_cog = converser_cog @@ -271,9 +318,10 @@ class RedoButton(discord.ui.Button['SaveView']): user_id = interaction.user.id interaction_id = interaction.message.id - if (interaction_id not in self.converser_cog.users_to_interactions[user_id]): + if interaction_id not in self.converser_cog.users_to_interactions[user_id]: await interaction.response.send_message( - content="You can only retry for prompts that you generated yourself!", ephemeral=True + content="You can only retry for prompts that you generated yourself!", + ephemeral=True, ) return @@ -285,7 +333,9 @@ class RedoButton(discord.ui.Button['SaveView']): prompt = redo_users[user_id].prompt response_message = redo_users[user_id].response_message message = await interaction.response.send_message( - f'Regenerating the image for your original prompt, check the original message.', ephemeral=True) + f"Regenerating the image for your original prompt, check the original message.", + ephemeral=True, + ) self.converser_cog.users_to_interactions[user_id].append(message.id) await self.cog.encapsulated_send(prompt, message, response_message) diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index 130fe7e..4a521f6 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -107,7 +107,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): ) print(f"Draw service was added") - @commands.command() async def delete_all_conversation_threads(self, ctx): # If the user has ADMIN_ROLES @@ -215,17 +214,28 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # Create a two-column embed to display the settings, use \u200b to create a blank space embed.add_field( name="Setting", - value="\n".join([key for key in self.model.__dict__.keys() if key not in self.model._hidden_attributes]), + value="\n".join( + [ + key + for key in self.model.__dict__.keys() + if key not in self.model._hidden_attributes + ] + ), inline=True, ) embed.add_field( name="Value", - value="\n".join([str(value) for key, value in self.model.__dict__.items() if key not in self.model._hidden_attributes]), + value="\n".join( + [ + str(value) + for key, value in self.model.__dict__.items() + if key not in self.model._hidden_attributes + ] + ), inline=True, ) await message.channel.send(embed=embed) - async def process_settings_command(self, message): # Extract the parameter and the value parameter = message.content[4:].split()[0] @@ -529,7 +539,8 @@ class RedoButtonView( discord.ui.View ): # Create a class called MyView that subclasses discord.ui.View @discord.ui.button( - label="Retry", style=discord.ButtonStyle.danger, + label="Retry", + style=discord.ButtonStyle.danger, ) # Create a button with the label "😎 Click me!" with color Blurple async def button_callback(self, button, interaction): msg = await interaction.response.send_message( diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index 325b619..d767e35 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -11,6 +11,7 @@ from models.deletion_service import Deletion redo_users = {} + class RedoUser: def __init__(self, prompt, message, response): self.prompt = prompt @@ -22,7 +23,14 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" def __init__( - self, bot, usage_service, model, message_queue, deletion_queue, converser_cog, image_service_cog + self, + bot, + usage_service, + model, + message_queue, + deletion_queue, + converser_cog, + image_service_cog, ): self.bot = bot self.usage_service = usage_service @@ -43,7 +51,6 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): traceback.print_exc() self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT - @commands.command() async def imgoptimize(self, ctx, *args): @@ -79,10 +86,16 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): response_message = await ctx.reply(response_text) self.converser_cog.users_to_interactions[ctx.message.author.id] = [] - self.converser_cog.users_to_interactions[ctx.message.author.id].append(response_message.id) + self.converser_cog.users_to_interactions[ctx.message.author.id].append( + response_message.id + ) redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message) - await response_message.edit(view=OptimizeView(self.converser_cog, self.image_service_cog, self.deletion_queue)) + await response_message.edit( + view=OptimizeView( + self.converser_cog, self.image_service_cog, self.deletion_queue + ) + ) # Catch the value errors raised by the Model object except ValueError as e: @@ -107,10 +120,10 @@ class OptimizeView(discord.ui.View): self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue)) self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue)) -class DrawButton(discord.ui.Button['OptimizeView']): +class DrawButton(discord.ui.Button["OptimizeView"]): def __init__(self, converser_cog, image_service_cog, deletion_queue): - super().__init__(style=discord.ButtonStyle.green, label='Draw') + super().__init__(style=discord.ButtonStyle.green, label="Draw") self.converser_cog = converser_cog self.image_service_cog = image_service_cog self.deletion_queue = deletion_queue @@ -122,7 +135,8 @@ class DrawButton(discord.ui.Button['OptimizeView']): if interaction_id not in self.converser_cog.users_to_interactions[user_id]: await interaction.response.send_message( - content="You can only draw for prompts that you generated yourself!", ephemeral=True + content="You can only draw for prompts that you generated yourself!", + ephemeral=True, ) return @@ -130,9 +144,12 @@ class DrawButton(discord.ui.Button['OptimizeView']): "Drawing this prompt!", ephemeral=False ) self.converser_cog.users_to_interactions[interaction.user.id].append(msg.id) - self.converser_cog.users_to_interactions[interaction.user.id].append(interaction.id) - self.converser_cog.users_to_interactions[interaction.user.id].append(interaction.message.id) - + self.converser_cog.users_to_interactions[interaction.user.id].append( + interaction.id + ) + self.converser_cog.users_to_interactions[interaction.user.id].append( + interaction.message.id + ) # get the text content of the message that was interacted with prompt = interaction.message.content @@ -142,13 +159,14 @@ class DrawButton(discord.ui.Button['OptimizeView']): prompt = re.sub(r"Output Prompt: ?", "", prompt) # Call the image service cog to draw the image - await self.image_service_cog.encapsulated_send(prompt, None, msg, True, True, user_id) - + await self.image_service_cog.encapsulated_send( + prompt, None, msg, True, True, user_id + ) -class RedoButton(discord.ui.Button['OptimizeView']): +class RedoButton(discord.ui.Button["OptimizeView"]): def __init__(self, converser_cog, image_service_cog, deletion_queue): - super().__init__(style=discord.ButtonStyle.danger, label='Retry') + super().__init__(style=discord.ButtonStyle.danger, label="Retry") self.converser_cog = converser_cog self.image_service_cog = image_service_cog self.deletion_queue = deletion_queue @@ -159,7 +177,8 @@ class RedoButton(discord.ui.Button['OptimizeView']): if interaction_id not in self.converser_cog.users_to_interactions[user_id]: await interaction.response.send_message( - content="You can only redo for prompts that you generated yourself!", ephemeral=True + content="You can only redo for prompts that you generated yourself!", + ephemeral=True, ) return @@ -181,4 +200,6 @@ class RedoButton(discord.ui.Button['OptimizeView']): message = redo_users[user_id].message prompt = redo_users[user_id].prompt response_message = redo_users[user_id].response - await self.converser_cog.encapsulated_send(message, prompt, response_message) \ No newline at end of file + await self.converser_cog.encapsulated_send( + message, prompt, response_message + ) diff --git a/models/openai_model.py b/models/openai_model.py index 7e68442..162cefe 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -22,6 +22,7 @@ class Models: DAVINCI = "text-davinci-003" CURIE = "text-curie-001" + class ImageSize: LARGE = "1024x1024" MEDIUM = "512x512" @@ -58,7 +59,13 @@ class Model: os.makedirs(self.IMAGE_SAVE_PATH) self.custom_image_path = False - self._hidden_attributes = ["usage_service", "DAVINCI_ROLES", "custom_image_path", "custom_web_root", "_hidden_attributes"] + self._hidden_attributes = [ + "usage_service", + "DAVINCI_ROLES", + "custom_image_path", + "custom_web_root", + "_hidden_attributes", + ] openai.api_key = os.getenv("OPENAI_TOKEN") @@ -73,7 +80,9 @@ class Model: if value in ImageSize.__dict__.values(): self._image_size = value else: - raise ValueError("Image size must be one of the following: SMALL(256x256), MEDIUM(512x512), LARGE(1024x1024)") + raise ValueError( + "Image size must be one of the following: SMALL(256x256), MEDIUM(512x512), LARGE(1024x1024)" + ) @property def num_images(self): @@ -320,8 +329,8 @@ class Model: print(response.__dict__) image_urls = [] - for result in response['data']: - image_urls.append(result['url']) + for result in response["data"]: + image_urls.append(result["url"]) # For each image url, open it as an image object using PIL images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls] @@ -348,10 +357,10 @@ class Model: height = max(heights) * num_rows # Create a transparent image with the same size as the images - transparent = Image.new('RGBA', (max(widths), max(heights))) + transparent = Image.new("RGBA", (max(widths), max(heights))) # Create a new image with the calculated size - new_im = Image.new('RGBA', (width, height)) + new_im = Image.new("RGBA", (width, height)) # Paste the images and transparent segments into the grid x_offset = y_offset = 0 @@ -370,7 +379,6 @@ class Model: x_offset = 0 y_offset += transparent.size[1] - # Save the new_im to a temporary file and return it as a discord.File temp_file = tempfile.NamedTemporaryFile(suffix=".png") new_im.save(temp_file.name) @@ -383,8 +391,12 @@ class Model: safety_counter = 0 while image_size > 8 or safety_counter >= 2: safety_counter += 1 - print(f"Image size is {image_size}MB, which is too large for discord. Downscaling and trying again") - new_im = new_im.resize((int(new_im.width / 1.05), int(new_im.height / 1.05))) + print( + f"Image size is {image_size}MB, which is too large for discord. Downscaling and trying again" + ) + new_im = new_im.resize( + (int(new_im.width / 1.05), int(new_im.height / 1.05)) + ) temp_file = tempfile.NamedTemporaryFile(suffix=".png") new_im.save(temp_file.name) image_size = os.path.getsize(temp_file.name) / 1000000