You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GPT3Discord/cogs/image_prompt_optimizer.py

58 lines
1.9 KiB

import os
import re
import traceback
from discord.ext import commands
class ImgPromptOptimizer(commands.Cog, name='ImgPromptOptimizer'):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
def __init__(self, bot, usage_service, model, message_queue):
self.bot = bot
self.usage_service = usage_service
self.model = model
self.message_queue = message_queue
try:
self.OPTIMIZER_PRETEXT = os.getenv('OPTIMIZER_PRETEXT')
except:
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
@commands.command()
async def imgoptimize(self, ctx, *args):
prompt = self.OPTIMIZER_PRETEXT
# Add everything except the command to the prompt
for arg in args:
prompt += arg + " "
print(f"Received an image optimization request for the following prompt: {prompt}")
try:
response = self.model.send_request(prompt, ctx.message)
response_text = response["choices"][0]["text"]
print(f"Received the following response: {response.__dict__}")
if re.search(r"<@!?\d+>|<@&\d+>|<#\d+>", response_text):
await ctx.reply("I'm sorry, I can't mention users, roles, or channels.")
return
response_message = await ctx.reply(response_text)
# Catch the value errors raised by the Model object
except ValueError as e:
await ctx.reply(e)
return
# Catch all other errors, we want this to keep going if it errors out.
except Exception as e:
await ctx.reply("Something went wrong, please try again later")
await ctx.channel.send(e)
# print a stack trace
traceback.print_exc()
return