Implemented slash command grouping and more ways to configure who can use what.

Refactored /g into /chat and /gpt-chat into /converse
Both have the grouping "gpt" so the full is /gpt chat, and /gpt converse
Added pycord-multicog to the project for easier slash grouping across cogs
Made the role check case insensitive and added some comments to the sample.env
Removed the rest of the discord.ext imports, it's not really used, ymmv
Removed another admin check i had missed in local-size
Rene Teigen 1 year ago
parent 583b766264
commit e5eeb035c1

1
.gitignore vendored

@ -4,6 +4,7 @@ __pycache__
/models/__pycache__ /models/__pycache__
#user files #user files
.env .env
.vscode
bot.pid bot.pid
usage.txt usage.txt
/dalleimages /dalleimages

@ -6,20 +6,19 @@ from io import BytesIO
import discord import discord
from PIL import Image from PIL import Image
from discord.ext import commands from pycord.multicog import add_to_group
# We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time # We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time
from models.env_service_model import EnvService from models.env_service_model import EnvService
from models.user_model import RedoUser from models.user_model import RedoUser
from models.check_model import Check
redo_users = {} redo_users = {}
users_to_interactions = {} users_to_interactions = {}
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class DrawDallEService(commands.Cog, name="DrawDallEService"): class DrawDallEService(discord.Cog, name="DrawDallEService"):
def __init__( def __init__(
self, bot, usage_service, model, message_queue, deletion_queue, converser_cog self, bot, usage_service, model, message_queue, deletion_queue, converser_cog
): ):
@ -130,12 +129,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
self.converser_cog.users_to_interactions[user_id].append( self.converser_cog.users_to_interactions[user_id].append(
result_message.id result_message.id
) )
@add_to_group("dalle")
@discord.slash_command( @discord.slash_command(
name="draw", name="draw",
description="Draw an image from a prompt", description="Draw an image from a prompt",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option(name="prompt", description="The prompt to draw from", required=True) @discord.option(name="prompt", description="The prompt to draw from", required=True)
async def draw(self, ctx: discord.ApplicationContext, prompt: str): async def draw(self, ctx: discord.ApplicationContext, prompt: str):
@ -155,6 +153,7 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
await ctx.respond("Something went wrong. Please try again later.") await ctx.respond("Something went wrong. Please try again later.")
await ctx.send_followup(e) await ctx.send_followup(e)
@add_to_group("admin")
@discord.slash_command( @discord.slash_command(
name="local-size", name="local-size",
description="Get the size of the dall-e images folder that we have on the current system", description="Get the size of the dall-e images folder that we have on the current system",
@ -165,8 +164,6 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
await ctx.defer() await ctx.defer()
# Get the size of the dall-e images folder that we have on the current system. # Get the size of the dall-e images folder that we have on the current system.
# Check if admin user # Check if admin user
if not await self.converser_cog.check_valid_roles(ctx.user, ctx):
return
image_path = self.model.IMAGE_SAVE_PATH image_path = self.model.IMAGE_SAVE_PATH
total_size = 0 total_size = 0
@ -179,11 +176,11 @@ class DrawDallEService(commands.Cog, name="DrawDallEService"):
total_size = total_size / 1000000 total_size = total_size / 1000000
await ctx.respond(f"The size of the local images folder is {total_size} MB.") await ctx.respond(f"The size of the local images folder is {total_size} MB.")
@add_to_group("admin")
@discord.slash_command( @discord.slash_command(
name="clear-local", name="clear-local",
description="Clear the local dalleimages folder on system.", description="Clear the local dalleimages folder on system.",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.guild_only() @discord.guild_only()
async def clear_local(self, ctx): async def clear_local(self, ctx):

@ -5,7 +5,7 @@ import traceback
from pathlib import Path from pathlib import Path
import discord import discord
from discord.ext import commands from pycord.multicog import add_to_group
from models.deletion_service_model import Deletion from models.deletion_service_model import Deletion
from models.env_service_model import EnvService from models.env_service_model import EnvService
@ -20,7 +20,7 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds()
print("THE ALLOWED GUILDS ARE: ", ALLOWED_GUILDS) print("THE ALLOWED GUILDS ARE: ", ALLOWED_GUILDS)
class GPT3ComCon(commands.Cog, name="GPT3ComCon"): class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
def __init__( def __init__(
self, self,
bot, bot,
@ -94,11 +94,28 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.message_queue = message_queue self.message_queue = message_queue
self.conversation_threads = {} self.conversation_threads = {}
@commands.Cog.listener() # Create slash command groups
admin = discord.SlashCommandGroup(name="admin",
description="Admin settings for the bot",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_admin_roles()]
)
dalle = discord.SlashCommandGroup(name="dalle",
description="Dalle related commands",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_dalle_roles()]
)
gpt = discord.SlashCommandGroup(name="gpt",
description="GPT related commands",
guild_ids=ALLOWED_GUILDS,
checks=[Check.check_gpt_roles()]
)
@discord.Cog.listener()
async def on_member_remove(self, member): async def on_member_remove(self, member):
pass pass
@commands.Cog.listener() @discord.Cog.listener()
async def on_ready(self): async def on_ready(self):
self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel( self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel(
self.DEBUG_CHANNEL self.DEBUG_CHANNEL
@ -113,18 +130,19 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
delete_existing=True, delete_existing=True,
) )
print(f"The debug channel was acquired and commands registered") print(f"The debug channel was acquired and commands registered")
@add_to_group("admin")
@discord.slash_command( @discord.slash_command(
name="set-usage", name="set-usage",
description="Set the current OpenAI usage (in dollars)", description="Set the current OpenAI usage (in dollars)",
checks=[Check.check_valid_roles()], guild_ids=ALLOWED_GUILDS,
) )
@discord.option( @discord.option(
name="usage_amount", name="usage_amount",
description="The current usage amount in dollars and cents (e.g 10.24)", description="The current usage amount in dollars and cents (e.g 10.24)",
type=float, type=float,
) )
async def set_usage(self, ctx, usage_amount: float): async def set_usage(self, ctx: discord.ApplicationContext, usage_amount: float):
await ctx.defer() await ctx.defer()
# Attempt to convert the input usage value into a float # Attempt to convert the input usage value into a float
@ -136,12 +154,13 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await ctx.respond("The usage value must be a valid float.") await ctx.respond("The usage value must be a valid float.")
return return
@add_to_group("admin")
@discord.slash_command( @discord.slash_command(
name="delete-conversation-threads", name="delete-conversation-threads",
description="Delete all conversation threads across the bot servers.", description="Delete all conversation threads across the bot servers.",
checks=[Check.check_valid_roles()], guild_ids=ALLOWED_GUILDS,
) )
async def delete_all_conversation_threads(self, ctx): async def delete_all_conversation_threads(self, ctx: discord.ApplicationContext):
await ctx.defer() await ctx.defer()
for guild in self.bot.guilds: for guild in self.bot.guilds:
@ -195,12 +214,12 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
title="GPT3Bot Help", description="The current commands", color=0xC730C7 title="GPT3Bot Help", description="The current commands", color=0xC730C7
) )
embed.add_field( embed.add_field(
name="/g <prompt>", name="/chat",
value="Ask GPT3 something. Be clear, long, and concise in your prompt. Don't waste tokens.", value="Ask GPT3 something. Be clear, long, and concise in your prompt. Don't waste tokens.",
inline=False, inline=False,
) )
embed.add_field( embed.add_field(
name="/chat-gpt", value="Start a conversation with GPT3", inline=False name="/converse", value="Start a conversation with GPT3", inline=False
) )
embed.add_field( embed.add_field(
name="/end-chat", name="/end-chat",
@ -408,7 +427,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.conversating_users[message.author.id].history = new_conversation_history self.conversating_users[message.author.id].history = new_conversation_history
# A listener for message edits to redo prompts if they are edited # A listener for message edits to redo prompts if they are edited
@commands.Cog.listener() @discord.Cog.listener()
async def on_message_edit(self, before, after): async def on_message_edit(self, before, after):
if after.author.id in self.redo_users: if after.author.id in self.redo_users:
if after.id == original_message[after.author.id]: if after.id == original_message[after.author.id]:
@ -439,7 +458,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.redo_users[after.author.id].prompt = after.content self.redo_users[after.author.id].prompt = after.content
@commands.Cog.listener() @discord.Cog.listener()
async def on_message(self, message): async def on_message(self, message):
# Get the message from context # Get the message from context
@ -657,17 +676,17 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await self.end_conversation(ctx) await self.end_conversation(ctx)
return return
@add_to_group("gpt")
@discord.slash_command( @discord.slash_command(
name="g", name="chat",
description="Ask GPT3 something!", description="Ask GPT3 something!",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option( @discord.option(
name="prompt", description="The prompt to send to GPT3", required=True name="prompt", description="The prompt to send to GPT3", required=True
) )
@discord.guild_only() @discord.guild_only()
async def g(self, ctx: discord.ApplicationContext, prompt: str): async def chat(self, ctx: discord.ApplicationContext, prompt: str):
await ctx.defer() await ctx.defer()
user = ctx.user user = ctx.user
@ -679,11 +698,11 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await self.encapsulated_send(user.id, prompt, ctx, from_g_command=True) await self.encapsulated_send(user.id, prompt, ctx, from_g_command=True)
@add_to_group("gpt")
@discord.slash_command( @discord.slash_command(
name="chat-gpt", name="converse",
description="Have a conversation with GPT3", description="Have a conversation with GPT3",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option( @discord.option(
name="opener", description="Which sentence to start with", required=False name="opener", description="Which sentence to start with", required=False
@ -701,7 +720,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
choices=["yes"], choices=["yes"],
) )
@discord.guild_only() @discord.guild_only()
async def chat_gpt( async def converse(
self, ctx: discord.ApplicationContext, opener: str, private, minimal self, ctx: discord.ApplicationContext, opener: str, private, minimal
): ):
if private: if private:
@ -781,6 +800,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
self.conversation_threads[user_id_normalized] = thread.id self.conversation_threads[user_id_normalized] = thread.id
@add_to_group("gpt")
@discord.slash_command( @discord.slash_command(
name="end-chat", name="end-chat",
description="End a conversation with GPT3", description="End a conversation with GPT3",
@ -801,6 +821,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
await self.send_help_text(ctx) await self.send_help_text(ctx)
@add_to_group("admin")
@discord.slash_command( @discord.slash_command(
name="usage", name="usage",
description="Get usage statistics for GPT3Discord", description="Get usage statistics for GPT3Discord",
@ -811,6 +832,7 @@ class GPT3ComCon(commands.Cog, name="GPT3ComCon"):
await ctx.defer() await ctx.defer()
await self.send_usage_text(ctx) await self.send_usage_text(ctx)
@add_to_group("admin")
@discord.slash_command( @discord.slash_command(
name="settings", name="settings",
description="Get settings for GPT3Discord", description="Get settings for GPT3Discord",

@ -2,16 +2,15 @@ import re
import traceback import traceback
import discord import discord
from discord.ext import commands
from models.env_service_model import EnvService from models.env_service_model import EnvService
from models.user_model import RedoUser from models.user_model import RedoUser
from models.check_model import Check from pycord.multicog import add_to_group
ALLOWED_GUILDS = EnvService.get_allowed_guilds() ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"): class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"
def __init__( def __init__(
@ -46,12 +45,12 @@ class ImgPromptOptimizer(commands.Cog, name="ImgPromptOptimizer"):
except: except:
traceback.print_exc() traceback.print_exc()
self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT
@add_to_group("dalle")
@discord.slash_command( @discord.slash_command(
name="imgoptimize", name="imgoptimize",
description="Optimize a text prompt for DALL-E/MJ/SD image generation.", description="Optimize a text prompt for DALL-E/MJ/SD image generation.",
guild_ids=ALLOWED_GUILDS, guild_ids=ALLOWED_GUILDS,
checks=[Check.check_valid_roles()],
) )
@discord.option( @discord.option(
name="prompt", description="The text prompt to optimize.", required=True name="prompt", description="The text prompt to optimize.", required=True

@ -4,8 +4,8 @@ import traceback
from pathlib import Path from pathlib import Path
import discord import discord
from discord.ext import commands
from dotenv import load_dotenv from dotenv import load_dotenv
from pycord.multicog import apply_multicog
from cogs.draw_image_generation import DrawDallEService from cogs.draw_image_generation import DrawDallEService
from cogs.gpt_3_commands_and_converser import GPT3ComCon from cogs.gpt_3_commands_and_converser import GPT3ComCon
@ -34,7 +34,7 @@ Settings for the bot
activity = discord.Activity( activity = discord.Activity(
type=discord.ActivityType.watching, name="for /help /g, and more!" type=discord.ActivityType.watching, name="for /help /g, and more!"
) )
bot = commands.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity) bot = discord.Bot(intents=discord.Intents.all(), command_prefix="!", activity=activity)
usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd())))
model = Model(usage_service) model = Model(usage_service)
@ -104,6 +104,8 @@ async def main():
) )
) )
apply_multicog(bot)
await bot.start(os.getenv("DISCORD_TOKEN")) await bot.start(os.getenv("DISCORD_TOKEN"))

@ -1,22 +1,50 @@
import discord import discord
from models.env_service_model import EnvService from models.env_service_model import EnvService
from typing import Callable from typing import Callable
ADMIN_ROLES = EnvService.get_admin_roles()
ALLOWED_ROLES = EnvService.get_allowed_roles() DALLE_ROLES = EnvService.get_dalle_roles()
GPT_ROLES = EnvService.get_gpt_roles()
ALLOWED_GUILDS = EnvService.get_allowed_guilds()
class Check: class Check:
def check_valid_roles() -> Callable: def check_admin_roles() -> Callable:
async def inner(ctx: discord.ApplicationContext): async def inner(ctx: discord.ApplicationContext):
if not any(role.name in ALLOWED_ROLES for role in ctx.user.roles): if not any(role.name.lower() in ADMIN_ROLES for role in ctx.user.roles):
await ctx.defer(ephemeral=True) await ctx.defer(ephemeral=True)
await ctx.respond( await ctx.respond(
"You don't have permission to use this.", "You don't have admin permission to use this.",
ephemeral=True, ephemeral=True,
delete_after=10, delete_after=10,
) )
return False return False
return True return True
return inner
def check_dalle_roles() -> Callable:
async def inner(ctx: discord.ApplicationContext):
if not any(role.name.lower() in DALLE_ROLES for role in ctx.user.roles):
await ctx.defer(ephemeral=True)
await ctx.respond(
"You don't have dalle permission to use this.",
ephemeral=True,
delete_after=10,
)
return False
return True
return inner return inner
def check_gpt_roles() -> Callable:
async def inner(ctx: discord.ApplicationContext):
if not any(role.name.lower() in GPT_ROLES for role in ctx.user.roles):
await ctx.defer(ephemeral=True)
await ctx.respond(
"You don't have gpt permission to use this.",
ephemeral=True,
delete_after=10,
)
return False
return True
return inner

@ -33,23 +33,67 @@ class EnvService:
return allowed_guilds return allowed_guilds
@staticmethod @staticmethod
def get_allowed_roles(): def get_admin_roles():
# ALLOWED_ROLES is a comma separated list of string roles # ADMIN_ROLES is a comma separated list of string roles
# It can also just be one role # It can also just be one role
# Read these allowed roles and return as a list of strings # Read these allowed roles and return as a list of strings
try: try:
allowed_roles = os.getenv("ALLOWED_ROLES") admin_roles = os.getenv("ADMIN_ROLES")
except: except:
allowed_roles = None admin_roles = None
if allowed_roles is None: if admin_roles is None:
raise ValueError( raise ValueError(
"ALLOWED_ROLES is not defined properly in the environment file!" "ADMIN_ROLES is not defined properly in the environment file!"
"Please copy your server's role and put it into ALLOWED_ROLES in the .env file." "Please copy your server's role and put it into ADMIN_ROLES in the .env file."
'For example a line should look like: `ALLOWED_ROLES="Admin"`' 'For example a line should look like: `ADMIN_ROLES="Admin"`'
) )
allowed_roles = ( admin_roles = (
allowed_roles.split(",") if "," in allowed_roles else [allowed_roles] admin_roles.lower().split(",") if "," in admin_roles else [admin_roles.lower()]
) )
return allowed_roles return admin_roles
@staticmethod
def get_dalle_roles():
# DALLE_ROLES is a comma separated list of string roles
# It can also just be one role
# Read these allowed roles and return as a list of strings
try:
dalle_roles = os.getenv("DALLE_ROLES")
except:
dalle_roles = None
if dalle_roles is None:
raise ValueError(
"DALLE_ROLES is not defined properly in the environment file!"
"Please copy your server's role and put it into DALLE_ROLES in the .env file."
'For example a line should look like: `DALLE_ROLES="Dalle"`'
)
dalle_roles = (
dalle_roles.lower().split(",") if "," in dalle_roles else [dalle_roles.lower()]
)
return dalle_roles
@staticmethod
def get_gpt_roles():
# GPT_ROLES is a comma separated list of string roles
# It can also just be one role
# Read these allowed roles and return as a list of strings
try:
gpt_roles = os.getenv("GPT_ROLES")
except:
gpt_roles = None
if gpt_roles is None:
raise ValueError(
"GPT_ROLES is not defined properly in the environment file!"
"Please copy your server's role and put it into GPT_ROLES in the .env file."
'For example a line should look like: `GPT_ROLES="Gpt"`'
)
gpt_roles = (
gpt_roles.lower().strip().split(",") if "," in gpt_roles else [gpt_roles.lower()]
)
return gpt_roles

@ -25,6 +25,7 @@ dependencies = [
"python-dotenv", "python-dotenv",
"requests", "requests",
"transformers", "transformers",
"pycord-multicog"
] ]
dynamic = ["version"] dynamic = ["version"]
[project.scripts] [project.scripts]

@ -5,3 +5,4 @@ py-cord==2.3.2
python-dotenv==0.21.0 python-dotenv==0.21.0
requests==2.28.1 requests==2.28.1
transformers==4.25.1 transformers==4.25.1
pycord-multicog==1.0.2

@ -1,6 +1,10 @@
OPENAI_TOKEN="<openai_api_token>" OPENAI_TOKEN="<openai_api_token>"
DISCORD_TOKEN="<discord_bot_token>" DISCORD_TOKEN="<discord_bot_token>"
DEBUG_GUILD="755420092027633774" DEBUG_GUILD="<debug_guild_id>"
DEBUG_CHANNEL="907974109084942396" DEBUG_CHANNEL="<debug_channel_id>"
ALLOWED_GUILDS="971268468148166697,971268468148166697" # make sure not to include an id the bot doesn't have access to
ALLOWED_ROLES="Admin,gpt" ALLOWED_GUILDS="<guild_id>,<guild_id>"
# no spaces, case insensitive
ADMIN_ROLES="admin,owner"
DALLE_ROLES="admin,openai,dalle"
GPT_ROLES="admin,openai,gpt"
Loading…
Cancel
Save