|
|
|
import re
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
import discord
|
|
|
|
|
|
|
|
from models.env_service_model import EnvService
|
|
|
|
from models.user_model import RedoUser
|
|
|
|
from pycord.multicog import add_to_group
|
|
|
|
|
|
|
|
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
|
|
|
|
|
|
|
|
|
|
|
|
class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
|
|
|
|
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
bot,
|
|
|
|
usage_service,
|
|
|
|
model,
|
|
|
|
message_queue,
|
|
|
|
deletion_queue,
|
|
|
|
converser_cog,
|
|
|
|
image_service_cog,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.bot = bot
|
|
|
|
self.usage_service = usage_service
|
|
|
|
self.model = model
|
|
|
|
self.message_queue = message_queue
|
|
|
|
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
|
|
|
|
self.converser_cog = converser_cog
|
|
|
|
self.image_service_cog = image_service_cog
|
|
|
|
self.deletion_queue = deletion_queue
|
|
|
|
|
|
|
|
try:
|
|
|
|
image_pretext_path = (
|
|
|
|
self.converser_cog.data_path / "image_optimizer_pretext.txt"
|
|
|
|
)
|
|
|
|
# Try to read the image optimizer pretext from
|
|
|
|
# the file system
|
Add Dockerfile + concept of a DATA_DIR
- Add a Dockerfile so people can run this bot in a docker container
- Stuck with recommendation of running with python3.9 for now
- Will later test with 3.11 + supply fixes if I get this working ...
- Added a DATA_DIR env param to use to choose the directory to write data we want persistent across docker container restarts to be written to
- We default to CWD like the code does today - we just explicitly pass it to functions / classes
Test:
- `docker build -t gpt3discord .`
```
crl-m1:GPT3Discord cooper$ docker image ls
REPOSITORY TAG IMAGE ID CREATED SIZE
gpt3discord latest 6d2832af2450 69 seconds ago 356MB
```
- Try run it ... I would guess if I had correct tokens things would work ...
- To do so I plan to bind mount over /bin/.env on my docker container when I run this ...
```
crl-m1:GPT3Discord cooper$ docker run gpt3discord
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Downloading: 100%|██████████| 1.04M/1.04M [00:02<00:00, 516kB/s]
Downloading: 100%|██████████| 456k/456k [00:01<00:00, 319kB/s]
Downloading: 100%|██████████| 1.36M/1.36M [00:03<00:00, 443kB/s]
Downloading: 100%|██████████| 665/665 [00:00<00:00, 740kB/s]
Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 413, in static_login
data = await self.request(Route("GET", "/users/@me"))
File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 366, in request
raise HTTPException(response, data)
discord.errors.HTTPException: 401 Unauthorized (error code: 0): 401: Unauthorized
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/bin/gpt3discord", line 79, in <module>
asyncio.get_event_loop().run_until_complete(main())
File "/usr/local/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
return future.result()
File "/bin/gpt3discord", line 63, in main
await bot.start(os.getenv("DISCORD_TOKEN"))
File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 658, in start
await self.login(token)
File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 514, in login
data = await self.http.static_login(token.strip())
File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 417, in static_login
raise LoginFailure("Improper token has been passed.") from exc
discord.errors.LoginFailure: Improper token has been passed.
Wrote PID to file the file bot.pid
The debug channel and guild IDs are 755420092027633774 and 907974109084942396
Improper token has been passed.
Removing PID file
Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0xffff721a2dc0>
Unclosed connector
connections: ['[(<aiohttp.client_proto.ResponseHandler object at 0xffff718fe0a0>, 170230.336548951)]']
connector: <aiohttp.connector.TCPConnector object at 0xffff721a2fd0>
```
2 years ago
|
|
|
with image_pretext_path.open("r") as file:
|
|
|
|
self.OPTIMIZER_PRETEXT = file.read()
|
Add Dockerfile + concept of a DATA_DIR
- Add a Dockerfile so people can run this bot in a docker container
- Stuck with recommendation of running with python3.9 for now
- Will later test with 3.11 + supply fixes if I get this working ...
- Added a DATA_DIR env param to use to choose the directory to write data we want persistent across docker container restarts to be written to
- We default to CWD like the code does today - we just explicitly pass it to functions / classes
Test:
- `docker build -t gpt3discord .`
```
crl-m1:GPT3Discord cooper$ docker image ls
REPOSITORY TAG IMAGE ID CREATED SIZE
gpt3discord latest 6d2832af2450 69 seconds ago 356MB
```
- Try run it ... I would guess if I had correct tokens things would work ...
- To do so I plan to bind mount over /bin/.env on my docker container when I run this ...
```
crl-m1:GPT3Discord cooper$ docker run gpt3discord
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Downloading: 100%|██████████| 1.04M/1.04M [00:02<00:00, 516kB/s]
Downloading: 100%|██████████| 456k/456k [00:01<00:00, 319kB/s]
Downloading: 100%|██████████| 1.36M/1.36M [00:03<00:00, 443kB/s]
Downloading: 100%|██████████| 665/665 [00:00<00:00, 740kB/s]
Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 413, in static_login
data = await self.request(Route("GET", "/users/@me"))
File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 366, in request
raise HTTPException(response, data)
discord.errors.HTTPException: 401 Unauthorized (error code: 0): 401: Unauthorized
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/bin/gpt3discord", line 79, in <module>
asyncio.get_event_loop().run_until_complete(main())
File "/usr/local/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
return future.result()
File "/bin/gpt3discord", line 63, in main
await bot.start(os.getenv("DISCORD_TOKEN"))
File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 658, in start
await self.login(token)
File "/usr/local/lib/python3.9/site-packages/discord/client.py", line 514, in login
data = await self.http.static_login(token.strip())
File "/usr/local/lib/python3.9/site-packages/discord/http.py", line 417, in static_login
raise LoginFailure("Improper token has been passed.") from exc
discord.errors.LoginFailure: Improper token has been passed.
Wrote PID to file the file bot.pid
The debug channel and guild IDs are 755420092027633774 and 907974109084942396
Improper token has been passed.
Removing PID file
Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0xffff721a2dc0>
Unclosed connector
connections: ['[(<aiohttp.client_proto.ResponseHandler object at 0xffff718fe0a0>, 170230.336548951)]']
connector: <aiohttp.connector.TCPConnector object at 0xffff721a2fd0>
```
2 years ago
|
|
|
print(f"Loaded image optimizer pretext from {image_pretext_path}")
|
|
|
|
except:
|
|
|
|
traceback.print_exc()
|
|
|
|
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
|
|
|
|
|
|
|
|
@add_to_group("dalle")
|
|
|
|
@discord.slash_command(
|
|
|
|
name="optimize",
|
|
|
|
description="Optimize a text prompt for DALL-E/MJ/SD image generation.",
|
|
|
|
guild_ids=ALLOWED_GUILDS,
|
|
|
|
)
|
|
|
|
@discord.option(
|
|
|
|
name="prompt", description="The text prompt to optimize.", required=True
|
|
|
|
)
|
|
|
|
@discord.guild_only()
|
|
|
|
async def optimize(self, ctx: discord.ApplicationContext, prompt: str):
|
|
|
|
await ctx.defer()
|
|
|
|
|
|
|
|
user = ctx.user
|
|
|
|
|
|
|
|
final_prompt = self.OPTIMIZER_PRETEXT
|
|
|
|
final_prompt += prompt
|
|
|
|
|
|
|
|
# If the prompt doesn't end in a period, terminate it.
|
|
|
|
if not final_prompt.endswith("."):
|
|
|
|
final_prompt += "."
|
|
|
|
|
|
|
|
# Get the token amount for the prompt
|
|
|
|
tokens = self.usage_service.count_tokens(final_prompt)
|
|
|
|
|
|
|
|
try:
|
|
|
|
response = await self.model.send_request(
|
|
|
|
final_prompt,
|
|
|
|
tokens=70,
|
|
|
|
top_p_override=1.0,
|
|
|
|
temp_override=0.9,
|
|
|
|
presence_penalty_override=0.5,
|
|
|
|
best_of_override=2,
|
|
|
|
max_tokens_override=80,
|
|
|
|
)
|
|
|
|
|
|
|
|
# THIS USES MORE TOKENS THAN A NORMAL REQUEST! This will use roughly 4000 tokens, and will repeat the query
|
|
|
|
# twice because of the best_of_override=2 parameter. This is to ensure that the model does a lot of analysis, but is
|
|
|
|
# also relatively cost-effective
|
|
|
|
|
|
|
|
response_text = response["choices"][0]["text"]
|
|
|
|
|
|
|
|
if re.search(r"<@!?\d+>|<@&\d+>|<#\d+>", response_text):
|
|
|
|
await ctx.respond(
|
|
|
|
"I'm sorry, I can't mention users, roles, or channels."
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
response_message = await ctx.respond(
|
|
|
|
response_text.replace("Optimized Prompt:", "")
|
|
|
|
.replace("Output Prompt:", "")
|
|
|
|
.replace("Output:", "")
|
|
|
|
)
|
|
|
|
|
|
|
|
self.converser_cog.users_to_interactions[user.id] = []
|
|
|
|
self.converser_cog.users_to_interactions[user.id].append(
|
|
|
|
response_message.id
|
|
|
|
)
|
|
|
|
|
|
|
|
self.converser_cog.redo_users[user.id] = RedoUser(
|
|
|
|
final_prompt, ctx, ctx, response_message
|
|
|
|
)
|
|
|
|
self.converser_cog.redo_users[user.id].add_interaction(response_message.id)
|
|
|
|
await response_message.edit(
|
|
|
|
view=OptimizeView(
|
|
|
|
self.converser_cog, self.image_service_cog, self.deletion_queue
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Catch the value errors raised by the Model object
|
|
|
|
except ValueError as e:
|
|
|
|
await ctx.respond(e)
|
|
|
|
return
|
|
|
|
|
|
|
|
# Catch all other errors, we want this to keep going if it errors out.
|
|
|
|
except Exception as e:
|
|
|
|
await ctx.respond("Something went wrong, please try again later")
|
|
|
|
await ctx.send_followup(e)
|
|
|
|
# print a stack trace
|
|
|
|
traceback.print_exc()
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizeView(discord.ui.View):
|
|
|
|
def __init__(self, converser_cog, image_service_cog, deletion_queue):
|
|
|
|
super().__init__(timeout=None)
|
|
|
|
self.cog = converser_cog
|
|
|
|
self.image_service_cog = image_service_cog
|
|
|
|
self.deletion_queue = deletion_queue
|
|
|
|
self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue))
|
|
|
|
self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue))
|
|
|
|
|
|
|
|
|
|
|
|
class DrawButton(discord.ui.Button["OptimizeView"]):
|
|
|
|
def __init__(self, converser_cog, image_service_cog, deletion_queue):
|
|
|
|
super().__init__(style=discord.ButtonStyle.green, label="Draw")
|
|
|
|
self.converser_cog = converser_cog
|
|
|
|
self.image_service_cog = image_service_cog
|
|
|
|
self.deletion_queue = deletion_queue
|
|
|
|
|
|
|
|
async def callback(self, interaction: discord.Interaction):
|
|
|
|
|
|
|
|
user_id = interaction.user.id
|
|
|
|
interaction_id = interaction.message.id
|
|
|
|
|
|
|
|
if (
|
|
|
|
interaction_id not in self.converser_cog.users_to_interactions[user_id]
|
|
|
|
or interaction_id not in self.converser_cog.redo_users[user_id].interactions
|
|
|
|
):
|
|
|
|
await interaction.response.send_message(
|
|
|
|
content="You can only draw for prompts that you generated yourself!",
|
|
|
|
ephemeral=True,
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
msg = await interaction.response.send_message(
|
|
|
|
"Drawing this prompt...", ephemeral=False
|
|
|
|
)
|
|
|
|
self.converser_cog.users_to_interactions[interaction.user.id].append(msg.id)
|
|
|
|
self.converser_cog.users_to_interactions[interaction.user.id].append(
|
|
|
|
interaction.id
|
|
|
|
)
|
|
|
|
self.converser_cog.users_to_interactions[interaction.user.id].append(
|
|
|
|
interaction.message.id
|
|
|
|
)
|
|
|
|
|
|
|
|
# get the text content of the message that was interacted with
|
|
|
|
prompt = interaction.message.content
|
|
|
|
|
|
|
|
# Use regex to replace "Output Prompt:" loosely with nothing.
|
|
|
|
# This is to ensure that the prompt is formatted correctly
|
|
|
|
prompt = re.sub(r"Optimized Prompt: ?", "", prompt)
|
|
|
|
|
|
|
|
# Call the image service cog to draw the image
|
|
|
|
await self.image_service_cog.encapsulated_send(
|
|
|
|
user_id,
|
|
|
|
prompt,
|
|
|
|
interaction,
|
|
|
|
msg,
|
|
|
|
True,
|
|
|
|
True,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class RedoButton(discord.ui.Button["OptimizeView"]):
|
|
|
|
def __init__(self, converser_cog, image_service_cog, deletion_queue):
|
|
|
|
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
|
|
|
|
self.converser_cog = converser_cog
|
|
|
|
self.image_service_cog = image_service_cog
|
|
|
|
self.deletion_queue = deletion_queue
|
|
|
|
|
|
|
|
async def callback(self, interaction: discord.Interaction):
|
|
|
|
interaction_id = interaction.message.id
|
|
|
|
|
|
|
|
# Get the user
|
|
|
|
user_id = interaction.user.id
|
|
|
|
|
|
|
|
if user_id in self.converser_cog.redo_users and self.converser_cog.redo_users[
|
|
|
|
user_id
|
|
|
|
].in_interaction(interaction_id):
|
|
|
|
# Get the message and the prompt and call encapsulated_send
|
|
|
|
ctx = self.converser_cog.redo_users[user_id].ctx
|
|
|
|
message = self.converser_cog.redo_users[user_id].message
|
|
|
|
prompt = self.converser_cog.redo_users[user_id].prompt
|
|
|
|
response_message = self.converser_cog.redo_users[user_id].response
|
|
|
|
msg = await interaction.response.send_message(
|
|
|
|
"Redoing your original request...", ephemeral=True, delete_after=20
|
|
|
|
)
|
|
|
|
await self.converser_cog.encapsulated_send(
|
|
|
|
user_id=user_id,
|
|
|
|
prompt=prompt,
|
|
|
|
ctx=ctx,
|
|
|
|
response_message=response_message,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
await interaction.response.send_message(
|
|
|
|
content="You can only redo for prompts that you generated yourself!",
|
|
|
|
ephemeral=True,
|
|
|
|
delete_after=10,
|
|
|
|
)
|