You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GPT3Discord/cogs/image_prompt_optimizer.py

275 lines
10 KiB

import re
import traceback
import discord
from sqlitedict import SqliteDict
from cogs.gpt_3_commands_and_converser import GPT3ComCon
from models.env_service_model import EnvService
from models.user_model import RedoUser
from pycord.multicog import add_to_group
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None
if USER_INPUT_API_KEYS:
USER_KEY_DB = SqliteDict("user_key_db.sqlite")
class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
def __init__(
self,
bot,
usage_service,
model,
message_queue,
deletion_queue,
converser_cog,
image_service_cog,
):
super().__init__()
self.bot = bot
self.usage_service = usage_service
self.model = model
self.message_queue = message_queue
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
try:
image_pretext_path = EnvService.find_shared_file(
"image_optimizer_pretext.txt"
)
# Try to read the image optimizer pretext from
# the file system
Add Dockerfile + concept of a DATA_DIR - Add a Dockerfile so people can run this bot in a docker container - Stuck with recommendation of running with python3.9 for now - Will later test with 3.11 + supply fixes if I get this working ... - Added a DATA_DIR env param to use to choose the directory to write data we want persistent across docker container restarts to be written to - We default to CWD like the code does today - we just explicitly pass it to functions / classes Test: - `docker build -t gpt3discord .` ``` crl-m1:GPT3Discord cooper$ docker image ls REPOSITORY TAG IMAGE ID CREATED SIZE gpt3discord latest 6d2832af2450 69 seconds ago 356MB ``` - Try run it ... I would guess if I had correct tokens things would work ... - To do so I plan to bind mount over /bin/.env on my docker container when I run this ... ``` crl-m1:GPT3Discord cooper$ docker run gpt3discord None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Downloading: 100%|██████████| 1.04M/1.04M [00:02<00:00, 516kB/s] Downloading: 100%|██████████| 456k/456k [00:01<00:00, 319kB/s] Downloading: 100%|██████████| 1.36M/1.36M [00:03<00:00, 443kB/s] Downloading: 100%|██████████| 665/665 [00:00<00:00, 740kB/s] Traceback (most recent call last): File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 413, in static_login data = await self.request(Route("GET", "/users/@me")) File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 366, in request raise HTTPException(response, data) discord.errors.HTTPException: 401 Unauthorized (error code: 0): 401: Unauthorized The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/bin/gpt3discord", line 79, in <module> asyncio.get_event_loop().run_until_complete(main()) File "/usr/local/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete return future.result() File "/bin/gpt3discord", line 63, in main await bot.start(os.getenv("DISCORD_TOKEN")) File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 658, in start await self.login(token) File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 514, in login data = await self.http.static_login(token.strip()) File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 417, in static_login raise LoginFailure("Improper token has been passed.") from exc discord.errors.LoginFailure: Improper token has been passed. Wrote PID to file the file bot.pid The debug channel and guild IDs are 755420092027633774 and 907974109084942396 Improper token has been passed. Removing PID file Unclosed client session client_session: <aiohttp.client.ClientSession object at 0xffff721a2dc0> Unclosed connector connections: ['[(<aiohttp.client_proto.ResponseHandler object at 0xffff718fe0a0>, 170230.336548951)]'] connector: <aiohttp.connector.TCPConnector object at 0xffff721a2fd0> ```
2 years ago
with image_pretext_path.open("r") as file:
self.OPTIMIZER_PRETEXT = file.read()
Add Dockerfile + concept of a DATA_DIR - Add a Dockerfile so people can run this bot in a docker container - Stuck with recommendation of running with python3.9 for now - Will later test with 3.11 + supply fixes if I get this working ... - Added a DATA_DIR env param to use to choose the directory to write data we want persistent across docker container restarts to be written to - We default to CWD like the code does today - we just explicitly pass it to functions / classes Test: - `docker build -t gpt3discord .` ``` crl-m1:GPT3Discord cooper$ docker image ls REPOSITORY TAG IMAGE ID CREATED SIZE gpt3discord latest 6d2832af2450 69 seconds ago 356MB ``` - Try run it ... I would guess if I had correct tokens things would work ... - To do so I plan to bind mount over /bin/.env on my docker container when I run this ... ``` crl-m1:GPT3Discord cooper$ docker run gpt3discord None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Downloading: 100%|██████████| 1.04M/1.04M [00:02<00:00, 516kB/s] Downloading: 100%|██████████| 456k/456k [00:01<00:00, 319kB/s] Downloading: 100%|██████████| 1.36M/1.36M [00:03<00:00, 443kB/s] Downloading: 100%|██████████| 665/665 [00:00<00:00, 740kB/s] Traceback (most recent call last): File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 413, in static_login data = await self.request(Route("GET", "/users/@me")) File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 366, in request raise HTTPException(response, data) discord.errors.HTTPException: 401 Unauthorized (error code: 0): 401: Unauthorized The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/bin/gpt3discord", line 79, in <module> asyncio.get_event_loop().run_until_complete(main()) File "/usr/local/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete return future.result() File "/bin/gpt3discord", line 63, in main await bot.start(os.getenv("DISCORD_TOKEN")) File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 658, in start await self.login(token) File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 514, in login data = await self.http.static_login(token.strip()) File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 417, in static_login raise LoginFailure("Improper token has been passed.") from exc discord.errors.LoginFailure: Improper token has been passed. Wrote PID to file the file bot.pid The debug channel and guild IDs are 755420092027633774 and 907974109084942396 Improper token has been passed. Removing PID file Unclosed client session client_session: <aiohttp.client.ClientSession object at 0xffff721a2dc0> Unclosed connector connections: ['[(<aiohttp.client_proto.ResponseHandler object at 0xffff718fe0a0>, 170230.336548951)]'] connector: <aiohttp.connector.TCPConnector object at 0xffff721a2fd0> ```
2 years ago
print(f"Loaded image optimizer pretext from {image_pretext_path}")
except:
traceback.print_exc()
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
@add_to_group("dalle")
@discord.slash_command(
2 years ago
name="optimize",
description="Optimize a text prompt for DALL-E/MJ/SD image generation.",
guild_ids=ALLOWED_GUILDS,
)
@discord.option(
name="prompt", description="The text prompt to optimize.", required=True
)
@discord.guild_only()
async def optimize(self, ctx: discord.ApplicationContext, prompt: str):
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(ctx.user.id, ctx)
if not user_api_key:
return
await ctx.defer()
user = ctx.user
final_prompt = self.OPTIMIZER_PRETEXT
# replace mentions with nicknames for the prompt
final_prompt += await self.converser_cog.mention_to_username(ctx, prompt)
# If the prompt doesn't end in a period, terminate it.
if not final_prompt.endswith("."):
final_prompt += "."
# Get the token amount for the prompt
tokens = self.usage_service.count_tokens(final_prompt)
try:
response = await self.model.send_request(
final_prompt,
tokens=70,
top_p_override=1.0,
temp_override=0.9,
presence_penalty_override=0.5,
best_of_override=2,
max_tokens_override=80,
custom_api_key=user_api_key,
)
# THIS USES MORE TOKENS THAN A NORMAL REQUEST! This will use roughly 4000 tokens, and will repeat the query
# twice because of the best_of_override=2 parameter. This is to ensure that the model does a lot of analysis, but is
# also relatively cost-effective
response_text = response["choices"][0]["text"]
# escape any mentions
response_text = discord.utils.escape_mentions(response_text)
# If the response_message is > 75 words, concatenate to the last 70th word
# TODO Temporary workaround until prompt is adjusted to make the optimized prompts shorter.
try:
if len(response_text.split()) > 75:
response_text = " ".join(response_text.split()[-70:])
except:
pass
response_message = await ctx.respond(
response_text.replace("Optimized Prompt:", "")
.replace("Output Prompt:", "")
.replace("Output:", "")
)
self.converser_cog.users_to_interactions[user.id] = []
self.converser_cog.users_to_interactions[user.id].append(
response_message.id
)
self.converser_cog.redo_users[user.id] = RedoUser(
prompt=final_prompt, message=ctx, ctx=ctx, response=response_message, instruction=None, codex=False
)
self.converser_cog.redo_users[user.id].add_interaction(response_message.id)
await response_message.edit(
view=OptimizeView(
self.converser_cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=user_api_key,
)
)
# Catch the value errors raised by the Model object
except ValueError as e:
await ctx.respond(e)
return
# Catch all other errors, we want this to keep going if it errors out.
except Exception as e:
await ctx.respond("Something went wrong, please try again later")
await ctx.send_followup(e)
# print a stack trace
traceback.print_exc()
return
class OptimizeView(discord.ui.View):
def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None
):
super().__init__(timeout=None)
self.cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
self.custom_api_key = custom_api_key
self.add_item(
RedoButton(
self.cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=self.custom_api_key,
)
)
self.add_item(
DrawButton(
self.cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=self.custom_api_key,
)
)
class DrawButton(discord.ui.Button["OptimizeView"]):
def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key
):
super().__init__(style=discord.ButtonStyle.green, label="Draw")
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
user_id = interaction.user.id
interaction_id = interaction.message.id
if (
interaction_id not in self.converser_cog.users_to_interactions[user_id]
or interaction_id not in self.converser_cog.redo_users[user_id].interactions
):
await interaction.response.send_message(
content="You can only draw for prompts that you generated yourself!",
ephemeral=True,
)
return
msg = await interaction.response.send_message(
2 years ago
"Drawing this prompt...", ephemeral=False
)
self.converser_cog.users_to_interactions[interaction.user.id].append(msg.id)
self.converser_cog.users_to_interactions[interaction.user.id].append(
interaction.id
)
self.converser_cog.users_to_interactions[interaction.user.id].append(
interaction.message.id
)
# get the text content of the message that was interacted with
prompt = interaction.message.content
# Use regex to replace "Output Prompt:" loosely with nothing.
# This is to ensure that the prompt is formatted correctly
2 years ago
prompt = re.sub(r"Optimized Prompt: ?", "", prompt)
# Call the image service cog to draw the image
await self.image_service_cog.encapsulated_send(
user_id,
prompt,
interaction,
msg,
True,
True,
custom_api_key=self.custom_api_key,
)
class RedoButton(discord.ui.Button["OptimizeView"]):
def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None
):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
self.custom_api_key = custom_api_key
async def callback(self, interaction: discord.Interaction):
interaction_id = interaction.message.id
# Get the user
user_id = interaction.user.id
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
user_id
].in_interaction(interaction_id):
# Get the message and the prompt and call encapsulated_send
ctx = self.converser_cog.redo_users[user_id].ctx
message = self.converser_cog.redo_users[user_id].message
prompt = self.converser_cog.redo_users[user_id].prompt
response_message = self.converser_cog.redo_users[user_id].response
msg = await interaction.response.send_message(
"Redoing your original request...", ephemeral=True, delete_after=20
)
await self.converser_cog.encapsulated_send(
id=user_id,
prompt=prompt,
ctx=ctx,
response_message=response_message,
custom_api_key=self.custom_api_key,
)
else:
await interaction.response.send_message(
content="You can only redo for prompts that you generated yourself!",
ephemeral=True,
delete_after=10,
)