You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

401 lines
14 KiB

import asyncio
import random
import tempfile
import traceback
from io import BytesIO
import aiohttp
import discord
from PIL import Image
from models.user_model import RedoUser
class ImageService:
def __init__(self):
pass
@staticmethod
async def encapsulated_send(
image_service_cog,
user_id,
prompt,
ctx,
response_message=None,
vary=None,
draw_from_optimizer=None,
custom_api_key=None,
):
"""service function that takes input and returns an image generation
Args:
image_service_cog (Cog): The cog which contains draw related commands
user_id (int): A discord user id
prompt (string): Prompt for the model
ctx (ApplicationContext): A discord ApplicationContext, from an interaction
response_message (Message, optional): A discord message. Defaults to None.
vary (bool, optional): If the image is a variation of another one. Defaults to None.
draw_from_optimizer (bool, optional): If the prompt is passed from the optimizer command. Defaults to None.
custom_api_key (str, optional): User defined OpenAI API key. Defaults to None.
"""
await asyncio.sleep(0)
# send the prompt to the model
from_context = isinstance(ctx, discord.ApplicationContext)
try:
file, image_urls = await image_service_cog.model.send_image_request(
ctx,
prompt,
vary=vary if not draw_from_optimizer else None,
custom_api_key=custom_api_key,
)
# Error catching for API errors
except aiohttp.ClientResponseError as e:
message = (
f"The API returned an invalid response: **{e.status}: {e.message}**"
)
if not from_context:
await ctx.channel.send(message)
else:
await ctx.respond(message, ephemeral=True)
return
except ValueError as e:
message = f"Error: {e}. Please try again with a different prompt."
if not from_context:
await ctx.channel.send(message)
else:
await ctx.respond(message, ephemeral=True)
return
# Start building an embed to send to the user with the results of the image generation
embed = discord.Embed(
title="Image Generation Results"
if not vary
else "Image Generation Results (Varying)"
if not draw_from_optimizer
else "Image Generation Results (Drawing from Optimizer)",
description=f"{prompt}",
color=0xC730C7,
)
# Add the image file to the embed
embed.set_image(url=f"attachment://{file.filename}")
if not response_message: # Original generation case
# Start an interaction with the user, we also want to send data embed=embed, file=file,
# view=SaveView(image_urls, image_service_cog, image_service_cog.converser_cog)
result_message = (
await ctx.channel.send(
embed=embed,
file=file,
)
if not from_context
else await ctx.respond(embed=embed, file=file)
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)
image_service_cog.converser_cog.users_to_interactions[user_id] = []
image_service_cog.converser_cog.users_to_interactions[user_id].append(
result_message.id
)
# Get the actual result message object
if from_context:
result_message = await ctx.fetch_message(result_message.id)
image_service_cog.redo_users[user_id] = RedoUser(
prompt=prompt,
message=ctx,
ctx=ctx,
response=result_message,
instruction=None,
paginator=None,
)
else:
if not vary: # Editing case
message = await response_message.edit(
embed=embed,
file=file,
)
await message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
message,
custom_api_key=custom_api_key,
)
)
else: # Varying case
if not draw_from_optimizer:
result_message = await response_message.edit_original_response(
content="Image variation completed!",
embed=embed,
file=file,
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
result_message,
True,
custom_api_key=custom_api_key,
)
)
else:
result_message = await response_message.edit_original_response(
content="I've drawn the optimized prompt!",
embed=embed,
file=file,
)
await result_message.edit(
view=SaveView(
ctx,
image_urls,
image_service_cog,
image_service_cog.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)
image_service_cog.redo_users[user_id] = RedoUser(
prompt=prompt,
message=ctx,
ctx=ctx,
response=result_message,
instruction=None,
paginator=None,
)
image_service_cog.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
image_service_cog.converser_cog.users_to_interactions[user_id].append(
result_message.id
)
class SaveView(discord.ui.View):
def __init__(
self,
ctx,
image_urls,
cog,
converser_cog,
message,
no_retry=False,
only_save=None,
custom_api_key=None,
):
super().__init__(
timeout=3600 if not only_save else None
) # 1 hour timeout for Retry, Save
self.ctx = ctx
self.image_urls = image_urls
self.cog = cog
self.no_retry = no_retry
self.converser_cog = converser_cog
self.message = message
self.custom_api_key = custom_api_key
for x in range(1, len(image_urls) + 1):
self.add_item(SaveButton(x, image_urls[x - 1]))
if not only_save:
if not no_retry:
self.add_item(
RedoButton(
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
for x in range(1, len(image_urls) + 1):
self.add_item(
VaryButton(
x,
image_urls[x - 1],
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
# On the timeout event, override it and we want to clear the items.
async def on_timeout(self):
# Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then
# update the message
self.clear_items()
# Create a new view with the same params as this one, but pass only_save=True
new_view = SaveView(
self.ctx,
self.image_urls,
self.cog,
self.converser_cog,
self.message,
self.no_retry,
only_save=True,
)
# Set the view of the message to the new view
await self.ctx.edit(view=new_view)
class VaryButton(discord.ui.Button):
def __init__(self, number, image_url, cog, converser_cog, custom_api_key):
super().__init__(
style=discord.ButtonStyle.blurple,
label="Vary " + str(number),
custom_id="vary_button" + str(random.randint(10000000, 99999999)),
)
self.number = number
self.image_url = image_url
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
interaction_id2 = interaction.id
if (
interaction_id2
not in self.converser_cog.users_to_interactions[user_id]
):
await interaction.response.send_message(
content="You can not vary images in someone else's chain!",
ephemeral=True,
)
else:
await interaction.response.send_message(
content="You can only vary for images that you generated yourself!",
ephemeral=True,
)
return
if user_id in self.cog.redo_users:
response_message = await interaction.response.send_message(
content="Varying image number " + str(self.number) + "..."
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.message.id
)
self.converser_cog.users_to_interactions[user_id].append(
response_message.id
)
prompt = self.cog.redo_users[user_id].prompt
asyncio.ensure_future(
ImageService.encapsulated_send(
self.cog,
user_id,
prompt,
interaction.message,
response_message=response_message,
vary=self.image_url,
custom_api_key=self.custom_api_key,
)
)
class SaveButton(discord.ui.Button["SaveView"]):
def __init__(self, number: int, image_url: str):
super().__init__(
style=discord.ButtonStyle.gray,
label="Save " + str(number),
custom_id="save_button" + str(random.randint(1000000, 9999999)),
)
self.number = number
self.image_url = image_url
async def callback(self, interaction: discord.Interaction):
# If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
# file to the user as an attachment.
try:
if not self.image_url.startswith("http"):
with open(self.image_url, "rb") as f:
image = Image.open(BytesIO(f.read()))
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(temp_file.name)
await interaction.response.send_message(
content="Here is your image for download (open original and save)",
file=discord.File(temp_file.name),
ephemeral=True,
)
else:
await interaction.response.send_message(
f"You can directly download this image from {self.image_url}",
ephemeral=True,
)
except Exception as e:
await interaction.response.send_message(f"Error: {e}", ephemeral=True)
traceback.print_exc()
class RedoButton(discord.ui.Button["SaveView"]):
def __init__(self, cog, converser_cog, custom_api_key):
super().__init__(
style=discord.ButtonStyle.danger,
label="Retry",
custom_id="redo_button_draw_main",
)
self.cog = cog
self.converser_cog = converser_cog
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
await interaction.response.send_message(
content="You can only retry for prompts that you generated yourself!",
ephemeral=True,
)
return
# We have passed the intial check of if the interaction belongs to the user
if user_id in self.cog.redo_users:
# Get the message and the prompt and call encapsulated_send
ctx = self.cog.redo_users[user_id].ctx
prompt = self.cog.redo_users[user_id].prompt
response_message = self.cog.redo_users[user_id].response
message = await interaction.response.send_message(
"Regenerating the image for your original prompt, check the original message.",
ephemeral=True,
)
self.converser_cog.users_to_interactions[user_id].append(message.id)
asyncio.ensure_future(
ImageService.encapsulated_send(
self.cog,
user_id,
prompt,
ctx,
response_message,
custom_api_key=self.custom_api_key,
)
)