diff --git a/.gitignore b/.gitignore index e979e2b..4244f47 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__ /models/__pycache__ #user files .env +.vscode bot.pid usage.txt /dalleimages \ No newline at end of file diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index f23a206..96e591d 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -6,20 +6,19 @@ from io import BytesIO import discord from PIL import Image -from discord.ext import commands +from pycord.multicog import add_to_group # We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time from models.env_service_model import EnvService from models.user_model import RedoUser -from models.check_model import Check redo_users = {} users_to_interactions = {} ALLOWED_GUILDS = EnvService.get_allowed_guilds() -class DrawDallEService(commands.Cog, name="DrawDallEService"): +class DrawDallEService(discord.Cog, name="DrawDallEService"): def __init__( self, bot, usage_service, model, message_queue, deletion_queue, converser_cog ): @@ -130,12 +129,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): self.converser_cog.users_to_interactions[user_id].append( result_message.id ) - + @add_to_group("dalle") @discord.slash_command( name="draw", description="Draw an image from a prompt", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option(name="prompt", description="The prompt to draw from", required=True) async def draw(self, ctx: discord.ApplicationContext, prompt: str): @@ -155,6 +153,7 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): await ctx.respond("Something went wrong. Please try again later.") await ctx.send_followup(e) + @add_to_group("admin") @discord.slash_command( name="local-size", description="Get the size of the dall-e images folder that we have on the current system", @@ -165,8 +164,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): await ctx.defer() # Get the size of the dall-e images folder that we have on the current system. # Check if admin user - if not await self.converser_cog.check_valid_roles(ctx.user, ctx): - return image_path = self.model.IMAGE_SAVE_PATH total_size = 0 @@ -179,11 +176,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"): total_size = total_size / 1000000 await ctx.respond(f"The size of the local images folder is {total_size} MB.") + @add_to_group("admin") @discord.slash_command( name="clear-local", description="Clear the local dalleimages folder on system.", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.guild_only() async def clear_local(self, ctx): diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index b5c549d..e1c4124 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -5,7 +5,7 @@ import traceback from pathlib import Path import discord -from discord.ext import commands +from pycord.multicog import add_to_group from models.deletion_service_model import Deletion from models.env_service_model import EnvService @@ -20,7 +20,7 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() print("THE ALLOWED GUILDS ARE: ", ALLOWED_GUILDS) -class GPT3ComCon(commands.Cog, name="GPT3ComCon"): +class GPT3ComCon(discord.Cog, name="GPT3ComCon"): def __init__( self, bot, @@ -94,11 +94,28 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.message_queue = message_queue self.conversation_threads = {} - @commands.Cog.listener() + # Create slash command groups + admin = discord.SlashCommandGroup(name="admin", + description="Admin settings for the bot", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_admin_roles()] + ) + dalle = discord.SlashCommandGroup(name="dalle", + description="Dalle related commands", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_dalle_roles()] + ) + gpt = discord.SlashCommandGroup(name="gpt", + description="GPT related commands", + guild_ids=ALLOWED_GUILDS, + checks=[Check.check_gpt_roles()] + ) + + @discord.Cog.listener() async def on_member_remove(self, member): pass - @commands.Cog.listener() + @discord.Cog.listener() async def on_ready(self): self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel( self.DEBUG_CHANNEL @@ -113,18 +130,19 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): delete_existing=True, ) print(f"The debug channel was acquired and commands registered") - + + @add_to_group("admin") @discord.slash_command( name="set-usage", description="Set the current OpenAI usage (in dollars)", - checks=[Check.check_valid_roles()], + guild_ids=ALLOWED_GUILDS, ) @discord.option( name="usage_amount", description="The current usage amount in dollars and cents (e.g 10.24)", type=float, ) - async def set_usage(self, ctx, usage_amount: float): + async def set_usage(self, ctx: discord.ApplicationContext, usage_amount: float): await ctx.defer() # Attempt to convert the input usage value into a float @@ -136,12 +154,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await ctx.respond("The usage value must be a valid float.") return + @add_to_group("admin") @discord.slash_command( name="delete-conversation-threads", description="Delete all conversation threads across the bot servers.", - checks=[Check.check_valid_roles()], + guild_ids=ALLOWED_GUILDS, ) - async def delete_all_conversation_threads(self, ctx): + async def delete_all_conversation_threads(self, ctx: discord.ApplicationContext): await ctx.defer() for guild in self.bot.guilds: @@ -195,12 +214,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): title="GPT3Bot Help", description="The current commands", color=0xC730C7 ) embed.add_field( - name="/g ", + name="/chat", value="Ask GPT3 something. Be clear, long, and concise in your prompt. Don't waste tokens.", inline=False, ) embed.add_field( - name="/chat-gpt", value="Start a conversation with GPT3", inline=False + name="/converse", value="Start a conversation with GPT3", inline=False ) embed.add_field( name="/end-chat", @@ -408,7 +427,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.conversating_users[message.author.id].history = new_conversation_history # A listener for message edits to redo prompts if they are edited - @commands.Cog.listener() + @discord.Cog.listener() async def on_message_edit(self, before, after): if after.author.id in self.redo_users: if after.id == original_message[after.author.id]: @@ -439,7 +458,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.redo_users[after.author.id].prompt = after.content - @commands.Cog.listener() + @discord.Cog.listener() async def on_message(self, message): # Get the message from context @@ -657,17 +676,17 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await self.end_conversation(ctx) return + @add_to_group("gpt") @discord.slash_command( - name="g", + name="chat", description="Ask GPT3 something!", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option( name="prompt", description="The prompt to send to GPT3", required=True ) @discord.guild_only() - async def g(self, ctx: discord.ApplicationContext, prompt: str): + async def chat(self, ctx: discord.ApplicationContext, prompt: str): await ctx.defer() user = ctx.user @@ -679,11 +698,11 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await self.encapsulated_send(user.id, prompt, ctx, from_g_command=True) + @add_to_group("gpt") @discord.slash_command( - name="chat-gpt", + name="converse", description="Have a conversation with GPT3", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option( name="opener", description="Which sentence to start with", required=False @@ -701,7 +720,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): choices=["yes"], ) @discord.guild_only() - async def chat_gpt( + async def converse( self, ctx: discord.ApplicationContext, opener: str, private, minimal ): if private: @@ -781,6 +800,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): self.conversation_threads[user_id_normalized] = thread.id + @add_to_group("gpt") @discord.slash_command( name="end-chat", description="End a conversation with GPT3", @@ -801,6 +821,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await ctx.defer() await self.send_help_text(ctx) + @add_to_group("admin") @discord.slash_command( name="usage", description="Get usage statistics for GPT3Discord", @@ -811,6 +832,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"): await ctx.defer() await self.send_usage_text(ctx) + @add_to_group("admin") @discord.slash_command( name="settings", description="Get settings for GPT3Discord", diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index f6974a8..c7b9b0a 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -2,16 +2,15 @@ import re import traceback import discord -from discord.ext import commands from models.env_service_model import EnvService from models.user_model import RedoUser -from models.check_model import Check +from pycord.multicog import add_to_group ALLOWED_GUILDS = EnvService.get_allowed_guilds() -class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): +class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" def __init__( @@ -46,12 +45,12 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): except: traceback.print_exc() self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT - + + @add_to_group("dalle") @discord.slash_command( name="imgoptimize", description="Optimize a text prompt for DALL-E/MJ/SD image generation.", guild_ids=ALLOWED_GUILDS, - checks=[Check.check_valid_roles()], ) @discord.option( name="prompt", description="The text prompt to optimize.", required=True diff --git a/gpt3discord.py b/gpt3discord.py index f330e97..6ad704d 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -4,8 +4,8 @@ import traceback from pathlib import Path import discord -from discord.ext import commands from dotenv import load_dotenv +from pycord.multicog import apply_multicog from cogs.draw_image_generation import DrawDallEService from cogs.gpt_3_commands_and_converser import GPT3ComCon @@ -34,7 +34,7 @@ Settings for the bot activity = discord.Activity( type=discord.ActivityType.watching, name="for /help /g, and more!" ) -bot = commands.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity) +bot = discord.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity) usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) model = Model(usage_service) @@ -104,6 +104,8 @@ async def main(): ) ) + apply_multicog(bot) + await bot.start(os.getenv("DISCORD_TOKEN")) diff --git a/models/check_model.py b/models/check_model.py index 3788374..c2dcde5 100644 --- a/models/check_model.py +++ b/models/check_model.py @@ -1,22 +1,50 @@ import discord + from models.env_service_model import EnvService from typing import Callable - -ALLOWED_ROLES = EnvService.get_allowed_roles() +ADMIN_ROLES = EnvService.get_admin_roles() +DALLE_ROLES = EnvService.get_dalle_roles() +GPT_ROLES = EnvService.get_gpt_roles() +ALLOWED_GUILDS = EnvService.get_allowed_guilds() class Check: - def check_valid_roles() -> Callable: + def check_admin_roles() -> Callable: async def inner(ctx: discord.ApplicationContext): - if not any(role.name in ALLOWED_ROLES for role in ctx.user.roles): + if not any(role.name.lower() in ADMIN_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - "You don't have permission to use this.", + "You don't have admin permission to use this.", ephemeral=True, delete_after=10, ) return False return True + return inner + def check_dalle_roles() -> Callable: + async def inner(ctx: discord.ApplicationContext): + if not any(role.name.lower() in DALLE_ROLES for role in ctx.user.roles): + await ctx.defer(ephemeral=True) + await ctx.respond( + "You don't have dalle permission to use this.", + ephemeral=True, + delete_after=10, + ) + return False + return True return inner + + def check_gpt_roles() -> Callable: + async def inner(ctx: discord.ApplicationContext): + if not any(role.name.lower() in GPT_ROLES for role in ctx.user.roles): + await ctx.defer(ephemeral=True) + await ctx.respond( + "You don't have gpt permission to use this.", + ephemeral=True, + delete_after=10, + ) + return False + return True + return inner \ No newline at end of file diff --git a/models/env_service_model.py b/models/env_service_model.py index 3536aa4..2dc3c30 100644 --- a/models/env_service_model.py +++ b/models/env_service_model.py @@ -33,23 +33,67 @@ class EnvService: return allowed_guilds @staticmethod - def get_allowed_roles(): - # ALLOWED_ROLES is a comma separated list of string roles + def get_admin_roles(): + # ADMIN_ROLES is a comma separated list of string roles # It can also just be one role # Read these allowed roles and return as a list of strings try: - allowed_roles = os.getenv("ALLOWED_ROLES") + admin_roles = os.getenv("ADMIN_ROLES") except: - allowed_roles = None + admin_roles = None - if allowed_roles is None: + if admin_roles is None: raise ValueError( - "ALLOWED_ROLES is not defined properly in the environment file!" - "Please copy your server's role and put it into ALLOWED_ROLES in the .env file." - 'For example a line should look like: `ALLOWED_ROLES="Admin"`' + "ADMIN_ROLES is not defined properly in the environment file!" + "Please copy your server's role and put it into ADMIN_ROLES in the .env file." + 'For example a line should look like: `ADMIN_ROLES="Admin"`' ) - allowed_roles = ( - allowed_roles.split(",") if "," in allowed_roles else [allowed_roles] + admin_roles = ( + admin_roles.lower().split(",") if "," in admin_roles else [admin_roles.lower()] ) - return allowed_roles + return admin_roles + + @staticmethod + def get_dalle_roles(): + # DALLE_ROLES is a comma separated list of string roles + # It can also just be one role + # Read these allowed roles and return as a list of strings + try: + dalle_roles = os.getenv("DALLE_ROLES") + except: + dalle_roles = None + + if dalle_roles is None: + raise ValueError( + "DALLE_ROLES is not defined properly in the environment file!" + "Please copy your server's role and put it into DALLE_ROLES in the .env file." + 'For example a line should look like: `DALLE_ROLES="Dalle"`' + ) + + dalle_roles = ( + dalle_roles.lower().split(",") if "," in dalle_roles else [dalle_roles.lower()] + ) + return dalle_roles + + @staticmethod + def get_gpt_roles(): + # GPT_ROLES is a comma separated list of string roles + # It can also just be one role + # Read these allowed roles and return as a list of strings + try: + gpt_roles = os.getenv("GPT_ROLES") + except: + gpt_roles = None + + if gpt_roles is None: + raise ValueError( + "GPT_ROLES is not defined properly in the environment file!" + "Please copy your server's role and put it into GPT_ROLES in the .env file." + 'For example a line should look like: `GPT_ROLES="Gpt"`' + ) + + gpt_roles = ( + gpt_roles.lower().strip().split(",") if "," in gpt_roles else [gpt_roles.lower()] + ) + return gpt_roles diff --git a/pyproject.toml b/pyproject.toml index 9183f36..954e5e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "python-dotenv", "requests", "transformers", + "pycord-multicog" ] dynamic = ["version"] [project.scripts] diff --git a/requirements.txt b/requirements.txt index 848a793..67d6ace 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ py-cord==2.3.2 python-dotenv==0.21.0 requests==2.28.1 transformers==4.25.1 +pycord-multicog==1.0.2 diff --git a/sample.env b/sample.env index 83b3c33..932f206 100644 --- a/sample.env +++ b/sample.env @@ -1,6 +1,10 @@ OPENAI_TOKEN="" DISCORD_TOKEN="" -DEBUG_GUILD="755420092027633774" -DEBUG_CHANNEL="907974109084942396" -ALLOWED_GUILDS="971268468148166697,971268468148166697" -ALLOWED_ROLES="Admin,gpt" \ No newline at end of file +DEBUG_GUILD="" +DEBUG_CHANNEL="" +# make sure not to include an id the bot doesn't have access to +ALLOWED_GUILDS="," +# no spaces, case insensitive +ADMIN_ROLES="admin,owner" +DALLE_ROLES="admin,openai,dalle" +GPT_ROLES="admin,openai,gpt" \ No newline at end of file