From 3ddb1056c348b6d97c7191a9eac807535b3fdedb Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sat, 31 Dec 2022 14:13:05 +0000 Subject: [PATCH] Change way that roles are checked and some formatting --- cogs/draw_image_generation.py | 14 ++++---- cogs/gpt_3_commands_and_converser.py | 54 ++++++++++++++++------------ cogs/image_prompt_optimizer.py | 5 ++- gpt3discord.py | 10 ++++++ models/check_model.py | 22 ++++++++++++ 5 files changed, 72 insertions(+), 33 deletions(-) create mode 100644 models/check_model.py diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index 7d39c62..f23a206 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -12,6 +12,7 @@ from discord.ext import commands # We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time from models.env_service_model import EnvService from models.user_model import RedoUser +from models.check_model import Check redo_users = {} users_to_interactions = {} @@ -131,7 +132,10 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): ) @discord.slash_command( - name="draw", description="Draw an image from a prompt", guild_ids=ALLOWED_GUILDS + name="draw", + description="Draw an image from a prompt", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_valid_roles()], ) @discord.option(name="prompt", description="The prompt to draw from", required=True) async def draw(self, ctx: discord.ApplicationContext, prompt: str): @@ -142,10 +146,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): if user == self.bot.user: return - # Only allow the bot to be used by people who have the role "Admin" or "GPT" - if not await self.converser_cog.check_valid_roles(ctx.user, ctx): - return - try: asyncio.ensure_future(self.encapsulated_send(user.id, prompt, ctx)) @@ -183,14 +183,12 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): name="clear-local", description="Clear the local dalleimages folder on system.", guild_ids=ALLOWED_GUILDS, + checks=[Check.check_valid_roles()], ) @discord.guild_only() async def clear_local(self, ctx): await ctx.defer() - if not await self.converser_cog.check_valid_roles(ctx.user, ctx): - return - # Delete all the local images in the images folder. image_path = self.model.IMAGE_SAVE_PATH for dirpath, dirnames, filenames in os.walk(image_path): diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index af1e40f..05d5248 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -11,6 +11,7 @@ from models.deletion_service_model import Deletion from models.env_service_model import EnvService from models.message_model import Message from models.user_model import User, RedoUser +from models.check_model import Check from collections import defaultdict @@ -37,7 +38,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self._last_member_ = None self.conversating_users = {} self.DAVINCI_ROLES = ["admin", "Admin", "GPT", "gpt"] - self.ALLOWED_ROLES = EnvService.get_allowed_roles() self.END_PROMPTS = [ "end", "end conversation", @@ -83,12 +83,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.message_queue = message_queue self.conversation_threads = {} - async def check_valid_roles(self, user, ctx): - if not any(role.name in self.ALLOWED_ROLES for role in user.roles): - await ctx.respond("You don't have permission to use this.") - return False - return True - @commands.Cog.listener() async def on_member_remove(self, member): pass @@ -101,7 +95,9 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): print(f"The debug channel was acquired") @discord.slash_command( - name="set-usage", description="Set the current OpenAI usage (in dollars)" + name="set-usage", + description="Set the current OpenAI usage (in dollars)", + checks=[Check.check_valid_roles()], ) @discord.option( name="usage_amount", @@ -111,9 +107,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): async def set_usage(self, ctx, usage_amount: float): await ctx.defer() - if not await self.check_valid_roles(ctx.user, ctx): - return - # Attempt to convert the input usage value into a float try: usage = float(usage_amount) @@ -126,12 +119,10 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): @discord.slash_command( name="delete-conversation-threads", description="Delete all conversation threads across the bot servers.", + checks=[Check.check_valid_roles()], ) async def delete_all_conversation_threads(self, ctx): await ctx.defer() - # If the user has ADMIN_ROLES - if not await self.check_valid_roles(ctx.user, ctx): - return for guild in self.bot.guilds: for thread in guild.threads: @@ -626,7 +617,10 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): return @discord.slash_command( - name="g", description="Ask GPT3 something!", guild_ids=ALLOWED_GUILDS + name="g", + description="Ask GPT3 something!", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_valid_roles()], ) @discord.option( name="prompt", description="The prompt to send to GPT3", required=True @@ -638,9 +632,6 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): user = ctx.user prompt = prompt.strip() - if not await self.check_valid_roles(user, ctx): - return - # CONVERSE Checks here TODO # Send the request to the model # If conversing, the prompt to send is the history, otherwise, it's just the prompt @@ -657,14 +648,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): name="chat-gpt", description="Have a conversation with GPT3", guild_ids=ALLOWED_GUILDS, + checks=[Check.check_valid_roles()], ) @discord.guild_only() async def chat_gpt(self, ctx: discord.ApplicationContext): await ctx.defer() - if not await self.check_valid_roles(ctx.user, ctx): - return - user = ctx.user if user.id in self.conversating_users: @@ -730,7 +719,28 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): guild_ids=ALLOWED_GUILDS, ) @discord.option( - name="parameter", description="The setting to change", required=False + name="parameter", + description="The setting to change", + required=False, + choices=[ + "mode", + "temp", + "top_p", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "best_of", + "prompt_min_length", + "max_conversation_length", + "model", + "low_usage_mode", + "image_size", + "num_images", + "summarize_conversations", + "summarize_threshold", + "IMAGE_SAVE_PATH", + "openai_key", + ], ) @discord.option( name="value", description="The value to set the setting to", required=False diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index 144b417..88d2e92 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -6,6 +6,7 @@ from discord.ext import commands from models.env_service_model import EnvService from models.user_model import RedoUser +from models.check_model import Check ALLOWED_GUILDS = EnvService.get_allowed_guilds() @@ -50,6 +51,7 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): name="imgoptimize", description="Optimize a text prompt for DALL-E/MJ/SD image generation.", guild_ids=ALLOWED_GUILDS, + checks=[Check.check_valid_roles()], ) @discord.option( name="prompt", description="The text prompt to optimize.", required=True @@ -58,9 +60,6 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): async def imgoptimize(self, ctx: discord.ApplicationContext, prompt: str): await ctx.defer() - if not await self.converser_cog.check_valid_roles(ctx.user, ctx): - return - user = ctx.user final_prompt = self.OPTIMIZER_PRETEXT diff --git a/gpt3discord.py b/gpt3discord.py index 0d4c2e6..8ebf9f7 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -46,6 +46,16 @@ async def on_ready(): # I can make self optional by print("We have logged in as {0.user}".format(bot)) +@bot.event +async def on_application_command_error( + ctx: discord.ApplicationContext, error: discord.DiscordException +): + if isinstance(error, discord.CheckFailure): + pass + else: + raise error + + async def main(): data_path = Path(os.environ.get("DATA_DIR", os.getcwd())) debug_guild = int(os.getenv("DEBUG_GUILD")) diff --git a/models/check_model.py b/models/check_model.py new file mode 100644 index 0000000..c184404 --- /dev/null +++ b/models/check_model.py @@ -0,0 +1,22 @@ +import discord +from models.env_service_model import EnvService +from typing import Callable + + +ALLOWED_ROLES = EnvService.get_allowed_roles() + + +class Check: + def check_valid_roles() -> Callable: + async def inner(ctx: discord.ApplicationContext): + if not any(role.name in ALLOWED_ROLES for role in ctx.user.roles): + await ctx.defer(ephemeral=True) + await ctx.send_followup( + "You don't have permission to use this.", + ephemeral=True, + delete_after=10, + ) + return False + return True + + return inner