From 0065f4918a31cb6f7510df52d8a9ab4e00406332 Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Sun, 18 Dec 2022 11:22:57 -0500 Subject: [PATCH] Add full DALL-E Integration --- cogs/draw_image_generation.py | 291 +++++++++++++++++++++++++++ cogs/gpt_3_commands_and_converser.py | 44 +++- cogs/image_prompt_optimizer.py | 88 ++++++-- conversation_starter_pretext.txt | 3 +- image_optimizer_pretext.txt | 1 - models/openai_model.py | 150 ++++++++++++++ models/usage_service_model.py | 20 ++ requirements.txt | 2 + 8 files changed, 573 insertions(+), 26 deletions(-) create mode 100644 cogs/draw_image_generation.py diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py new file mode 100644 index 0000000..cfbc011 --- /dev/null +++ b/cogs/draw_image_generation.py @@ -0,0 +1,291 @@ +import datetime +import os +import re +import tempfile +import traceback +import uuid +from collections import defaultdict +from io import BytesIO + +import discord +from PIL import Image +from discord.ext import commands + +from cogs.image_prompt_optimizer import ImgPromptOptimizer + + +class RedoUser: + def __init__(self, prompt, message, response_message): + self.prompt = prompt + self.message = message + self.response_message = response_message + + +redo_users = {} +users_to_interactions = {} + +class DrawDallEService(commands.Cog, name="DrawDallEService"): + def __init__(self, bot, usage_service, model, message_queue, deletion_queue, converser_cog): + self.bot = bot + self.usage_service = usage_service + self.model = model + self.message_queue = message_queue + self.deletion_queue = deletion_queue + self.converser_cog = converser_cog + + print("Draw service init") + self.bot.add_cog( + ImgPromptOptimizer( + self.bot, + self.usage_service, + self.model, + self.message_queue, + self.deletion_queue, + self.converser_cog, + self, + ) + ) + print(f"Image prompt optimizer was added") + + + async def encapsulated_send(self, prompt, message, response_message=None, vary=None, draw_from_optimizer=None, user_id=None): + # send the prompt to the model + file, image_urls = self.model.send_image_request( + prompt, vary=vary if not draw_from_optimizer else None + ) + + # Start building an embed to send to the user with the results of the image generation + embed = discord.Embed( + title="Image Generation Results" if not vary else "Image Generation Results (Varying)" if not draw_from_optimizer else "Image Generation Results (Drawing from Optimizer)", + description=f"{prompt}", + color=0xC730C7, + ) + + # Add the image file to the embed + embed.set_image(url=f"attachment://{file.filename}") + + + if not response_message: # Original generation case + # Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog) + result_message = await message.channel.send( + embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog) + ) + + self.converser_cog.users_to_interactions[message.author.id] = [] + self.converser_cog.users_to_interactions[message.author.id].append(result_message.id) + + redo_users[message.author.id] = RedoUser( + prompt, message, result_message + ) + else: + if not vary: # Editing case + message = await response_message.edit(embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)) + else: # Varying case + if not draw_from_optimizer: + result_message = await response_message.edit_original_response(content="Image variation completed!", + embed=embed, file=file, + view=SaveView(image_urls, self, self.converser_cog, True)) + redo_users[message.author.id] = RedoUser( + prompt, message, result_message + ) + else: + result_message = await response_message.edit_original_response( + content="I've drawn the optimized prompt!", embed=embed, file=file, + view=SaveView(image_urls, self, self.converser_cog)) + + redo_users[user_id] = RedoUser( + prompt, message, result_message + ) + + if (user_id): + self.converser_cog.users_to_interactions[user_id].append(response_message.id) + self.converser_cog.users_to_interactions[user_id].append(result_message.id) + + @commands.command() + async def draw(self, ctx, *args): + message = ctx.message + + if message.author == self.bot.user: + return + + # Only allow the bot to be used by people who have the role "Admin" or "GPT" + general_user = not any( + role in set(self.converser_cog.DAVINCI_ROLES).union(set(self.converser_cog.CURIE_ROLES)) + for role in message.author.roles + ) + admin_user = not any( + role in self.converser_cog.DAVINCI_ROLES for role in message.author.roles + ) + + if (not admin_user and not general_user): + return + try: + + # The image prompt is everything after the command + prompt = " ".join(args) + + await self.encapsulated_send(prompt, message) + + except Exception as e: + print(e) + traceback.print_exc() + await ctx.reply("Something went wrong. Please try again later.") + await ctx.reply(e) + + + + + @commands.command() + async def local_size(self, ctx): + # Get the size of the dall-e images folder that we have on the current system. + # Check if admin user + message = ctx.message + admin_user = not any( + role in self.converser_cog.DAVINCI_ROLES for role in message.author.roles + ) + if not admin_user: + return + + image_path = self.model.IMAGE_SAVE_PATH + total_size = 0 + for dirpath, dirnames, filenames in os.walk(image_path): + for f in filenames: + fp = os.path.join(dirpath, f) + total_size += os.path.getsize(fp) + + # Format the size to be in MB and send. + total_size = total_size / 1000000 + await ctx.send(f"The size of the local images folder is {total_size} MB.") + + @commands.command() + async def clear_local(self, ctx): + message = ctx.message + admin_user = not any( + role in self.converser_cog.DAVINCI_ROLES for role in message.author.roles + ) + if not admin_user: + return + + # Delete all the local images in the images folder. + image_path = self.model.IMAGE_SAVE_PATH + for dirpath, dirnames, filenames in os.walk(image_path): + for f in filenames: + try: + fp = os.path.join(dirpath, f) + os.remove(fp) + except Exception as e: + print(e) + + await ctx.send("Local images cleared.") + + +class SaveView(discord.ui.View): + def __init__(self, image_urls, cog,converser_cog, no_retry=False): + super().__init__() + self.cog = cog + self.converser_cog = converser_cog + for x in range(1, len(image_urls)+1): + self.add_item(SaveButton(x, image_urls[x-1])) + self.add_item(VaryButton(x, image_urls[x-1], self.cog, converser_cog=self.converser_cog)) + if not no_retry: + self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog)) + + +class VaryButton(discord.ui.Button): + def __init__(self, number, image_url, cog, converser_cog): + super().__init__(style=discord.ButtonStyle.blurple, label='Vary ' + str(number)) + self.number = number + self.image_url = image_url + self.cog = cog + self.converser_cog = converser_cog + + async def callback(self, interaction: discord.Interaction): + user_id = interaction.user.id + interaction_id = interaction.message.id + print(f"The interactions for the user is {self.converser_cog.users_to_interactions[user_id]}") + print(f"The current interaction message id is {interaction_id}") + print(f"The current interaction ID is {interaction.id}") + + if (interaction_id not in self.converser_cog.users_to_interactions[user_id]): + if len(self.converser_cog.users_to_interactions[user_id]) >= 2: + interaction_id2 = interaction.id + if (interaction_id2 not in self.converser_cog.users_to_interactions[user_id]): + await interaction.response.send_message( + content="You can not vary images in someone else's chain!", ephemeral=True + ) + else: + await interaction.response.send_message( + content="You can only vary for images that you generated yourself!", ephemeral=True + ) + return + + if user_id in redo_users: + response_message = await interaction.response.send_message( + content="Varying image number " + str(self.number) + "..." + ) + self.converser_cog.users_to_interactions[user_id].append(response_message.message.id) + self.converser_cog.users_to_interactions[user_id].append(response_message.id) + prompt = redo_users[user_id].prompt + await self.cog.encapsulated_send(prompt, interaction.message, response_message=response_message, + vary=self.image_url, user_id=user_id) + + +class SaveButton(discord.ui.Button['SaveView']): + + def __init__(self, number: int, image_url: str): + super().__init__(style=discord.ButtonStyle.gray, label='Save ' + str(number)) + self.number = number + self.image_url = image_url + + + async def callback(self, interaction: discord.Interaction): + # If the image url doesn't start with "http", then we need to read the file from the URI, and then send the + # file to the user as an attachment. + try: + if not self.image_url.startswith("http"): + with open(self.image_url, "rb") as f: + image = Image.open(BytesIO(f.read())) + temp_file = tempfile.NamedTemporaryFile(suffix=".png") + image.save(temp_file.name) + + await interaction.response.send_message( + content="Here is your image for download (open original and save)", + file=discord.File(temp_file.name), ephemeral=True + ) + else: + await interaction.response.send_message(f'You can directly download this image from {self.image_url}', + ephemeral=True) + except Exception as e: + await interaction.response.send_message(f'Error: {e}', ephemeral=True) + traceback.print_exc() + + +class RedoButton(discord.ui.Button['SaveView']): + + def __init__(self, cog, converser_cog): + super().__init__(style=discord.ButtonStyle.danger, label='Retry') + self.cog = cog + self.converser_cog = converser_cog + + async def callback(self, interaction: discord.Interaction): + user_id = interaction.user.id + interaction_id = interaction.message.id + + if (interaction_id not in self.converser_cog.users_to_interactions[user_id]): + await interaction.response.send_message( + content="You can only retry for prompts that you generated yourself!", ephemeral=True + ) + return + + # We have passed the intial check of if the interaction belongs to the user + + if user_id in redo_users: + # Get the message and the prompt and call encapsulated_send + message = redo_users[user_id].message + prompt = redo_users[user_id].prompt + response_message = redo_users[user_id].response_message + message = await interaction.response.send_message( + f'Regenerating the image for your original prompt, check the original message.', ephemeral=True) + self.converser_cog.users_to_interactions[user_id].append(message.id) + + await self.cog.encapsulated_send(prompt, message, response_message) diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index d732175..130fe7e 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -8,6 +8,7 @@ import traceback import discord from discord.ext import commands +from cogs.draw_image_generation import DrawDallEService from cogs.image_prompt_optimizer import ImgPromptOptimizer from models.deletion_service import Deletion from models.message_model import Message @@ -55,6 +56,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.usage_service = usage_service self.model = model self.deletion_queue = deletion_queue + self.users_to_interactions = defaultdict(list) try: # Attempt to read a conversation starter text string from the file. @@ -91,8 +93,10 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.DEBUG_CHANNEL ) print(f"The debug channel was acquired") + + # Add the draw service to the bot. self.bot.add_cog( - ImgPromptOptimizer( + DrawDallEService( self.bot, self.usage_service, self.model, @@ -101,7 +105,8 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self, ) ) - print(f"Image prompt optimizer was added") + print(f"Draw service was added") + @commands.command() async def delete_all_conversation_threads(self, ctx): @@ -126,7 +131,14 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): and message.channel.id == self.conversation_threads[message.author.id] ) - return cond1 or cond2 + # If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation + try: + cond3 = not message.content.strip().startswith("~") + except Exception as e: + print(e) + cond3 = False + + return (cond1 or cond2) and cond3 async def end_conversation(self, message): self.conversating_users.pop(message.author.id) @@ -200,9 +212,19 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): description="The current settings of the model", color=0x00FF00, ) - for key, value in self.model.__dict__.items(): - embed.add_field(name=key, value=value, inline=False) - await message.reply(embed=embed) + # Create a two-column embed to display the settings, use \u200b to create a blank space + embed.add_field( + name="Setting", + value="\n".join([key for key in self.model.__dict__.keys() if key not in self.model._hidden_attributes]), + inline=True, + ) + embed.add_field( + name="Value", + value="\n".join([str(value) for key, value in self.model.__dict__.items() if key not in self.model._hidden_attributes]), + inline=True, + ) + await message.channel.send(embed=embed) + async def process_settings_command(self, message): # Extract the parameter and the value @@ -461,12 +483,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await thread.send( "<@" + str(message.author.id) - + "> You are now conversing with GPT3. End the conversation with !g end or just say end" + + "> You are now conversing with GPT3. *Say hi to start!*\n End the conversation by saying `end`.\n\n If you want GPT3 to ignore your messages, start your messages with `~`\n\nYour conversation will remain active even if you leave this thread and talk in other GPT supported channels, unless you end the conversation!" ) self.conversation_threads[message.author.id] = thread.id else: await message.reply( - "You are now conversing with GPT3. End the conversation with !g end or just say end" + "You are now conversing with GPT3. *Say hi to start!*\n End the conversation by saying `end`.\n\n If you want GPT3 to ignore your messages, start your messages with `~`\n\nYour conversation will remain active even if you leave this thread and talk in other GPT supported channels, unless you end the conversation!" ) return @@ -475,7 +497,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): # If the user is not conversating, don't let them end the conversation if message.author.id not in self.conversating_users: await message.reply( - "You are not conversing with GPT3. Start a conversation with !g converse" + "**You are not conversing with GPT3.** Start a conversation with `!g converse`" ) return @@ -507,11 +529,11 @@ class RedoButtonView( discord.ui.View ): # Create a class called MyView that subclasses discord.ui.View @discord.ui.button( - label="", style=discord.ButtonStyle.primary, emoji="🔄" + label="Retry", style=discord.ButtonStyle.danger, ) # Create a button with the label "😎 Click me!" with color Blurple async def button_callback(self, button, interaction): msg = await interaction.response.send_message( - "Redoing your original request...", ephemeral=True + "Retrying your original request...", ephemeral=True ) # Put the message into the deletion queue with a timestamp of 10 seconds from now to be deleted diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index 0da7cbd..325b619 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -2,6 +2,7 @@ import datetime import os import re import traceback +from collections import defaultdict import discord from discord.ext import commands @@ -10,7 +11,6 @@ from models.deletion_service import Deletion redo_users = {} - class RedoUser: def __init__(self, prompt, message, response): self.prompt = prompt @@ -22,7 +22,7 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" def __init__( - self, bot, usage_service, model, message_queue, deletion_queue, converser_cog + self, bot, usage_service, model, message_queue, deletion_queue, converser_cog, image_service_cog ): self.bot = bot self.usage_service = usage_service @@ -30,6 +30,8 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): self.message_queue = message_queue self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT self.converser_cog = converser_cog + self.image_service_cog = image_service_cog + self.deletion_queue = deletion_queue try: # Try to read the image optimizer pretext from @@ -41,6 +43,7 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): traceback.print_exc() self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT + @commands.command() async def imgoptimize(self, ctx, *args): @@ -75,10 +78,11 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): return response_message = await ctx.reply(response_text) + self.converser_cog.users_to_interactions[ctx.message.author.id] = [] + self.converser_cog.users_to_interactions[ctx.message.author.id].append(response_message.id) redo_users[ctx.author.id] = RedoUser(prompt, ctx.message, response_message) - RedoButtonView.bot = self.converser_cog - await response_message.edit(view=RedoButtonView()) + await response_message.edit(view=OptimizeView(self.converser_cog, self.image_service_cog, self.deletion_queue)) # Catch the value errors raised by the Model object except ValueError as e: @@ -94,13 +98,71 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): return -class RedoButtonView( - discord.ui.View -): # Create a class called MyView that subclasses discord.ui.View - @discord.ui.button( - label="", style=discord.ButtonStyle.primary, emoji="🔄" - ) # Create a button with the label "😎 Click me!" with color Blurple - async def button_callback(self, button, interaction): +class OptimizeView(discord.ui.View): + def __init__(self, converser_cog, image_service_cog, deletion_queue): + super().__init__() + self.cog = converser_cog + self.image_service_cog = image_service_cog + self.deletion_queue = deletion_queue + self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue)) + self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue)) + +class DrawButton(discord.ui.Button['OptimizeView']): + + def __init__(self, converser_cog, image_service_cog, deletion_queue): + super().__init__(style=discord.ButtonStyle.green, label='Draw') + self.converser_cog = converser_cog + self.image_service_cog = image_service_cog + self.deletion_queue = deletion_queue + + async def callback(self, interaction: discord.Interaction): + + user_id = interaction.user.id + interaction_id = interaction.message.id + + if interaction_id not in self.converser_cog.users_to_interactions[user_id]: + await interaction.response.send_message( + content="You can only draw for prompts that you generated yourself!", ephemeral=True + ) + return + + msg = await interaction.response.send_message( + "Drawing this prompt!", ephemeral=False + ) + self.converser_cog.users_to_interactions[interaction.user.id].append(msg.id) + self.converser_cog.users_to_interactions[interaction.user.id].append(interaction.id) + self.converser_cog.users_to_interactions[interaction.user.id].append(interaction.message.id) + + + # get the text content of the message that was interacted with + prompt = interaction.message.content + + # Use regex to replace "Output Prompt:" loosely with nothing. + # This is to ensure that the prompt is formatted correctly + prompt = re.sub(r"Output Prompt: ?", "", prompt) + + # Call the image service cog to draw the image + await self.image_service_cog.encapsulated_send(prompt, None, msg, True, True, user_id) + + +class RedoButton(discord.ui.Button['OptimizeView']): + + def __init__(self, converser_cog, image_service_cog, deletion_queue): + super().__init__(style=discord.ButtonStyle.danger, label='Retry') + self.converser_cog = converser_cog + self.image_service_cog = image_service_cog + self.deletion_queue = deletion_queue + + async def callback(self, interaction: discord.Interaction): + user_id = interaction.user.id + interaction_id = interaction.message.id + + if interaction_id not in self.converser_cog.users_to_interactions[user_id]: + await interaction.response.send_message( + content="You can only redo for prompts that you generated yourself!", ephemeral=True + ) + return + msg = await interaction.response.send_message( "Redoing your original request...", ephemeral=True ) @@ -109,7 +171,7 @@ class RedoButtonView( deletion = Deletion( msg, (datetime.datetime.now() + datetime.timedelta(seconds=10)).timestamp() ) - await self.bot.deletion_queue.put(deletion) + await self.deletion_queue.put(deletion) # Get the user user_id = interaction.user.id @@ -119,4 +181,4 @@ class RedoButtonView( message = redo_users[user_id].message prompt = redo_users[user_id].prompt response_message = redo_users[user_id].response - await self.bot.encapsulated_send(message, prompt, response_message) + await self.converser_cog.encapsulated_send(message, prompt, response_message) \ No newline at end of file diff --git a/conversation_starter_pretext.txt b/conversation_starter_pretext.txt index ea1ce8b..9c2efce 100644 --- a/conversation_starter_pretext.txt +++ b/conversation_starter_pretext.txt @@ -14,6 +14,7 @@ You are a software engineer. You type very casually and use emojis. You don't repeat yourself and don't sound boring. You always type very casually and use emojis. +Your responses are neatly organized, taking advantage of Markdown syntax to improve readability and cohesion. You talk in internet shorthand like hbu and lol and lmao and nvm and haha. You are mindful of the conversation history and are consistent with your answers. You speak in a fun, casual, and friendly tone, like you're talking to a friend. @@ -30,4 +31,4 @@ Human: [MESSAGE 2] GPTie: [RESPONSE TO MESSAGE 2] ... -Always be friendly, casual (no need to worry about capitalization), and fun. Use emojis in your responses in a way that makes sense. Always respond in nice, markdown formatting. +Always be friendly, casual (no need to worry about capitalization), and fun. Use emojis in your responses in a way that makes sense. diff --git a/image_optimizer_pretext.txt b/image_optimizer_pretext.txt index 81281e9..a4050df 100644 --- a/image_optimizer_pretext.txt +++ b/image_optimizer_pretext.txt @@ -129,5 +129,4 @@ Pay careful to attention to the words that you use in the optimized prompt, the Use all of the information, and also branch out and be creative and infer to optimize the prompt given. Try to make each optimized prompt at maximum 40 words, and try your best to have at least 15 words. Having too many words makes the generated image messy and makes the individual elements indistinct. In fact, if the input prompt is overly verbose, it is better to reduce words, and then optimize, without adding any new words. Moreover, do not add extra words to an already suitable prompt. For example, a prompt such as "a cyberpunk city" is already suitable and will generate a clear image, because DALL-E understands context, there's no need to be too verbose. Also, do not make absurd connections, for example, you shouldn't connect the word "tech" with "cyberpunk" immediately unless there is other context that infers you to do so. - Input Prompt: \ No newline at end of file diff --git a/models/openai_model.py b/models/openai_model.py index fe9b99f..7e68442 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -1,8 +1,18 @@ +import math import os +import tempfile +import uuid +from typing import Tuple, List, Any +import discord import openai # An enum of two modes, TOP_P or TEMPERATURE +import requests +from PIL import Image +from discord import File + + class Mode: TOP_P = "top_p" TEMPERATURE = "temperature" @@ -12,6 +22,11 @@ class Models: DAVINCI = "text-davinci-003" CURIE = "text-curie-001" +class ImageSize: + LARGE = "1024x1024" + MEDIUM = "512x512" + SMALL = "256x256" + class Model: def __init__(self, usage_service): @@ -30,11 +45,47 @@ class Model: self._low_usage_mode = False self.usage_service = usage_service self.DAVINCI_ROLES = ["admin", "Admin", "GPT", "gpt"] + self._image_size = ImageSize.MEDIUM + self._num_images = 1 + + try: + self.IMAGE_SAVE_PATH = os.environ["IMAGE_SAVE_PATH"] + self.custom_image_path = True + except: + self.IMAGE_SAVE_PATH = "dalleimages" + # Try to make this folder called images/ in the local directory if it doesnt exist + if not os.path.exists(self.IMAGE_SAVE_PATH): + os.makedirs(self.IMAGE_SAVE_PATH) + self.custom_image_path = False + + self._hidden_attributes = ["usage_service", "DAVINCI_ROLES", "custom_image_path", "custom_web_root", "_hidden_attributes"] openai.api_key = os.getenv("OPENAI_TOKEN") # Use the @property and @setter decorators for all the self fields to provide value checking + @property + def image_size(self): + return self._image_size + + @image_size.setter + def image_size(self, value): + if value in ImageSize.__dict__.values(): + self._image_size = value + else: + raise ValueError("Image size must be one of the following: SMALL(256x256), MEDIUM(512x512), LARGE(1024x1024)") + + @property + def num_images(self): + return self._num_images + + @num_images.setter + def num_images(self, value): + value = int(value) + if value > 4 or value <= 0: + raise ValueError("num_images must be less than 4 and at least 1.") + self._num_images = value + @property def low_usage_mode(self): return self._low_usage_mode @@ -241,3 +292,102 @@ class Model: self.usage_service.update_usage(tokens_used) return response + + def send_image_request(self, prompt, vary=None) -> tuple[File, list[Any]]: + # Validate that all the parameters are in a good state before we send the request + words = len(prompt.split(" ")) + if words < 3 or words > 75: + raise ValueError( + "Prompt must be greater than 3 words and less than 75, it is currently " + + str(words) + ) + + print("The prompt about to be sent is " + prompt) + self.usage_service.update_usage_image(self.image_size) + + if not vary: + response = openai.Image.create( + prompt=prompt, + n=self.num_images, + size=self.image_size, + ) + else: + response = openai.Image.create_variation( + image=open(vary, "rb"), + n=self.num_images, + size=self.image_size, + ) + print(response.__dict__) + + image_urls = [] + for result in response['data']: + image_urls.append(result['url']) + + # For each image url, open it as an image object using PIL + images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls] + + # Save all the images with a random name to self.IMAGE_SAVE_PATH + image_names = [f"{uuid.uuid4()}.png" for _ in range(len(images))] + for image, name in zip(images, image_names): + image.save(f"{self.IMAGE_SAVE_PATH}/{name}") + + # Update image_urls to include the local path to these new images + image_urls = [f"{self.IMAGE_SAVE_PATH}/{name}" for name in image_names] + + widths, heights = zip(*(i.size for i in images)) + + # Calculate the number of rows and columns needed for the grid + num_rows = num_cols = int(math.ceil(math.sqrt(len(images)))) + + # If there are only 2 images, set the number of rows to 1 + if len(images) == 2: + num_rows = 1 + + # Calculate the size of the combined image + width = max(widths) * num_cols + height = max(heights) * num_rows + + # Create a transparent image with the same size as the images + transparent = Image.new('RGBA', (max(widths), max(heights))) + + # Create a new image with the calculated size + new_im = Image.new('RGBA', (width, height)) + + # Paste the images and transparent segments into the grid + x_offset = y_offset = 0 + for im in images: + new_im.paste(im, (x_offset, y_offset)) + x_offset += im.size[0] + if x_offset >= width: + x_offset = 0 + y_offset += im.size[1] + + # Fill the remaining cells with transparent segments + while y_offset < height: + while x_offset < width: + new_im.paste(transparent, (x_offset, y_offset)) + x_offset += transparent.size[0] + x_offset = 0 + y_offset += transparent.size[1] + + + # Save the new_im to a temporary file and return it as a discord.File + temp_file = tempfile.NamedTemporaryFile(suffix=".png") + new_im.save(temp_file.name) + + # Print the filesize of new_im, in mega bytes + image_size = os.path.getsize(temp_file.name) / 1000000 + + # If the image size is greater than 8MB, we can't return this to the user, so we will need to downscale the + # image and try again + safety_counter = 0 + while image_size > 8 or safety_counter >= 2: + safety_counter += 1 + print(f"Image size is {image_size}MB, which is too large for discord. Downscaling and trying again") + new_im = new_im.resize((int(new_im.width / 1.05), int(new_im.height / 1.05))) + temp_file = tempfile.NamedTemporaryFile(suffix=".png") + new_im.save(temp_file.name) + image_size = os.path.getsize(temp_file.name) / 1000000 + print(f"New image size is {image_size}MB") + + return (discord.File(temp_file.name), image_urls) diff --git a/models/usage_service_model.py b/models/usage_service_model.py index 36e23e5..b6763b3 100644 --- a/models/usage_service_model.py +++ b/models/usage_service_model.py @@ -31,3 +31,23 @@ class UsageService: def count_tokens(self, input): res = self.tokenizer(input)["input_ids"] return len(res) + + def update_usage_image(self, image_size): + # 1024×1024 $0.020 / image + # 512×512 $0.018 / image + # 256×256 $0.016 / image + + if image_size == "1024x1024": + price = 0.02 + elif image_size == "512x512": + price = 0.018 + elif image_size == "256x256": + price = 0.016 + else: + raise ValueError("Invalid image size") + + usage = self.get_usage() + + with open("usage.txt", "w") as f: + f.write(str(usage + float(price))) + f.close() diff --git a/requirements.txt b/requirements.txt index d183583..999ed7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ py-cord==2.3.2 openai==0.25.0 +Pillow==9.3.0 python-dotenv==0.21.0 +requests==2.28.1 transformers==4.25.1 \ No newline at end of file