import asyncio
import os
import tempfile
import traceback
from io import BytesIO

import discord
from PIL import Image
from pycord.multicog import add_to_group


# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from sqlitedict import SqliteDict

from cogs.gpt_3_commands_and_converser import GPT3ComCon
from models.env_service_model import EnvService
from models.user_model import RedoUser

redo_users = {}
users_to_interactions = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds()

USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None
if USER_INPUT_API_KEYS:
    USER_KEY_DB = SqliteDict("user_key_db.sqlite")


class DrawDallEService(discord.Cog, name="DrawDallEService"):
    def __init__(
        self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
    ):
        super().__init__()
        self.bot = bot
        self.usage_service = usage_service
        self.model = model
        self.message_queue = message_queue
        self.deletion_queue = deletion_queue
        self.converser_cog = converser_cog
        print("Draw service initialized")

    async def encapsulated_send(
        self,
        user_id,
        prompt,
        ctx,
        response_message=None,
        vary=None,
        draw_from_optimizer=None,
        custom_api_key=None,
    ):
        await asyncio.sleep(0)
        # send the prompt to the model
        from_context = isinstance(ctx, discord.ApplicationContext)

        try:
            file, image_urls = await self.model.send_image_request(
                ctx,
                prompt,
                vary=vary if not draw_from_optimizer else None,
                custom_api_key=custom_api_key,
            )
        except ValueError as e:
            (
                await ctx.channel.send(
                    f"Error: {e}. Please try again with a different prompt."
                )
                if not from_context
                else await ctx.respond(
                    f"Error: {e}. Please try again with a different prompt."
                )
            )
            return

        # Start building an embed to send to the user with the results of the image generation
        embed = discord.Embed(
            title="Image Generation Results"
            if not vary
            else "Image Generation Results (Varying)"
            if not draw_from_optimizer
            else "Image Generation Results (Drawing from Optimizer)",
            description=f"{prompt}",
            color=0xC730C7,
        )

        # Add the image file to the embed
        embed.set_image(url=f"attachment://{file.filename}")

        if not response_message:  # Original generation case
            # Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, self, self.converser_cog)
            result_message = (
                await ctx.channel.send(
                    embed=embed,
                    file=file,
                )
                if not from_context
                else await ctx.respond(embed=embed, file=file)
            )

            await result_message.edit(
                view=SaveView(
                    ctx,
                    image_urls,
                    self,
                    self.converser_cog,
                    result_message,
                    custom_api_key=custom_api_key,
                )
            )

            self.converser_cog.users_to_interactions[user_id] = []
            self.converser_cog.users_to_interactions[user_id].append(result_message.id)

            # Get the actual result message object
            if from_context:
                result_message = await ctx.fetch_message(result_message.id)

            redo_users[user_id] = RedoUser(prompt, ctx, ctx, result_message)

        else:
            if not vary:  # Editing case
                message = await response_message.edit(
                    embed=embed,
                    file=file,
                )
                await message.edit(
                    view=SaveView(
                        ctx,
                        image_urls,
                        self,
                        self.converser_cog,
                        message,
                        custom_api_key=custom_api_key,
                    )
                )
            else:  # Varying case
                if not draw_from_optimizer:
                    result_message = await response_message.edit_original_response(
                        content="Image variation completed!",
                        embed=embed,
                        file=file,
                    )
                    await result_message.edit(
                        view=SaveView(
                            ctx,
                            image_urls,
                            self,
                            self.converser_cog,
                            result_message,
                            True,
                            custom_api_key=custom_api_key,
                        )
                    )

                else:
                    result_message = await response_message.edit_original_response(
                        content="I've drawn the optimized prompt!",
                        embed=embed,
                        file=file,
                    )
                    await result_message.edit(
                        view=SaveView(
                            ctx,
                            image_urls,
                            self,
                            self.converser_cog,
                            result_message,
                            custom_api_key=custom_api_key,
                        )
                    )

                    redo_users[user_id] = RedoUser(prompt, ctx, ctx, result_message)

                self.converser_cog.users_to_interactions[user_id].append(
                    response_message.id
                )
                self.converser_cog.users_to_interactions[user_id].append(
                    result_message.id
                )

    @add_to_group("dalle")
    @discord.slash_command(
        name="draw",
        description="Draw an image from a prompt",
        guild_ids=ALLOWED_GUILDS,
    )
    @discord.option(name="prompt", description="The prompt to draw from", required=True)
    async def draw(self, ctx: discord.ApplicationContext, prompt: str):
        user_api_key = None
        if USER_INPUT_API_KEYS:
            user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx)
            if not user_api_key:
                return

        await ctx.defer()

        user = ctx.user

        if user == self.bot.user:
            return

        try:
            asyncio.ensure_future(
                self.encapsulated_send(
                    user.id, prompt, ctx, custom_api_key=user_api_key
                )
            )

        except Exception as e:
            print(e)
            traceback.print_exc()
            await ctx.respond("Something went wrong. Please try again later.")
            await ctx.send_followup(e)

    @add_to_group("system")
    @discord.slash_command(
        name="local-size",
        description="Get the size of the dall-e images folder that we have on the current system",
        guild_ids=ALLOWED_GUILDS,
    )
    @discord.guild_only()
    async def local_size(self, ctx: discord.ApplicationContext):
        await ctx.defer()
        # Get the size of the dall-e images folder that we have on the current system.

        image_path = self.model.IMAGE_SAVE_PATH
        total_size = 0
        for dirpath, dirnames, filenames in os.walk(image_path):
            for f in filenames:
                fp = os.path.join(dirpath, f)
                total_size += os.path.getsize(fp)

        # Format the size to be in MB and send.
        total_size = total_size / 1000000
        await ctx.respond(f"The size of the local images folder is {total_size} MB.")

    @add_to_group("system")
    @discord.slash_command(
        name="clear-local",
        description="Clear the local dalleimages folder on system.",
        guild_ids=ALLOWED_GUILDS,
    )
    @discord.guild_only()
    async def clear_local(self, ctx):
        await ctx.defer()

        # Delete all the local images in the images folder.
        image_path = self.model.IMAGE_SAVE_PATH
        for dirpath, dirnames, filenames in os.walk(image_path):
            for f in filenames:
                try:
                    fp = os.path.join(dirpath, f)
                    os.remove(fp)
                except Exception as e:
                    print(e)

        await ctx.respond("Local images cleared.")


class SaveView(discord.ui.View):
    def __init__(
        self,
        ctx,
        image_urls,
        cog,
        converser_cog,
        message,
        no_retry=False,
        only_save=None,
        custom_api_key=None,
    ):
        super().__init__(
            timeout=3600 if not only_save else None
        )  # 1 hour timeout for Retry, Save
        self.ctx = ctx
        self.image_urls = image_urls
        self.cog = cog
        self.no_retry = no_retry
        self.converser_cog = converser_cog
        self.message = message
        self.custom_api_key = custom_api_key
        for x in range(1, len(image_urls) + 1):
            self.add_item(SaveButton(x, image_urls[x - 1]))
        if not only_save:
            if not no_retry:
                self.add_item(
                    RedoButton(
                        self.cog,
                        converser_cog=self.converser_cog,
                        custom_api_key=self.custom_api_key,
                    )
                )
            for x in range(1, len(image_urls) + 1):
                self.add_item(
                    VaryButton(
                        x,
                        image_urls[x - 1],
                        self.cog,
                        converser_cog=self.converser_cog,
                        custom_api_key=self.custom_api_key,
                    )
                )

    # On the timeout event, override it and we want to clear the items.
    async def on_timeout(self):
        # Save all the SaveButton items, then clear all the items, then add back the SaveButton items, then
        # update the message
        self.clear_items()

        # Create a new view with the same params as this one, but pass only_save=True
        new_view = SaveView(
            self.ctx,
            self.image_urls,
            self.cog,
            self.converser_cog,
            self.message,
            self.no_retry,
            only_save=True,
        )

        # Set the view of the message to the new view
        await self.ctx.edit(view=new_view)


class VaryButton(discord.ui.Button):
    def __init__(self, number, image_url, cog, converser_cog, custom_api_key):
        super().__init__(style=discord.ButtonStyle.blurple, label="Vary " + str(number))
        self.number = number
        self.image_url = image_url
        self.cog = cog
        self.converser_cog = converser_cog
        self.custom_api_key = custom_api_key

    async def callback(self, interaction: discord.Interaction):
        user_id = interaction.user.id
        interaction_id = interaction.message.id

        if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
            if len(self.converser_cog.users_to_interactions[user_id]) >= 2:
                interaction_id2 = interaction.id
                if (
                    interaction_id2
                    not in self.converser_cog.users_to_interactions[user_id]
                ):
                    await interaction.response.send_message(
                        content="You can not vary images in someone else's chain!",
                        ephemeral=True,
                    )
            else:
                await interaction.response.send_message(
                    content="You can only vary for images that you generated yourself!",
                    ephemeral=True,
                )
            return

        if user_id in redo_users:
            response_message = await interaction.response.send_message(
                content="Varying image number " + str(self.number) + "..."
            )
            self.converser_cog.users_to_interactions[user_id].append(
                response_message.message.id
            )
            self.converser_cog.users_to_interactions[user_id].append(
                response_message.id
            )
            prompt = redo_users[user_id].prompt

            asyncio.ensure_future(
                self.cog.encapsulated_send(
                    user_id,
                    prompt,
                    interaction.message,
                    response_message=response_message,
                    vary=self.image_url,
                    custom_api_key=self.custom_api_key,
                )
            )


class SaveButton(discord.ui.Button["SaveView"]):
    def __init__(self, number: int, image_url: str):
        super().__init__(style=discord.ButtonStyle.gray, label="Save " + str(number))
        self.number = number
        self.image_url = image_url

    async def callback(self, interaction: discord.Interaction):
        # If the image url doesn't start with "http", then we need to read the file from the URI, and then send the
        # file to the user as an attachment.
        try:
            if not self.image_url.startswith("http"):
                with open(self.image_url, "rb") as f:
                    image = Image.open(BytesIO(f.read()))
                    temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
                    image.save(temp_file.name)

                    await interaction.response.send_message(
                        content="Here is your image for download (open original and save)",
                        file=discord.File(temp_file.name),
                        ephemeral=True,
                    )
            else:
                await interaction.response.send_message(
                    f"You can directly download this image from {self.image_url}",
                    ephemeral=True,
                )
        except Exception as e:
            await interaction.response.send_message(f"Error: {e}", ephemeral=True)
            traceback.print_exc()


class RedoButton(discord.ui.Button["SaveView"]):
    def __init__(self, cog, converser_cog, custom_api_key):
        super().__init__(style=discord.ButtonStyle.danger, label="Retry")
        self.cog = cog
        self.converser_cog = converser_cog
        self.custom_api_key = custom_api_key

    async def callback(self, interaction: discord.Interaction):
        user_id = interaction.user.id
        interaction_id = interaction.message.id

        if interaction_id not in self.converser_cog.users_to_interactions[user_id]:
            await interaction.response.send_message(
                content="You can only retry for prompts that you generated yourself!",
                ephemeral=True,
            )
            return

        # We have passed the intial check of if the interaction belongs to the user
        if user_id in redo_users:
            # Get the message and the prompt and call encapsulated_send
            ctx = redo_users[user_id].ctx
            prompt = redo_users[user_id].prompt
            response_message = redo_users[user_id].response
            message = await interaction.response.send_message(
                f"Regenerating the image for your original prompt, check the original message.",
                ephemeral=True,
            )
            self.converser_cog.users_to_interactions[user_id].append(message.id)

            asyncio.ensure_future(
                self.cog.encapsulated_send(
                    user_id,
                    prompt,
                    ctx,
                    response_message,
                    custom_api_key=self.custom_api_key,
                )
            )