Merge pull request #48 from Hikari-Haru/slashgroups-refactor

Slashgroups refactor and separation of commands based on roles
Kaveen Kumarasinghe 2 years ago committed by GitHub
commit dd6be231d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

1
.gitignore vendored

@ -4,6 +4,7 @@ __pycache__
/models/__pycache__ /models/__pycache__
#user files #user files
.env .env
.vscode
bot.pid bot.pid
usage.txt usage.txt
/dalleimages /dalleimages

@ -6,20 +6,19 @@ from io import BytesIO
import discord import discord
from PIL import Image from PIL import Image
from discord.ext import commands from pycord.multicog import add_to_group
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time # We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from models.env_service_model import EnvService from models.env_service_model import EnvService
from models.user_model import RedoUser from models.user_model import RedoUser
from models.check_model import Check
redo_users = {} redo_users = {}
users_to_interactions = {} users_to_interactions = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class DrawDallEService(commands.Cog, name="DrawDallEService"): class DrawDallEService(discord.Cog, name="DrawDallEService"):
def __init__( def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
): ):
@ -147,12 +146,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
self.converser_cog.users_to_interactions[user_id].append( self.converser_cog.users_to_interactions[user_id].append(
result_message.id result_message.id
) )
@add_to_group("dalle")
@discord.slash_command( @discord.slash_command(
name="draw", name="draw",
description="Draw an image from a prompt", description="Draw an image from a prompt",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option(name="prompt", description="The prompt to draw from", required=True) @discord.option(name="prompt", description="The prompt to draw from", required=True)
async def draw(self, ctx: discord.ApplicationContext, prompt: str): async def draw(self, ctx: discord.ApplicationContext, prompt: str):
@ -172,6 +170,7 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
await ctx.respond("Something went wrong. Please try again later.") await ctx.respond("Something went wrong. Please try again later.")
await ctx.send_followup(e) await ctx.send_followup(e)
@add_to_group("system")
@discord.slash_command( @discord.slash_command(
name="local-size", name="local-size",
description="Get the size of the dall-e images folder that we have on the current system", description="Get the size of the dall-e images folder that we have on the current system",
@ -181,9 +180,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
async def local_size(self, ctx: discord.ApplicationContext): async def local_size(self, ctx: discord.ApplicationContext):
await ctx.defer() await ctx.defer()
# Get the size of the dall-e images folder that we have on the current system. # Get the size of the dall-e images folder that we have on the current system.
# Check if admin user
if not await self.converser_cog.check_valid_roles(ctx.user, ctx):
return
image_path = self.model.IMAGE_SAVE_PATH image_path = self.model.IMAGE_SAVE_PATH
total_size = 0 total_size = 0
@ -196,11 +192,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
total_size = total_size / 1000000 total_size = total_size / 1000000
await ctx.respond(f"The size of the local images folder is {total_size} MB.") await ctx.respond(f"The size of the local images folder is {total_size} MB.")
@add_to_group("system")
@discord.slash_command( @discord.slash_command(
name="clear-local", name="clear-local",
description="Clear the local dalleimages folder on system.", description="Clear the local dalleimages folder on system.",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.guild_only() @discord.guild_only()
async def clear_local(self, ctx): async def clear_local(self, ctx):

@ -5,7 +5,7 @@ import traceback
from pathlib import Path from pathlib import Path
import discord import discord
from discord.ext import commands from pycord.multicog import add_to_group
from models.deletion_service_model import Deletion from models.deletion_service_model import Deletion
from models.env_service_model import EnvService from models.env_service_model import EnvService
@ -19,7 +19,7 @@ original_message = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class GPT3ComCon(commands.Cog, name="GPT3ComCon"): class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
def __init__( def __init__(
self, self,
bot, bot,
@ -93,6 +93,23 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.message_queue = message_queue self.message_queue = message_queue
self.conversation_threads = {} self.conversation_threads = {}
# Create slash command groups
dalle = discord.SlashCommandGroup(name="dalle",
description="Dalle related commands",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_dalle_roles()]
)
gpt = discord.SlashCommandGroup(name="gpt",
description="GPT related commands",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_gpt_roles()]
)
system = discord.SlashCommandGroup(name="system",
description="Admin/System settings for the bot",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_admin_roles()]
)
@commands.Cog.listener() @commands.Cog.listener()
async def on_member_join(self, member): async def on_member_join(self, member):
if self.model.welcome_message_enabled: if self.model.welcome_message_enabled:
@ -122,7 +139,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
async def on_member_remove(self, member): async def on_member_remove(self, member):
pass pass
@commands.Cog.listener() @discord.Cog.listener()
async def on_ready(self): async def on_ready(self):
self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel( self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel(
self.DEBUG_CHANNEL self.DEBUG_CHANNEL
@ -138,17 +155,18 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
) )
print(f"The debug channel was acquired and commands registered") print(f"The debug channel was acquired and commands registered")
@add_to_group("system")
@discord.slash_command( @discord.slash_command(
name="set-usage", name="set-usage",
description="Set the current OpenAI usage (in dollars)", description="Set the current OpenAI usage (in dollars)",
checks=[Check.check_valid_roles()], guild_ids=ALLOWED_GUILDS,
) )
@discord.option( @discord.option(
name="usage_amount", name="usage_amount",
description="The current usage amount in dollars and cents (e.g 10.24)", description="The current usage amount in dollars and cents (e.g 10.24)",
type=float, type=float,
) )
async def set_usage(self, ctx, usage_amount: float): async def set_usage(self, ctx: discord.ApplicationContext, usage_amount: float):
await ctx.defer() await ctx.defer()
# Attempt to convert the input usage value into a float # Attempt to convert the input usage value into a float
@ -160,12 +178,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await ctx.respond("The usage value must be a valid float.") await ctx.respond("The usage value must be a valid float.")
return return
@add_to_group("system")
@discord.slash_command( @discord.slash_command(
name="delete-conversation-threads", name="delete-conversation-threads",
description="Delete all conversation threads across the bot servers.", description="Delete all conversation threads across the bot servers.",
checks=[Check.check_valid_roles()], guild_ids=ALLOWED_GUILDS,
) )
async def delete_all_conversation_threads(self, ctx): async def delete_all_conversation_threads(self, ctx: discord.ApplicationContext):
await ctx.defer() await ctx.defer()
for guild in self.bot.guilds: for guild in self.bot.guilds:
@ -219,12 +238,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
title="GPT3Bot Help", description="The current commands", color=0xC730C7 title="GPT3Bot Help", description="The current commands", color=0xC730C7
) )
embed.add_field( embed.add_field(
name="/g <prompt>", name="/ask",
value="Ask GPT3 something. Be clear, long, and concise in your prompt. Don't waste tokens.", value="Ask GPT3 something. Be clear, long, and concise in your prompt. Don't waste tokens.",
inline=False, inline=False,
) )
embed.add_field( embed.add_field(
name="/chat-gpt", value="Start a conversation with GPT3", inline=False name="/converse", value="Start a conversation with GPT3", inline=False
) )
embed.add_field( embed.add_field(
name="/end-chat", name="/end-chat",
@ -247,7 +266,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
inline=False, inline=False,
) )
embed.add_field( embed.add_field(
name="/imgoptimize <image prompt>", name="/optimize <image prompt>",
value="Optimize an image prompt for use with DALL-E2, Midjourney, SD, etc.", value="Optimize an image prompt for use with DALL-E2, Midjourney, SD, etc.",
inline=False, inline=False,
) )
@ -432,7 +451,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.conversating_users[message.author.id].history = new_conversation_history self.conversating_users[message.author.id].history = new_conversation_history
# A listener for message edits to redo prompts if they are edited # A listener for message edits to redo prompts if they are edited
@commands.Cog.listener() @discord.Cog.listener()
async def on_message_edit(self, before, after): async def on_message_edit(self, before, after):
if after.author.id in self.redo_users: if after.author.id in self.redo_users:
if after.id == original_message[after.author.id]: if after.id == original_message[after.author.id]:
@ -463,7 +482,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.redo_users[after.author.id].prompt = after.content self.redo_users[after.author.id].prompt = after.content
@commands.Cog.listener() @discord.Cog.listener()
async def on_message(self, message): async def on_message(self, message):
# Get the message from context # Get the message from context
@ -681,17 +700,17 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await self.end_conversation(ctx) await self.end_conversation(ctx)
return return
@add_to_group("gpt")
@discord.slash_command( @discord.slash_command(
name="g", name="ask",
description="Ask GPT3 something!", description="Ask GPT3 something!",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option( @discord.option(
name="prompt", description="The prompt to send to GPT3", required=True name="prompt", description="The prompt to send to GPT3", required=True
) )
@discord.guild_only() @discord.guild_only()
async def g(self, ctx: discord.ApplicationContext, prompt: str): async def ask(self, ctx: discord.ApplicationContext, prompt: str):
await ctx.defer() await ctx.defer()
user = ctx.user user = ctx.user
@ -703,11 +722,11 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await self.encapsulated_send(user.id, prompt, ctx, from_g_command=True) await self.encapsulated_send(user.id, prompt, ctx, from_g_command=True)
@add_to_group("gpt")
@discord.slash_command( @discord.slash_command(
name="chat-gpt", name="converse",
description="Have a conversation with GPT3", description="Have a conversation with GPT3",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option( @discord.option(
name="opener", description="Which sentence to start with", required=False name="opener", description="Which sentence to start with", required=False
@ -725,7 +744,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
choices=["yes"], choices=["yes"],
) )
@discord.guild_only() @discord.guild_only()
async def chat_gpt( async def converse(
self, ctx: discord.ApplicationContext, opener: str, private, minimal self, ctx: discord.ApplicationContext, opener: str, private, minimal
): ):
if private: if private:
@ -805,6 +824,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.conversation_threads[user_id_normalized] = thread.id self.conversation_threads[user_id_normalized] = thread.id
@add_to_group("gpt")
@discord.slash_command( @discord.slash_command(
name="end-chat", name="end-chat",
description="End a conversation with GPT3", description="End a conversation with GPT3",
@ -825,6 +845,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
await self.send_help_text(ctx) await self.send_help_text(ctx)
@add_to_group("system")
@discord.slash_command( @discord.slash_command(
name="usage", name="usage",
description="Get usage statistics for GPT3Discord", description="Get usage statistics for GPT3Discord",
@ -835,6 +856,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
await self.send_usage_text(ctx) await self.send_usage_text(ctx)
@add_to_group("system")
@discord.slash_command( @discord.slash_command(
name="settings", name="settings",
description="Get settings for GPT3Discord", description="Get settings for GPT3Discord",

@ -2,16 +2,15 @@ import re
import traceback import traceback
import discord import discord
from discord.ext import commands
from models.env_service_model import EnvService from models.env_service_model import EnvService
from models.user_model import RedoUser from models.user_model import RedoUser
from models.check_model import Check from pycord.multicog import add_to_group
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
def __init__( def __init__(
@ -47,17 +46,17 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
traceback.print_exc() traceback.print_exc()
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
@add_to_group("dalle")
@discord.slash_command( @discord.slash_command(
name="imgoptimize", name="imgoptimize",
description="Optimize a text prompt for DALL-E/MJ/SD image generation.", description="Optimize a text prompt for DALL-E/MJ/SD image generation.",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option( @discord.option(
name="prompt", description="The text prompt to optimize.", required=True name="prompt", description="The text prompt to optimize.", required=True
) )
@discord.guild_only() @discord.guild_only()
async def imgoptimize(self, ctx: discord.ApplicationContext, prompt: str): async def optimize(self, ctx: discord.ApplicationContext, prompt: str):
await ctx.defer() await ctx.defer()
user = ctx.user user = ctx.user

@ -4,8 +4,8 @@ import traceback
from pathlib import Path from pathlib import Path
import discord import discord
from discord.ext import commands
from dotenv import load_dotenv from dotenv import load_dotenv
from pycord.multicog import apply_multicog
import os import os
if sys.platform == "win32": if sys.platform == "win32":
@ -41,7 +41,7 @@ Settings for the bot
activity = discord.Activity( activity = discord.Activity(
type=discord.ActivityType.watching, name="for /help /g, and more!" type=discord.ActivityType.watching, name="for /help /g, and more!"
) )
bot = commands.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity) bot = discord.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity)
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
model = Model(usage_service) model = Model(usage_service)
@ -111,6 +111,8 @@ async def main():
) )
) )
apply_multicog(bot)
await bot.start(os.getenv("DISCORD_TOKEN")) await bot.start(os.getenv("DISCORD_TOKEN"))

@ -1,15 +1,31 @@
import discord import discord
from models.env_service_model import EnvService from models.env_service_model import EnvService
from typing import Callable from typing import Callable
ADMIN_ROLES = EnvService.get_admin_roles()
ALLOWED_ROLES = EnvService.get_allowed_roles() DALLE_ROLES = EnvService.get_dalle_roles()
GPT_ROLES = EnvService.get_gpt_roles()
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class Check: class Check:
def check_valid_roles() -> Callable: def check_admin_roles() -> Callable:
async def inner(ctx: discord.ApplicationContext):
if not any(role.name.lower() in ADMIN_ROLES for role in ctx.user.roles):
await ctx.defer(ephemeral=True)
await ctx.respond(
f"You don't have permission to use this.",
ephemeral=True,
delete_after=10,
)
return False
return True
return inner
def check_dalle_roles() -> Callable:
async def inner(ctx: discord.ApplicationContext): async def inner(ctx: discord.ApplicationContext):
if not any(role.name in ALLOWED_ROLES for role in ctx.user.roles): if not any(role.name.lower() in DALLE_ROLES for role in ctx.user.roles):
await ctx.defer(ephemeral=True) await ctx.defer(ephemeral=True)
await ctx.respond( await ctx.respond(
"You don't have permission to use this.", "You don't have permission to use this.",
@ -18,5 +34,17 @@ class Check:
) )
return False return False
return True return True
return inner
def check_gpt_roles() -> Callable:
async def inner(ctx: discord.ApplicationContext):
if not any(role.name.lower() in GPT_ROLES for role in ctx.user.roles):
await ctx.defer(ephemeral=True)
await ctx.respond(
"You don't have permission to use this.",
ephemeral=True,
delete_after=10,
)
return False
return True
return inner return inner

@ -33,26 +33,71 @@ class EnvService:
return allowed_guilds return allowed_guilds
@staticmethod @staticmethod
def get_allowed_roles(): def get_admin_roles():
# ALLOWED_ROLES is a comma separated list of string roles # ADMIN_ROLES is a comma separated list of string roles
# It can also just be one role # It can also just be one role
# Read these allowed roles and return as a list of strings # Read these allowed roles and return as a list of strings
try: try:
allowed_roles = os.getenv("ALLOWED_ROLES") admin_roles = os.getenv("ADMIN_ROLES")
except: except:
allowed_roles = None admin_roles = None
if allowed_roles is None: if admin_roles is None:
raise ValueError( raise ValueError(
"ALLOWED_ROLES is not defined properly in the environment file!" "ADMIN_ROLES is not defined properly in the environment file!"
"Please copy your server's role and put it into ALLOWED_ROLES in the .env file." "Please copy your server's role and put it into ADMIN_ROLES in the .env file."
'For example a line should look like: `ALLOWED_ROLES="Admin"`' 'For example a line should look like: `ADMIN_ROLES="Admin"`'
) )
allowed_roles = ( admin_roles = (
allowed_roles.split(",") if "," in allowed_roles else [allowed_roles] admin_roles.lower().split(",") if "," in admin_roles else [admin_roles.lower()]
) )
return allowed_roles return admin_roles
@staticmethod
def get_dalle_roles():
# DALLE_ROLES is a comma separated list of string roles
# It can also just be one role
# Read these allowed roles and return as a list of strings
try:
dalle_roles = os.getenv("DALLE_ROLES")
except:
dalle_roles = None
if dalle_roles is None:
raise ValueError(
"DALLE_ROLES is not defined properly in the environment file!"
"Please copy your server's role and put it into DALLE_ROLES in the .env file."
'For example a line should look like: `DALLE_ROLES="Dalle"`'
)
dalle_roles = (
dalle_roles.lower().split(",") if "," in dalle_roles else [dalle_roles.lower()]
)
return dalle_roles
@staticmethod
def get_gpt_roles():
# GPT_ROLES is a comma separated list of string roles
# It can also just be one role
# Read these allowed roles and return as a list of strings
try:
gpt_roles = os.getenv("GPT_ROLES")
except:
gpt_roles = None
if gpt_roles is None:
raise ValueError(
"GPT_ROLES is not defined properly in the environment file!"
"Please copy your server's role and put it into GPT_ROLES in the .env file."
'For example a line should look like: `GPT_ROLES="Gpt"`'
)
gpt_roles = (
gpt_roles.lower().strip().split(",") if "," in gpt_roles else [gpt_roles.lower()]
)
return gpt_roles
@staticmethod @staticmethod
def get_welcome_message(): def get_welcome_message():

@ -23,6 +23,7 @@ dependencies = [
"python-dotenv", "python-dotenv",
"requests", "requests",
"transformers", "transformers",
"pycord-multicog"
] ]
dynamic = ["version"] dynamic = ["version"]
[project.scripts] [project.scripts]

@ -3,3 +3,4 @@ py-cord==2.3.2
python-dotenv==0.21.0 python-dotenv==0.21.0
requests==2.28.1 requests==2.28.1
transformers==4.25.1 transformers==4.25.1
pycord-multicog==1.0.2

@ -1,7 +1,11 @@
OPENAI_TOKEN="<openai_api_token>" OPENAI_TOKEN="<openai_api_token>"
DISCORD_TOKEN="<discord_bot_token>" DISCORD_TOKEN="<discord_bot_token>"
DEBUG_GUILD="755420092027633774" DEBUG_GUILD="<debug_guild_id>"
DEBUG_CHANNEL="907974109084942396" DEBUG_CHANNEL="<debug_channel_id>"
ALLOWED_GUILDS="971268468148166697,971268468148166697" # make sure not to include an id the bot doesn't have access to
ALLOWED_ROLES="Admin,gpt" ALLOWED_GUILDS="<guild_id>,<guild_id>"
# no spaces, case sensitive, these define which roles have access to what. E.g if GPT_ROLES="gpt", then anyone with the "gpt" role can use GPT3 commands.
ADMIN_ROLES="admin,owner"
DALLE_ROLES="admin,openai,dalle"
GPT_ROLES="admin,openai,gpt"
WELCOME_MESSAGE="Hi There! Welcome to our Discord server. We hope you'll enjoy our server and we look forward to engaging with you!" WELCOME_MESSAGE="Hi There! Welcome to our Discord server. We hope you'll enjoy our server and we look forward to engaging with you!"

Loading…
Cancel
Save