Format Python code with psf/black push

github-actions 2 years ago
parent c6ccfd94dc
commit b47d52f217

@ -56,7 +56,10 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"):
try: try:
file, image_urls = await self.model.send_image_request( file, image_urls = await self.model.send_image_request(
ctx, prompt, vary=vary if not draw_from_optimizer else None, custom_api_key=custom_api_key ctx,
prompt,
vary=vary if not draw_from_optimizer else None,
custom_api_key=custom_api_key,
) )
except ValueError as e: except ValueError as e:
( (
@ -96,7 +99,14 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"):
) )
await result_message.edit( await result_message.edit(
view=SaveView(ctx, image_urls, self, self.converser_cog, result_message, custom_api_key=custom_api_key) view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
) )
self.converser_cog.users_to_interactions[user_id] = [] self.converser_cog.users_to_interactions[user_id] = []
@ -115,7 +125,14 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"):
file=file, file=file,
) )
await message.edit( await message.edit(
view=SaveView(ctx, image_urls, self, self.converser_cog, message, custom_api_key=custom_api_key) view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
message,
custom_api_key=custom_api_key,
)
) )
else: # Varying case else: # Varying case
if not draw_from_optimizer: if not draw_from_optimizer:
@ -144,7 +161,12 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"):
) )
await result_message.edit( await result_message.edit(
view=SaveView( view=SaveView(
ctx, image_urls, self, self.converser_cog, result_message, custom_api_key=custom_api_key ctx,
image_urls,
self,
self.converser_cog,
result_message,
custom_api_key=custom_api_key,
) )
) )
@ -179,7 +201,11 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"):
return return
try: try:
asyncio.ensure_future(self.encapsulated_send(user.id, prompt, ctx, custom_api_key=user_api_key)) asyncio.ensure_future(
self.encapsulated_send(
user.id, prompt, ctx, custom_api_key=user_api_key
)
)
except Exception as e: except Exception as e:
print(e) print(e)
@ -258,11 +284,21 @@ class SaveView(discord.ui.View):
self.add_item(SaveButton(x, image_urls[x - 1])) self.add_item(SaveButton(x, image_urls[x - 1]))
if not only_save: if not only_save:
if not no_retry: if not no_retry:
self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog, custom_api_key=self.custom_api_key)) self.add_item(
RedoButton(
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
for x in range(1, len(image_urls) + 1): for x in range(1, len(image_urls) + 1):
self.add_item( self.add_item(
VaryButton( VaryButton(
x, image_urls[x - 1], self.cog, converser_cog=self.converser_cog, custom_api_key=self.custom_api_key x,
image_urls[x - 1],
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
) )
) )
@ -404,5 +440,11 @@ class RedoButton(discord.ui.Button["SaveView"]):
self.converser_cog.users_to_interactions[user_id].append(message.id) self.converser_cog.users_to_interactions[user_id].append(message.id)
asyncio.ensure_future( asyncio.ensure_future(
self.cog.encapsulated_send(user_id, prompt, ctx, response_message, custom_api_key=self.custom_api_key) self.cog.encapsulated_send(
user_id,
prompt,
ctx,
response_message,
custom_api_key=self.custom_api_key,
)
) )

@ -32,12 +32,13 @@ else:
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None USER_KEY_DB = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
print("This server was configured to enforce user input API keys. Doing the required database setup now") print(
"This server was configured to enforce user input API keys. Doing the required database setup now"
)
USER_KEY_DB = SqliteDict("user_key_db.sqlite") USER_KEY_DB = SqliteDict("user_key_db.sqlite")
print("Retrieved/created the user key database") print("Retrieved/created the user key database")
class GPT3ComCon(discord.Cog, name="GPT3ComCon"): class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
def __init__( def __init__(
self, self,
@ -148,9 +149,13 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
modal = SetupModal(title="API Key Setup") modal = SetupModal(title="API Key Setup")
if isinstance(ctx, discord.ApplicationContext): if isinstance(ctx, discord.ApplicationContext):
await ctx.send_modal(modal) await ctx.send_modal(modal)
await ctx.send_followup("You must set up your API key before using this command.") await ctx.send_followup(
"You must set up your API key before using this command."
)
else: else:
await ctx.reply("You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key.") await ctx.reply(
"You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key."
)
return user_api_key return user_api_key
async def load_file(self, file, ctx): async def load_file(self, file, ctx):
@ -199,7 +204,9 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
self.DEBUG_CHANNEL self.DEBUG_CHANNEL
) )
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
print("This bot was set to use user input API keys. Doing the required SQLite setup now") print(
"This bot was set to use user input API keys. Doing the required SQLite setup now"
)
await self.bot.sync_commands( await self.bot.sync_commands(
commands=None, commands=None,
@ -644,7 +651,9 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# Extract all the text after the !g and use it as the prompt. # Extract all the text after the !g and use it as the prompt.
user_api_key = None user_api_key = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(message.author.id, message) user_api_key = await GPT3ComCon.get_user_api_key(
message.author.id, message
)
if not user_api_key: if not user_api_key:
return return
@ -790,7 +799,11 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# Create and upsert the embedding for the conversation id, prompt, timestamp # Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await self.pinecone_service.upsert_conversation_embedding( embedding = await self.pinecone_service.upsert_conversation_embedding(
self.model, conversation_id, new_prompt, timestamp, custom_api_key=custom_api_key, self.model,
conversation_id,
new_prompt,
timestamp,
custom_api_key=custom_api_key,
) )
embedding_prompt_less_author = await self.model.send_embedding_request( embedding_prompt_less_author = await self.model.send_embedding_request(
@ -953,7 +966,11 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
# Create and upsert the embedding for the conversation id, prompt, timestamp # Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await self.pinecone_service.upsert_conversation_embedding( embedding = await self.pinecone_service.upsert_conversation_embedding(
self.model, conversation_id, response_text, timestamp, custom_api_key=custom_api_key self.model,
conversation_id,
response_text,
timestamp,
custom_api_key=custom_api_key,
) )
# Cleanse # Cleanse
@ -967,12 +984,16 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
response_message = ( response_message = (
await ctx.respond( await ctx.respond(
response_text, response_text,
view=ConversationView(ctx, self, ctx.channel.id, custom_api_key=custom_api_key), view=ConversationView(
ctx, self, ctx.channel.id, custom_api_key=custom_api_key
),
) )
if from_context if from_context
else await ctx.reply( else await ctx.reply(
response_text, response_text,
view=ConversationView(ctx, self, ctx.channel.id, custom_api_key=custom_api_key), view=ConversationView(
ctx, self, ctx.channel.id, custom_api_key=custom_api_key
),
) )
) )
@ -1368,12 +1389,18 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
await self.send_help_text(ctx) await self.send_help_text(ctx)
@discord.slash_command( @discord.slash_command(
name="setup", description="Setup your API key for use with GPT3Discord", guild_ids=ALLOWED_GUILDS name="setup",
description="Setup your API key for use with GPT3Discord",
guild_ids=ALLOWED_GUILDS,
) )
@discord.guild_only() @discord.guild_only()
async def setup(self, ctx: discord.ApplicationContext): async def setup(self, ctx: discord.ApplicationContext):
if not USER_INPUT_API_KEYS: if not USER_INPUT_API_KEYS:
await ctx.respond("This server doesn't support user input API keys.", ephemeral=True, delete_after=30) await ctx.respond(
"This server doesn't support user input API keys.",
ephemeral=True,
delete_after=30,
)
modal = SetupModal(title="API Key Setup") modal = SetupModal(title="API Key Setup")
await ctx.send_modal(modal) await ctx.send_modal(modal)
@ -1437,8 +1464,10 @@ class ConversationView(discord.ui.View):
super().__init__(timeout=3600) # 1 hour interval to redo. super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog self.converser_cog = converser_cog
self.ctx = ctx self.ctx = ctx
self.custom_api_key= custom_api_key self.custom_api_key = custom_api_key
self.add_item(RedoButton(self.converser_cog, custom_api_key=self.custom_api_key)) self.add_item(
RedoButton(self.converser_cog, custom_api_key=self.custom_api_key)
)
if id in self.converser_cog.conversation_threads: if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog)) self.add_item(EndConvoButton(self.converser_cog))
@ -1511,7 +1540,11 @@ class RedoButton(discord.ui.Button["ConversationView"]):
) )
await self.converser_cog.encapsulated_send( await self.converser_cog.encapsulated_send(
id=user_id, prompt=prompt, ctx=ctx, response_message=response_message, custom_api_key=self.custom_api_key id=user_id,
prompt=prompt,
ctx=ctx,
response_message=response_message,
custom_api_key=self.custom_api_key,
) )
else: else:
await interaction.response.send_message( await interaction.response.send_message(
@ -1520,37 +1553,63 @@ class RedoButton(discord.ui.Button["ConversationView"]):
delete_after=10, delete_after=10,
) )
class SetupModal(discord.ui.Modal): class SetupModal(discord.ui.Modal):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.add_item(discord.ui.InputText(label="OpenAI API Key", placeholder="sk--......", )) self.add_item(
discord.ui.InputText(
label="OpenAI API Key",
placeholder="sk--......",
)
)
async def callback(self, interaction: discord.Interaction): async def callback(self, interaction: discord.Interaction):
user = interaction.user user = interaction.user
api_key = self.children[0].value api_key = self.children[0].value
# Validate that api_key is indeed in this format # Validate that api_key is indeed in this format
if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key): if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key):
await interaction.response.send_message("Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.", ephemeral=True, delete_after=100) await interaction.response.send_message(
"Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.",
ephemeral=True,
delete_after=100,
)
else: else:
# We can save the key for the user to the database. # We can save the key for the user to the database.
# Make a test request using the api key to ensure that it is valid. # Make a test request using the api key to ensure that it is valid.
try: try:
await Model.send_test_request(api_key) await Model.send_test_request(api_key)
await interaction.response.send_message("Your API key was successfully validated.", ephemeral=True, delete_after=10) await interaction.response.send_message(
"Your API key was successfully validated.",
ephemeral=True,
delete_after=10,
)
except Exception as e: except Exception as e:
await interaction.response.send_message(f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding", ephemeral=True, delete_after=30) await interaction.response.send_message(
f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding",
ephemeral=True,
delete_after=30,
)
return return
# Save the key to the database # Save the key to the database
try: try:
USER_KEY_DB[user.id] = api_key USER_KEY_DB[user.id] = api_key
USER_KEY_DB.commit() USER_KEY_DB.commit()
await interaction.followup.send("Your API key was successfully saved.", ephemeral=True, delete_after=10) await interaction.followup.send(
"Your API key was successfully saved.",
ephemeral=True,
delete_after=10,
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
await interaction.followup.send("There was an error saving your API key.", ephemeral=True, delete_after=30) await interaction.followup.send(
"There was an error saving your API key.",
ephemeral=True,
delete_after=30,
)
return return
pass pass

@ -15,6 +15,7 @@ USER_KEY_DB = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
USER_KEY_DB = SqliteDict("user_key_db.sqlite") USER_KEY_DB = SqliteDict("user_key_db.sqlite")
class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
@ -123,7 +124,10 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
self.converser_cog.redo_users[user.id].add_interaction(response_message.id) self.converser_cog.redo_users[user.id].add_interaction(response_message.id)
await response_message.edit( await response_message.edit(
view=OptimizeView( view=OptimizeView(
self.converser_cog, self.image_service_cog, self.deletion_queue, custom_api_key=user_api_key, self.converser_cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=user_api_key,
) )
) )
@ -142,18 +146,36 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
class OptimizeView(discord.ui.View): class OptimizeView(discord.ui.View):
def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None): def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None
):
super().__init__(timeout=None) super().__init__(timeout=None)
self.cog = converser_cog self.cog = converser_cog
self.image_service_cog = image_service_cog self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue self.deletion_queue = deletion_queue
self.custom_api_key = custom_api_key self.custom_api_key = custom_api_key
self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue, custom_api_key=self.custom_api_key)) self.add_item(
self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue, custom_api_key=self.custom_api_key)) RedoButton(
self.cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=self.custom_api_key,
)
)
self.add_item(
DrawButton(
self.cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=self.custom_api_key,
)
)
class DrawButton(discord.ui.Button["OptimizeView"]): class DrawButton(discord.ui.Button["OptimizeView"]):
def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key): def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key
):
super().__init__(style=discord.ButtonStyle.green, label="Draw") super().__init__(style=discord.ButtonStyle.green, label="Draw")
self.converser_cog = converser_cog self.converser_cog = converser_cog
self.image_service_cog = image_service_cog self.image_service_cog = image_service_cog
@ -206,7 +228,9 @@ class DrawButton(discord.ui.Button["OptimizeView"]):
class RedoButton(discord.ui.Button["OptimizeView"]): class RedoButton(discord.ui.Button["OptimizeView"]):
def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None): def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None
):
super().__init__(style=discord.ButtonStyle.danger, label="Retry") super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.converser_cog = converser_cog self.converser_cog = converser_cog
self.image_service_cog = image_service_cog self.image_service_cog = image_service_cog

@ -474,7 +474,9 @@ class Model:
else frequency_penalty_override, else frequency_penalty_override,
"best_of": self.best_of if not best_of_override else best_of_override, "best_of": self.best_of if not best_of_override else best_of_override,
} }
headers = {"Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}"} headers = {
"Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}"
}
async with session.post( async with session.post(
"https://api.openai.com/v1/completions", json=payload, headers=headers "https://api.openai.com/v1/completions", json=payload, headers=headers
) as resp: ) as resp:
@ -499,7 +501,7 @@ class Model:
} }
headers = {"Authorization": f"Bearer {api_key}"} headers = {"Authorization": f"Bearer {api_key}"}
async with session.post( async with session.post(
"https://api.openai.com/v1/completions", json=payload, headers=headers "https://api.openai.com/v1/completions", json=payload, headers=headers
) as resp: ) as resp:
response = await resp.json() response = await resp.json()
try: try:
@ -550,9 +552,9 @@ class Model:
async with session.post( async with session.post(
"https://api.openai.com/v1/images/variations", "https://api.openai.com/v1/images/variations",
headers={ headers={
"Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}",
}, },
data=data, data=data,
) as resp: ) as resp:
response = await resp.json() response = await resp.json()

@ -26,7 +26,9 @@ class PineconeService:
print("The split chunk is ", chunk) print("The split chunk is ", chunk)
# Create an embedding for the split chunk # Create an embedding for the split chunk
embedding = await model.send_embedding_request(chunk, custom_api_key=custom_api_key) embedding = await model.send_embedding_request(
chunk, custom_api_key=custom_api_key
)
if not first_embedding: if not first_embedding:
first_embedding = embedding first_embedding = embedding
self.index.upsert( self.index.upsert(
@ -38,7 +40,9 @@ class PineconeService:
) )
return first_embedding return first_embedding
else: else:
embedding = await model.send_embedding_request(text, custom_api_key=custom_api_key) embedding = await model.send_embedding_request(
text, custom_api_key=custom_api_key
)
self.index.upsert( self.index.upsert(
[ [
( (

Loading…
Cancel
Save