From e5eeb035c1960c7e5de48611437828c03a2b8593 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sun, 1 Jan 2023 19:38:58 +0000 Subject: [PATCH] Implemented slash command grouping and more ways to configure who can use what. Refactored /g into /chat and /gpt-chat into /converse Both have the grouping "gpt" so the full is /gpt chat, and /gpt converse Added pycord-multicog to the project for easier slash grouping across cogs Made the role check case insensitive and added some comments to the sample.env Removed the rest of the discord.ext imports, it's not really used, ymmv Removed another admin check i had missed in local-size --- .gitignore | 1 + cogs/draw_image_generation.py | 13 +++--- cogs/gpt_3_commands_and_converser.py | 60 +++++++++++++++++-------- cogs/image_prompt_optimizer.py | 9 ++-- gpt3discord.py | 6 ++- models/check_model.py | 38 +++++++++++++--- models/env_service_model.py | 66 +++++++++++++++++++++++----- pyproject.toml | 1 + requirements.txt | 1 + sample.env | 12 +++-- 10 files changed, 153 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index e979e2b..4244f47 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__ /models/__pycache__ #user files .env +.vscode bot.pid usage.txt /dalleimages \ No newline at end of file diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index f23a206..96e591d 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -6,20 +6,19 @@ from io import BytesIO import discord from PIL import Image -from discord.ext import commands +from pycord.multicog import add_to_group # We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time from models.env_service_model import EnvService from models.user_model import RedoUser -from models.check_model import Check redo_users = {} users_to_interactions = {} ALLOWED_GUILDS = EnvService.get_allowed_guilds() -class DrawDallEService(commands.Cog, name="DrawDallEService"): +class DrawDallEService(discord.Cog, name="DrawDallEService"): def __init__( self, bot, usage_service, model, message_queue, deletion_queue, converser_cog ): @@ -130,12 +129,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): self.converser_cog.users_to_interactions[user_id].append( result_message.id ) - + @add_to_group("dalle") @discord.slash_command( name="draw", description="Draw an image from a prompt", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option(name="prompt", description="The prompt to draw from", required=True) async def draw(self, ctx: discord.ApplicationContext, prompt: str): @@ -155,6 +153,7 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): await ctx.respond("Something went wrong. Please try again later.") await ctx.send_followup(e) + @add_to_group("admin") @discord.slash_command( name="local-size", description="Get the size of the dall-e images folder that we have on the current system", @@ -165,8 +164,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): await ctx.defer() # Get the size of the dall-e images folder that we have on the current system. # Check if admin user - if not await self.converser_cog.check_valid_roles(ctx.user, ctx): - return image_path = self.model.IMAGE_SAVE_PATH total_size = 0 @@ -179,11 +176,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): total_size = total_size / 1000000 await ctx.respond(f"The size of the local images folder is {total_size} MB.") + @add_to_group("admin") @discord.slash_command( name="clear-local", description="Clear the local dalleimages folder on system.", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.guild_only() async def clear_local(self, ctx): diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index b5c549d..e1c4124 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -5,7 +5,7 @@ import traceback from pathlib import Path import discord -from discord.ext import commands +from pycord.multicog import add_to_group from models.deletion_service_model import Deletion from models.env_service_model import EnvService @@ -20,7 +20,7 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() print("THE ALLOWED GUILDS ARE: ", ALLOWED_GUILDS) -class GPT3ComCon(commands.Cog, name="GPT3ComCon"): +class GPT3ComCon(discord.Cog, name="GPT3ComCon"): def __init__( self, bot, @@ -94,11 +94,28 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.message_queue = message_queue self.conversation_threads = {} - @commands.Cog.listener() + # Create slash command groups + admin = discord.SlashCommandGroup(name="admin", + description="Admin settings for the bot", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_admin_roles()] + ) + dalle = discord.SlashCommandGroup(name="dalle", + description="Dalle related commands", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_dalle_roles()] + ) + gpt = discord.SlashCommandGroup(name="gpt", + description="GPT related commands", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_gpt_roles()] + ) + + @discord.Cog.listener() async def on_member_remove(self, member): pass - @commands.Cog.listener() + @discord.Cog.listener() async def on_ready(self): self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel( self.DEBUG_CHANNEL @@ -113,18 +130,19 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): delete_existing=True, ) print(f"The debug channel was acquired and commands registered") - + + @add_to_group("admin") @discord.slash_command( name="set-usage", description="Set the current OpenAI usage (in dollars)", - checks=[Check.check_valid_roles()], + guild_ids=ALLOWED_GUILDS, ) @discord.option( name="usage_amount", description="The current usage amount in dollars and cents (e.g 10.24)", type=float, ) - async def set_usage(self, ctx, usage_amount: float): + async def set_usage(self, ctx: discord.ApplicationContext, usage_amount: float): await ctx.defer() # Attempt to convert the input usage value into a float @@ -136,12 +154,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await ctx.respond("The usage value must be a valid float.") return + @add_to_group("admin") @discord.slash_command( name="delete-conversation-threads", description="Delete all conversation threads across the bot servers.", - checks=[Check.check_valid_roles()], + guild_ids=ALLOWED_GUILDS, ) - async def delete_all_conversation_threads(self, ctx): + async def delete_all_conversation_threads(self, ctx: discord.ApplicationContext): await ctx.defer() for guild in self.bot.guilds: @@ -195,12 +214,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): title="GPT3Bot Help", description="The current commands", color=0xC730C7 ) embed.add_field( - name="/g ", + name="/chat", value="Ask GPT3 something. Be clear, long, and concise in your prompt. Don't waste tokens.", inline=False, ) embed.add_field( - name="/chat-gpt", value="Start a conversation with GPT3", inline=False + name="/converse", value="Start a conversation with GPT3", inline=False ) embed.add_field( name="/end-chat", @@ -408,7 +427,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.conversating_users[message.author.id].history = new_conversation_history # A listener for message edits to redo prompts if they are edited - @commands.Cog.listener() + @discord.Cog.listener() async def on_message_edit(self, before, after): if after.author.id in self.redo_users: if after.id == original_message[after.author.id]: @@ -439,7 +458,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.redo_users[after.author.id].prompt = after.content - @commands.Cog.listener() + @discord.Cog.listener() async def on_message(self, message): # Get the message from context @@ -657,17 +676,17 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await self.end_conversation(ctx) return + @add_to_group("gpt") @discord.slash_command( - name="g", + name="chat", description="Ask GPT3 something!", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option( name="prompt", description="The prompt to send to GPT3", required=True ) @discord.guild_only() - async def g(self, ctx: discord.ApplicationContext, prompt: str): + async def chat(self, ctx: discord.ApplicationContext, prompt: str): await ctx.defer() user = ctx.user @@ -679,11 +698,11 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await self.encapsulated_send(user.id, prompt, ctx, from_g_command=True) + @add_to_group("gpt") @discord.slash_command( - name="chat-gpt", + name="converse", description="Have a conversation with GPT3", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option( name="opener", description="Which sentence to start with", required=False @@ -701,7 +720,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): choices=["yes"], ) @discord.guild_only() - async def chat_gpt( + async def converse( self, ctx: discord.ApplicationContext, opener: str, private, minimal ): if private: @@ -781,6 +800,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.conversation_threads[user_id_normalized] = thread.id + @add_to_group("gpt") @discord.slash_command( name="end-chat", description="End a conversation with GPT3", @@ -801,6 +821,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await ctx.defer() await self.send_help_text(ctx) + @add_to_group("admin") @discord.slash_command( name="usage", description="Get usage statistics for GPT3Discord", @@ -811,6 +832,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await ctx.defer() await self.send_usage_text(ctx) + @add_to_group("admin") @discord.slash_command( name="settings", description="Get settings for GPT3Discord", diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index f6974a8..c7b9b0a 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -2,16 +2,15 @@ import re import traceback import discord -from discord.ext import commands from models.env_service_model import EnvService from models.user_model import RedoUser -from models.check_model import Check +from pycord.multicog import add_to_group ALLOWED_GUILDS = EnvService.get_allowed_guilds() -class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): +class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" def __init__( @@ -46,12 +45,12 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): except: traceback.print_exc() self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT - + + @add_to_group("dalle") @discord.slash_command( name="imgoptimize", description="Optimize a text prompt for DALL-E/MJ/SD image generation.", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option( name="prompt", description="The text prompt to optimize.", required=True diff --git a/gpt3discord.py b/gpt3discord.py index f330e97..6ad704d 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -4,8 +4,8 @@ import traceback from pathlib import Path import discord -from discord.ext import commands from dotenv import load_dotenv +from pycord.multicog import apply_multicog from cogs.draw_image_generation import DrawDallEService from cogs.gpt_3_commands_and_converser import GPT3ComCon @@ -34,7 +34,7 @@ Settings for the bot activity = discord.Activity( type=discord.ActivityType.watching, name="for /help /g, and more!" ) -bot = commands.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity) +bot = discord.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity) usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) model = Model(usage_service) @@ -104,6 +104,8 @@ async def main(): ) ) + apply_multicog(bot) + await bot.start(os.getenv("DISCORD_TOKEN")) diff --git a/models/check_model.py b/models/check_model.py index 3788374..c2dcde5 100644 --- a/models/check_model.py +++ b/models/check_model.py @@ -1,22 +1,50 @@ import discord + from models.env_service_model import EnvService from typing import Callable - -ALLOWED_ROLES = EnvService.get_allowed_roles() +ADMIN_ROLES = EnvService.get_admin_roles() +DALLE_ROLES = EnvService.get_dalle_roles() +GPT_ROLES = EnvService.get_gpt_roles() +ALLOWED_GUILDS = EnvService.get_allowed_guilds() class Check: - def check_valid_roles() -> Callable: + def check_admin_roles() -> Callable: async def inner(ctx: discord.ApplicationContext): - if not any(role.name in ALLOWED_ROLES for role in ctx.user.roles): + if not any(role.name.lower() in ADMIN_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - "You don't have permission to use this.", + "You don't have admin permission to use this.", ephemeral=True, delete_after=10, ) return False return True + return inner + def check_dalle_roles() -> Callable: + async def inner(ctx: discord.ApplicationContext): + if not any(role.name.lower() in DALLE_ROLES for role in ctx.user.roles): + await ctx.defer(ephemeral=True) + await ctx.respond( + "You don't have dalle permission to use this.", + ephemeral=True, + delete_after=10, + ) + return False + return True return inner + + def check_gpt_roles() -> Callable: + async def inner(ctx: discord.ApplicationContext): + if not any(role.name.lower() in GPT_ROLES for role in ctx.user.roles): + await ctx.defer(ephemeral=True) + await ctx.respond( + "You don't have gpt permission to use this.", + ephemeral=True, + delete_after=10, + ) + return False + return True + return inner \ No newline at end of file diff --git a/models/env_service_model.py b/models/env_service_model.py index 3536aa4..2dc3c30 100644 --- a/models/env_service_model.py +++ b/models/env_service_model.py @@ -33,23 +33,67 @@ class EnvService: return allowed_guilds @staticmethod - def get_allowed_roles(): - # ALLOWED_ROLES is a comma separated list of string roles + def get_admin_roles(): + # ADMIN_ROLES is a comma separated list of string roles # It can also just be one role # Read these allowed roles and return as a list of strings try: - allowed_roles = os.getenv("ALLOWED_ROLES") + admin_roles = os.getenv("ADMIN_ROLES") except: - allowed_roles = None + admin_roles = None - if allowed_roles is None: + if admin_roles is None: raise ValueError( - "ALLOWED_ROLES is not defined properly in the environment file!" - "Please copy your server's role and put it into ALLOWED_ROLES in the .env file." - 'For example a line should look like: `ALLOWED_ROLES="Admin"`' + "ADMIN_ROLES is not defined properly in the environment file!" + "Please copy your server's role and put it into ADMIN_ROLES in the .env file." + 'For example a line should look like: `ADMIN_ROLES="Admin"`' ) - allowed_roles = ( - allowed_roles.split(",") if "," in allowed_roles else [allowed_roles] + admin_roles = ( + admin_roles.lower().split(",") if "," in admin_roles else [admin_roles.lower()] ) - return allowed_roles + return admin_roles + + @staticmethod + def get_dalle_roles(): + # DALLE_ROLES is a comma separated list of string roles + # It can also just be one role + # Read these allowed roles and return as a list of strings + try: + dalle_roles = os.getenv("DALLE_ROLES") + except: + dalle_roles = None + + if dalle_roles is None: + raise ValueError( + "DALLE_ROLES is not defined properly in the environment file!" + "Please copy your server's role and put it into DALLE_ROLES in the .env file." + 'For example a line should look like: `DALLE_ROLES="Dalle"`' + ) + + dalle_roles = ( + dalle_roles.lower().split(",") if "," in dalle_roles else [dalle_roles.lower()] + ) + return dalle_roles + + @staticmethod + def get_gpt_roles(): + # GPT_ROLES is a comma separated list of string roles + # It can also just be one role + # Read these allowed roles and return as a list of strings + try: + gpt_roles = os.getenv("GPT_ROLES") + except: + gpt_roles = None + + if gpt_roles is None: + raise ValueError( + "GPT_ROLES is not defined properly in the environment file!" + "Please copy your server's role and put it into GPT_ROLES in the .env file." + 'For example a line should look like: `GPT_ROLES="Gpt"`' + ) + + gpt_roles = ( + gpt_roles.lower().strip().split(",") if "," in gpt_roles else [gpt_roles.lower()] + ) + return gpt_roles diff --git a/pyproject.toml b/pyproject.toml index 9183f36..954e5e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "python-dotenv", "requests", "transformers", + "pycord-multicog" ] dynamic = ["version"] [project.scripts] diff --git a/requirements.txt b/requirements.txt index 848a793..67d6ace 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ py-cord==2.3.2 python-dotenv==0.21.0 requests==2.28.1 transformers==4.25.1 +pycord-multicog==1.0.2 diff --git a/sample.env b/sample.env index 83b3c33..932f206 100644 --- a/sample.env +++ b/sample.env @@ -1,6 +1,10 @@ OPENAI_TOKEN="" DISCORD_TOKEN="" -DEBUG_GUILD="755420092027633774" -DEBUG_CHANNEL="907974109084942396" -ALLOWED_GUILDS="971268468148166697,971268468148166697" -ALLOWED_ROLES="Admin,gpt" \ No newline at end of file +DEBUG_GUILD="" +DEBUG_CHANNEL="" +# make sure not to include an id the bot doesn't have access to +ALLOWED_GUILDS="," +# no spaces, case insensitive +ADMIN_ROLES="admin,owner" +DALLE_ROLES="admin,openai,dalle" +GPT_ROLES="admin,openai,gpt" \ No newline at end of file