From 1bd8ead17cb8d239a2292051c823821de3aaae38 Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Sat, 17 Dec 2022 00:33:09 -0500 Subject: [PATCH] Add a command to optimize image text prompts before putting into dall-e! --- cogs/gpt_3_commands_and_converser.py | 4 +- cogs/image_prompt_optimizer.py | 57 ++++++++++++++++++++++++++++ main.py | 6 ++- 3 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 cogs/image_prompt_optimizer.py diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index ecc62ea..1a33b75 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -21,8 +21,6 @@ class RedoUser: redo_users = {} - - class GPT3ComCon(commands.Cog, name='GPT3ComCon'): def __init__(self, bot, usage_service, model, message_queue, DEBUG_GUILD, DEBUG_CHANNEL): @@ -52,7 +50,7 @@ class GPT3ComCon(commands.Cog, name='GPT3ComCon'): self.DEBUG_GUILD = DEBUG_GUILD self.DEBUG_CHANNEL = DEBUG_CHANNEL - print(f"The debuf channel and guild IDs are {self.DEBUG_GUILD} and {self.DEBUG_CHANNEL}") + print(f"The debug channel and guild IDs are {self.DEBUG_GUILD} and {self.DEBUG_CHANNEL}") self.TEXT_CUTOFF = 1900 self.message_queue = message_queue self.conversation_threads = {} diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py new file mode 100644 index 0000000..94137a6 --- /dev/null +++ b/cogs/image_prompt_optimizer.py @@ -0,0 +1,57 @@ +import os +import re +import traceback + +from discord.ext import commands + + +class ImgPromptOptimizer(commands.Cog, name='ImgPromptOptimizer'): + + _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" + + def __init__(self, bot, usage_service, model, message_queue): + self.bot = bot + self.usage_service = usage_service + self.model = model + self.message_queue = message_queue + + try: + self.OPTIMIZER_PRETEXT = os.getenv('OPTIMIZER_PRETEXT') + except: + self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT + + @commands.command() + async def imgoptimize(self, ctx, *args): + + prompt = self.OPTIMIZER_PRETEXT + # Add everything except the command to the prompt + for arg in args: + prompt += arg + " " + + print(f"Received an image optimization request for the following prompt: {prompt}") + + try: + response = self.model.send_request(prompt, ctx.message) + response_text = response["choices"][0]["text"] + + print(f"Received the following response: {response.__dict__}") + + if re.search(r"<@!?\d+>|<@&\d+>|<#\d+>", response_text): + await ctx.reply("I'm sorry, I can't mention users, roles, or channels.") + return + + response_message = await ctx.reply(response_text) + + + # Catch the value errors raised by the Model object + except ValueError as e: + await ctx.reply(e) + return + + # Catch all other errors, we want this to keep going if it errors out. + except Exception as e: + await ctx.reply("Something went wrong, please try again later") + await ctx.channel.send(e) + # print a stack trace + traceback.print_exc() + return diff --git a/main.py b/main.py index 2b03bc9..af0de0e 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ from discord.ext import commands from dotenv import load_dotenv from cogs.gpt_3_commands_and_converser import GPT3ComCon +from cogs.image_prompt_optimizer import ImgPromptOptimizer from models.message_model import Message from models.openai_model import Model from models.usage_service_model import UsageService @@ -22,7 +23,7 @@ asyncio.ensure_future(Message.process_message_queue(message_queue, 1.5, 5)) """ Settings for the bot """ -bot = commands.Bot(intents=discord.Intents.all(), command_prefix="'") +bot = commands.Bot(intents=discord.Intents.all(), command_prefix="!") usage_service = UsageService() model = Model(usage_service) @@ -40,8 +41,9 @@ async def main(): debug_guild = int(os.getenv('DEBUG_GUILD')) debug_channel = int(os.getenv('DEBUG_CHANNEL')) - # Load te main GPT3 Bot service + # Load the main GPT3 Bot service bot.add_cog(GPT3ComCon(bot, usage_service, model, message_queue, debug_guild, debug_channel)) + bot.add_cog(ImgPromptOptimizer(bot, usage_service, model, message_queue)) await bot.start(os.getenv('DISCORD_TOKEN'))