From 7e5dace02b501d218952e3da4b7ebe999a1f8ae9 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sun, 15 Jan 2023 14:35:02 +0000 Subject: [PATCH 1/8] Docs, fixes and linting --- cogs/commands.py | 68 +++++++------ cogs/image_service_cog.py | 13 +-- cogs/moderations_service_cog.py | 34 +++++-- cogs/prompt_optimizer_cog.py | 12 +-- cogs/text_service_cog.py | 171 ++++++++++++++++++++------------ models/autocomplete_model.py | 23 +++-- models/check_model.py | 12 +-- models/openai_model.py | 31 +++--- services/image_service.py | 35 +++++-- services/moderations_service.py | 12 +++ services/text_service.py | 2 +- 11 files changed, 253 insertions(+), 160 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index 17eaf08..aa0a292 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -9,6 +9,7 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() class Commands(discord.Cog, name="Commands"): + '''Cog containing all slash and context commands as one-liners''' def __init__( self, bot, @@ -58,9 +59,9 @@ class Commands(discord.Cog, name="Commands"): checks=[Check.check_admin_roles()], ) - """ - System commands - """ + # + #System commands + # @add_to_group("system") @discord.slash_command( @@ -139,9 +140,9 @@ class Commands(discord.Cog, name="Commands"): async def delete_all_conversation_threads(self, ctx: discord.ApplicationContext): await self.converser_cog.delete_all_conversation_threads_command(ctx) - """ - (system) Moderation commands - """ + #""" + #Moderation commands + #""" @add_to_group("mod") @discord.slash_command( @@ -188,14 +189,15 @@ class Commands(discord.Cog, name="Commands"): ) @discord.option( name="type", - description="The type of moderation to configure ('warn' or 'delete')", + description="The type of moderation to configure", + choices=["warn", "delete"], required=True, ) @discord.option( name="hate", description="The threshold for hate speech", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -203,7 +205,7 @@ class Commands(discord.Cog, name="Commands"): name="hate_threatening", description="The threshold for hate/threatening speech", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -211,7 +213,7 @@ class Commands(discord.Cog, name="Commands"): name="self_harm", description="The threshold for self_harm speech", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -219,7 +221,7 @@ class Commands(discord.Cog, name="Commands"): name="sexual", description="The threshold for sexual speech", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -227,7 +229,7 @@ class Commands(discord.Cog, name="Commands"): name="sexual_minors", description="The threshold for sexual speech with minors in context", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -235,7 +237,7 @@ class Commands(discord.Cog, name="Commands"): name="violence", description="The threshold for violent speech", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -243,7 +245,7 @@ class Commands(discord.Cog, name="Commands"): name="violence_graphic", description="The threshold for violent and graphic speech", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -272,9 +274,9 @@ class Commands(discord.Cog, name="Commands"): violence_graphic, ) - """ - GPT commands - """ + # + #GPT commands + # @add_to_group("gpt") @discord.slash_command( @@ -289,7 +291,7 @@ class Commands(discord.Cog, name="Commands"): name="temperature", description="Higher values means the model will take more risks", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -297,7 +299,7 @@ class Commands(discord.Cog, name="Commands"): name="top_p", description="1 is greedy sampling, 0.1 means only considering the top 10% of probability distribution", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -305,7 +307,7 @@ class Commands(discord.Cog, name="Commands"): name="frequency_penalty", description="Decreasing the model's likelihood to repeat the same line verbatim", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=-2, max_value=2, ) @@ -313,7 +315,7 @@ class Commands(discord.Cog, name="Commands"): name="presence_penalty", description="Increasing the model's likelihood to talk about new topics", required=False, - input_type=float, + input_type=discord.SlashCommandOptionType.number, min_value=-2, max_value=2, ) @@ -433,9 +435,9 @@ class Commands(discord.Cog, name="Commands"): async def end(self, ctx: discord.ApplicationContext): await self.converser_cog.end_command(ctx) - """ - DALLE commands - """ + # + #DALLE commands + # @add_to_group("dalle") @discord.slash_command( @@ -460,9 +462,9 @@ class Commands(discord.Cog, name="Commands"): async def optimize(self, ctx: discord.ApplicationContext, prompt: str): await self.image_service_cog.optimize_command(ctx, prompt) - """ - Other commands - """ + # + #Other commands + # @discord.slash_command( name="private-test", @@ -489,9 +491,9 @@ class Commands(discord.Cog, name="Commands"): async def setup(self, ctx: discord.ApplicationContext): await self.converser_cog.setup_command(ctx) - """ - Text-based context menu commands from here - """ + # + #Text-based context menu commands from here + # @discord.message_command( name="Ask GPT", guild_ids=ALLOWED_GUILDS, checks=[Check.check_gpt_roles()] @@ -499,9 +501,9 @@ class Commands(discord.Cog, name="Commands"): async def ask_gpt_action(self, ctx, message: discord.Message): await self.converser_cog.ask_gpt_action(ctx, message) - """ - Image-based context menu commands from here - """ + # + #Image-based context menu commands from here + # @discord.message_command( name="Draw", guild_ids=ALLOWED_GUILDS, checks=[Check.check_dalle_roles()] diff --git a/cogs/image_service_cog.py b/cogs/image_service_cog.py index 38d79bb..e030f55 100644 --- a/cogs/image_service_cog.py +++ b/cogs/image_service_cog.py @@ -1,19 +1,13 @@ import asyncio import os -import tempfile import traceback -from io import BytesIO -import aiohttp import discord -from PIL import Image # We don't use the converser cog here because we want to be able to redo for the last images and text prompts at the same time from sqlitedict import SqliteDict -from cogs.text_service_cog import GPT3ComCon from services.environment_service import EnvService -from models.user_model import RedoUser from services.image_service import ImageService from services.text_service import TextService @@ -27,6 +21,7 @@ if USER_INPUT_API_KEYS: class DrawDallEService(discord.Cog, name="DrawDallEService"): + '''Cog containing a draw commands and file management for saved images''' def __init__( self, bot, usage_service, model, message_queue, deletion_queue, converser_cog ): @@ -43,6 +38,7 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"): async def draw_command( self, ctx: discord.ApplicationContext, prompt: str, from_action=False ): + '''With an ApplicationContext and prompt, send a dalle image to the invoked channel. Ephemeral if from an action''' user_api_key = None if USER_INPUT_API_KEYS: user_api_key = await TextService.get_user_api_key( @@ -74,11 +70,12 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"): await ctx.send_followup(e, ephemeral=from_action) async def draw_action(self, ctx, message): + '''decoupler to handle context actions for the draw command''' await self.draw_command(ctx, message.content, from_action=True) async def local_size_command(self, ctx: discord.ApplicationContext): + '''Get the folder size of the image folder''' await ctx.defer() - # Get the size of the dall-e images folder that we have on the current system. image_path = self.model.IMAGE_SAVE_PATH total_size = 0 @@ -92,9 +89,9 @@ class DrawDallEService(discord.Cog, name="DrawDallEService"): await ctx.respond(f"The size of the local images folder is {total_size} MB.") async def clear_local_command(self, ctx): + '''Delete all local images''' await ctx.defer() - # Delete all the local images in the images folder. image_path = self.model.IMAGE_SAVE_PATH for dirpath, dirnames, filenames in os.walk(image_path): for f in filenames: diff --git a/cogs/moderations_service_cog.py b/cogs/moderations_service_cog.py index a0cf2ea..a7c9ed4 100644 --- a/cogs/moderations_service_cog.py +++ b/cogs/moderations_service_cog.py @@ -16,6 +16,7 @@ except Exception as e: class ModerationsService(discord.Cog, name="ModerationsService"): + '''Cog containing moderation tools and features''' def __init__( self, bot, @@ -40,7 +41,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): @discord.Cog.listener() async def on_ready(self): - # Check moderation service for each guild + '''Check moderation service for each guild''' for guild in self.bot.guilds: self.get_or_set_warn_set(guild.id) self.get_or_set_delete_set(guild.id) @@ -48,16 +49,20 @@ class ModerationsService(discord.Cog, name="ModerationsService"): print("The moderation service is ready.") def check_guild_moderated(self, guild_id): + '''Given guild id, return bool of moderation status''' return guild_id in MOD_DB and MOD_DB[guild_id]["moderated"] def get_moderated_alert_channel(self, guild_id): + '''Given guild id, return alert channel''' return MOD_DB[guild_id]["alert_channel"] def set_moderated_alert_channel(self, guild_id, channel_id): + '''Given guild id and channel id, set channel to recieve alerts''' MOD_DB[guild_id] = {"moderated": True, "alert_channel": channel_id} MOD_DB.commit() def get_or_set_warn_set(self, guild_id): + '''Get warn_set set for the guild, if not set them from default values''' guild_id = str(guild_id) key = guild_id + "_warn_set" if key not in MOD_DB: @@ -68,6 +73,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): return dict(MOD_DB[key]) def get_or_set_delete_set(self, guild_id): + '''Get delete_set set for the guild, if not set them from default values''' guild_id = str(guild_id) key = guild_id + "_delete_set" if key not in MOD_DB: @@ -78,18 +84,21 @@ class ModerationsService(discord.Cog, name="ModerationsService"): return dict(MOD_DB[key]) def set_warn_set(self, guild_id, threshold_set): + '''Set threshold for warning a message''' guild_id = str(guild_id) key = guild_id + "_warn_set" MOD_DB[key] = zip(threshold_set.keys, threshold_set.thresholds) MOD_DB.commit() def set_delete_set(self, guild_id, threshold_set): + '''Set threshold for deleting a message''' guild_id = str(guild_id) key = guild_id + "_delete_set" MOD_DB[key] = zip(threshold_set.keys, threshold_set.thresholds) MOD_DB.commit() def set_guild_moderated(self, guild_id, status=True): + '''Set the guild to moderated or not''' if guild_id not in MOD_DB: MOD_DB[guild_id] = {"moderated": status, "alert_channel": 0} MOD_DB.commit() @@ -101,7 +110,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): MOD_DB.commit() async def check_and_launch_moderations(self, guild_id, alert_channel_override=None): - # Create the moderations service. + '''Create the moderation service''' print("Checking and attempting to launch moderations service...") if self.check_guild_moderated(guild_id): Moderation.moderation_queues[guild_id] = asyncio.Queue() @@ -135,6 +144,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): async def moderations_command( self, ctx: discord.ApplicationContext, status: str, alert_channel_id: str ): + '''command handler for toggling moderation and setting an alert channel''' await ctx.defer() status = status.lower().strip() @@ -162,6 +172,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): ) async def stop_moderations_service(self, guild_id): + '''Remove guild moderation status and stop the service''' self.set_guild_moderated(guild_id, False) Moderation.moderation_tasks[guild_id].cancel() Moderation.moderation_tasks[guild_id] = None @@ -169,6 +180,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): Moderation.moderations_launched.remove(guild_id) async def start_moderations_service(self, guild_id, alert_channel_id=None): + '''Set guild moderation and start the service''' self.set_guild_moderated(guild_id) moderations_channel = await self.check_and_launch_moderations( guild_id, @@ -179,8 +191,13 @@ class ModerationsService(discord.Cog, name="ModerationsService"): self.set_moderated_alert_channel(guild_id, moderations_channel.id) async def restart_moderations_service(self, ctx): + '''restarts the moderation of the guild it's run in''' + if not self.check_guild_moderated(ctx.guild_id): + await ctx.respond("Moderations are not enabled, can't restart") + return + await ctx.respond( - f"The moderations service is being restarted...", + "The moderations service is being restarted...", ephemeral=True, delete_after=30, ) @@ -197,11 +214,11 @@ class ModerationsService(discord.Cog, name="ModerationsService"): delete_after=30, ) - async def build_moderation_settings_embed(self, type, mod_set): + async def build_moderation_settings_embed(self, category, mod_set): embed = discord.Embed( title="Moderation Settings", - description="The moderation settings for this guild for the type: " + type, + description="The moderation settings for this guild for the type: " + category, color=discord.Color.yellow() if type == "warn" else discord.Color.red(), ) @@ -223,11 +240,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): violence, violence_graphic, ): - config_type = config_type.lower().strip() - if config_type not in ["warn", "delete"]: - await ctx.respond("Invalid config type, please use `warn` or `delete`") - return - + '''command handler for assigning threshold values for warn or delete''' all_args = [ hate, hate_threatening, @@ -290,6 +303,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): async def moderations_test_command( self, ctx: discord.ApplicationContext, prompt: str ): + '''command handler for checking moderation values of a given input''' await ctx.defer() response = await self.model.send_moderations_request(prompt) await ctx.respond(response["results"][0]["category_scores"]) diff --git a/cogs/prompt_optimizer_cog.py b/cogs/prompt_optimizer_cog.py index 687a039..e85203a 100644 --- a/cogs/prompt_optimizer_cog.py +++ b/cogs/prompt_optimizer_cog.py @@ -49,14 +49,14 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): with image_pretext_path.open("r") as file: self.OPTIMIZER_PRETEXT = file.read() print(f"Loaded image optimizer pretext from {image_pretext_path}") - except: + except Exception: traceback.print_exc() self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT async def optimize_command(self, ctx: discord.ApplicationContext, prompt: str): user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx) + user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) if not user_api_key: return @@ -73,7 +73,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): final_prompt += "." # Get the token amount for the prompt - tokens = self.usage_service.count_tokens(final_prompt) + #tokens = self.usage_service.count_tokens(final_prompt) try: response = await self.model.send_request( @@ -100,7 +100,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): try: if len(response_text.split()) > 75: response_text = " ".join(response_text.split()[-70:]) - except: + except Exception: pass response_message = await ctx.respond( @@ -251,10 +251,10 @@ class RedoButton(discord.ui.Button["OptimizeView"]): ].in_interaction(interaction_id): # Get the message and the prompt and call encapsulated_send ctx = self.converser_cog.redo_users[user_id].ctx - message = self.converser_cog.redo_users[user_id].message + #message = self.converser_cog.redo_users[user_id].message prompt = self.converser_cog.redo_users[user_id].prompt response_message = self.converser_cog.redo_users[user_id].response - msg = await interaction.response.send_message( + await interaction.response.send_message( "Redoing your original request...", ephemeral=True, delete_after=20 ) await TextService.encapsulated_send( diff --git a/cogs/text_service_cog.py b/cogs/text_service_cog.py index 4b44f07..f0219c6 100644 --- a/cogs/text_service_cog.py +++ b/cogs/text_service_cog.py @@ -1,4 +1,3 @@ -import asyncio import datetime import re import traceback @@ -11,7 +10,6 @@ import json import discord -from models.check_model import Check from services.environment_service import EnvService from services.message_queue_service import Message from services.moderations_service import Moderation @@ -28,9 +26,9 @@ if sys.platform == "win32": else: separator = "/" -""" -Get the user key service if it is enabled. -""" +# +#Get the user key service if it is enabled. +# USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_KEY_DB = None if USER_INPUT_API_KEYS: @@ -54,10 +52,10 @@ if USER_INPUT_API_KEYS: print("Retrieved/created the user key database") -""" -Obtain the Moderation table and the General table, these are two SQLite tables that contain -information about the server that are used for persistence and to auto-restart the moderation service. -""" +# +#Obtain the Moderation table and the General table, these are two SQLite tables that contain +#information about the server that are used for persistence and to auto-restart the moderation service. +# MOD_DB = None GENERAL_DB = None try: @@ -158,6 +156,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.conversation_thread_owners = {} async def load_file(self, file, ctx): + '''Take filepath, return content or respond if not found''' try: async with aiofiles.open(file, "r") as f: return await f.read() @@ -170,6 +169,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): @discord.Cog.listener() async def on_member_join(self, member): + '''When members join send welcome message if enabled''' if self.model.welcome_message_enabled: query = f"Please generate a welcome message for {member.name} who has just joined the server." @@ -178,7 +178,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): query, tokens=self.usage_service.count_tokens(query) ) welcome_message = str(welcome_message_response["choices"][0]["text"]) - except: + except Exception: welcome_message = None if not welcome_message: @@ -195,6 +195,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): @discord.Cog.listener() async def on_ready(self): + '''When ready to recieve data set debug channel and sync commands''' self.debug_channel = self.bot.get_guild(self.DEBUG_GUILD).get_channel( self.DEBUG_CHANNEL ) @@ -209,10 +210,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): check_guilds=[], delete_existing=True, ) - print(f"Commands synced") + print("Commands synced") - # TODO: add extra condition to check if multi is enabled for the thread, stated in conversation_threads - def check_conversing(self, user_id, channel_id, message_content, multi=None): + def check_conversing(self, channel_id, message_content): + '''given channel id and a message, return true if it's a conversation thread, false if not, or if the message starts with "~"''' cond1 = channel_id in self.conversation_threads # If the trimmed message starts with a Tilde, then we want to not contribute this to the conversation try: @@ -226,6 +227,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): async def end_conversation( self, ctx, opener_user_id=None, conversation_limit=False ): + '''end the thread of the user interacting with the bot, if the conversation has reached the limit close it for the owner''' normalized_user_id = opener_user_id if opener_user_id else ctx.author.id if ( conversation_limit @@ -234,7 +236,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): else: try: channel_id = self.conversation_thread_owners[normalized_user_id] - except: + except Exception: await ctx.delete(delay=5) await ctx.reply( "Only the conversation starter can end this.", delete_after=5 @@ -276,12 +278,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): thread = await self.bot.fetch_channel(channel_id) await thread.edit(locked=True) await thread.edit(name="Closed-GPT") - except: + except Exception: traceback.print_exc() - pass - except: + except Exception: traceback.print_exc() - pass else: if normalized_user_id in self.conversation_thread_owners: thread_id = self.conversation_thread_owners[normalized_user_id] @@ -292,11 +292,11 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): thread = await self.bot.fetch_channel(thread_id) await thread.edit(locked=True) await thread.edit(name="Closed-GPT") - except: + except Exception: traceback.print_exc() - pass async def send_settings_text(self, ctx): + '''compose and return the settings menu to the interacting user''' embed = discord.Embed( title="GPT3Bot Settings", description="The current settings of the model", @@ -325,9 +325,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ), inline=True, ) - await ctx.respond(embed=embed) + await ctx.respond(embed=embed, ephemeral=True) async def process_settings(self, ctx, parameter, value): + '''Given a parameter and value set the corresponding parameter in storage to the value''' # Check if the parameter is a valid parameter if hasattr(self.model, parameter): @@ -354,12 +355,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.respond("The parameter is not a valid parameter") def generate_debug_message(self, prompt, response): + '''create a debug message with a prompt and a response field''' debug_message = "----------------------------------------------------------------------------------\n" debug_message += "Prompt:\n```\n" + prompt + "\n```\n" debug_message += "Response:\n```\n" + json.dumps(response, indent=4) + "\n```\n" return debug_message async def paginate_and_send(self, response_text, ctx): + '''paginate a response to a text cutoff length and send it in chunks''' from_context = isinstance(ctx, discord.ApplicationContext) response_text = [ @@ -382,7 +385,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.channel.send(chunk) async def paginate_embed(self, response_text, codex, prompt=None, instruction=None): - + '''Given a response text make embed pages and return a list of the pages. Codex makes it a codeblock in the embed''' if codex: # clean codex input response_text = response_text.replace("```", "") response_text = response_text.replace(f"***Prompt: {prompt}***\n", "") @@ -416,9 +419,11 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): return pages async def queue_debug_message(self, debug_message, debug_channel): + '''Put a message into the debug queue''' await self.message_queue.put(Message(debug_message, debug_channel)) async def queue_debug_chunks(self, debug_message, debug_channel): + '''Put a message as chunks into the debug queue''' debug_message_chunks = [ debug_message[i : i + self.TEXT_CUTOFF] for i in range(0, len(debug_message), self.TEXT_CUTOFF) @@ -445,6 +450,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await self.message_queue.put(Message(chunk, debug_channel)) async def send_debug_message(self, debug_message, debug_channel): + '''process a debug message and put directly into queue or chunk it''' # Send the debug message try: if len(debug_message) > self.TEXT_CUTOFF: @@ -458,6 +464,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) async def check_conversation_limit(self, message): + '''Check if a conversation has reached the set limit and end it if it has''' # After each response, check if the user has reached the conversation limit in terms of messages or time. if message.channel.id in self.conversation_threads: # If the user has reached the max conversation length, end the conversation @@ -471,6 +478,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await self.end_conversation(message, conversation_limit=True) async def summarize_conversation(self, message, prompt): + '''Takes a conversation history filled prompt and summarizes it to then start a new history with it as the base''' response = await self.model.send_summary_request(prompt) summarized_text = response["choices"][0]["text"] @@ -502,6 +510,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # A listener for message edits to redo prompts if they are edited @discord.Cog.listener() async def on_message_edit(self, before, after): + '''When a message is edited run moderation if enabled, and process if it a prompt that should be redone''' if after.author.id == self.bot.user.id: return @@ -524,11 +533,10 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): @discord.Cog.listener() async def on_message(self, message): + '''On a new message check if it should be moderated then process it for conversation''' if message.author == self.bot.user: return - content = message.content.strip() - # Moderations service is done here. if ( hasattr(message, "guild") @@ -550,6 +558,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): original_message[message.author.id] = message.id def cleanse_response(self, response_text): + '''Cleans history tokens from response''' response_text = response_text.replace("GPTie:\n", "") response_text = response_text.replace("GPTie:", "") response_text = response_text.replace("GPTie: ", "") @@ -559,6 +568,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): def remove_awaiting( self, author_id, channel_id, from_ask_command, from_edit_command ): + '''Remove user from ask/edit command response wait, if not any of those then process the id to remove user from thread response wait''' if author_id in self.awaiting_responses: self.awaiting_responses.remove(author_id) if not from_ask_command and not from_edit_command: @@ -566,22 +576,23 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.awaiting_thread_responses.remove(channel_id) async def mention_to_username(self, ctx, message): + '''replaces discord mentions with their server nickname in text, if the user is not found keep the mention as is''' if not discord.utils.raw_mentions(message): return message - else: - for mention in discord.utils.raw_mentions(message): - try: - user = await discord.utils.get_or_fetch( - ctx.guild, "member", mention - ) - message = message.replace(f"<@{str(mention)}>", user.display_name) - except: - pass - return message + for mention in discord.utils.raw_mentions(message): + try: + user = await discord.utils.get_or_fetch( + ctx.guild, "member", mention + ) + message = message.replace(f"<@{str(mention)}>", user.display_name) + except Exception: + pass + return message # COMMANDS async def help_command(self, ctx): + '''Command handler. Generates a help message and sends it to the user''' await ctx.defer() embed = discord.Embed( title="GPT3Bot Help", description="The current commands", color=0xC730C7 @@ -631,11 +642,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) embed.add_field(name="/help", value="See this help text", inline=False) - await ctx.respond(embed=embed) + await ctx.respond(embed=embed, ephemeral=True) async def set_usage_command( self, ctx: discord.ApplicationContext, usage_amount: float ): + '''Command handler. Sets the usage file to the given value''' await ctx.defer() # Attempt to convert the input usage value into a float @@ -643,26 +655,27 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): usage = float(usage_amount) await self.usage_service.set_usage(usage) await ctx.respond(f"Set the usage to {usage}") - except: + except Exception: await ctx.respond("The usage value must be a valid float.") return async def delete_all_conversation_threads_command( self, ctx: discord.ApplicationContext ): + '''Command handler. Deletes all threads made by the bot in the current guild''' await ctx.defer() - for guild in self.bot.guilds: - for thread in guild.threads: - thread_name = thread.name.lower() - if "with gpt" in thread_name or "closed-gpt" in thread_name: - try: - await thread.delete() - except: - pass - await ctx.respond("All conversation threads have been deleted.") + for thread in ctx.guild.threads: + thread_name = thread.name.lower() + if "with gpt" in thread_name or "closed-gpt" in thread_name: + try: + await thread.delete() + except Exception: + pass + await ctx.respond("All conversation threads in this server have been deleted.") async def usage_command(self, ctx): + '''Command handler. Responds with the current usage of the bot''' await ctx.defer() embed = discord.Embed( title="GPT3Bot Usage", description="The current usage", color=0x00FF00 @@ -690,6 +703,17 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): presence_penalty: float, from_action=None, ): + """Command handler. Requests and returns a generation with no extras to the completion endpoint + + Args: + ctx (discord.ApplicationContext): Command interaction + prompt (str): A prompt to use for generation + temperature (float): Sets the temperature override + top_p (float): Sets the top p override + frequency_penalty (float): Sets the frequency penalty override + presence_penalty (float): Sets the presence penalty override + from_action (bool, optional): Enables ephemeral. Defaults to None. + """ user = ctx.user prompt = await self.mention_to_username(ctx, prompt.strip()) @@ -712,26 +736,36 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): presence_penalty_override=presence_penalty, from_ask_command=True, custom_api_key=user_api_key, - from_action=prompt, + from_action=from_action, ) async def edit_command( self, ctx: discord.ApplicationContext, instruction: str, - input: str, + text: str, temperature: float, top_p: float, codex: bool, ): + """Command handler. Requests and returns a generation with no extras to the edit endpoint + + Args: + ctx (discord.ApplicationContext): Command interaction + instruction (str): The modification instructions + text (str): The text that should be modified + temperature (float): Sets the temperature override + top_p (float): Sets the top p override + codex (bool): Enables the codex edit model + """ user = ctx.user - input = await self.mention_to_username(ctx, input.strip()) + text = await self.mention_to_username(ctx, text.strip()) instruction = await self.mention_to_username(ctx, instruction.strip()) user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key(user.id, ctx) + user_api_key = await TextService.get_user_api_key(user.id, ctx, USER_KEY_DB) if not user_api_key: return @@ -740,7 +774,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await TextService.encapsulated_send( self, user.id, - prompt=input, + prompt=text, ctx=ctx, temp_override=temperature, top_p_override=top_p, @@ -751,6 +785,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ) async def private_test_command(self, ctx: discord.ApplicationContext): + '''Command handler. Creates a private thread in the current channel''' await ctx.defer(ephemeral=True) await ctx.respond("Your private test thread") thread = await ctx.channel.create_thread( @@ -769,12 +804,22 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): private: bool, minimal: bool, ): + """Command handler. Starts a conversation with the bot + + Args: + ctx (discord.ApplicationContext): Command interaction + opener (str): The first prompt to send in the conversation + opener_file (str): A .txt or .json file which is appended before the opener + private (bool): If the thread should be private + minimal (bool): If a minimal starter should be used + """ + user = ctx.user # If we are in user input api keys mode, check if the user has entered their api key before letting them continue user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key(user.id, ctx) + user_api_key = await TextService.get_user_api_key(user.id, ctx, USER_KEY_DB) if not user_api_key: return @@ -784,7 +829,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await ctx.defer() if user.id in self.conversation_thread_owners: - message = await ctx.respond( + await ctx.respond( "You've already created a thread, end it before creating a new one", delete_after=5, ) @@ -852,12 +897,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): opener_file.get("text", "error getting text") + opener ) - except: # Parse as just regular text + except Exception: # Parse as just regular text if not opener: opener = opener_file else: opener = opener_file + opener - except: + except Exception: opener_file = None # Just start a regular thread if the file fails to load # Append the starter text for gpt3 to the user's history so it gets concatenated with the prompt later @@ -876,7 +921,8 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): await thread.send( f"<@{str(user_id_normalized)}> You are now conversing with GPT3. *Say hi to start!*\n" - f"Overrides for this thread is **temp={overrides['temperature']}**, **top_p={overrides['top_p']}**, **frequency penalty={overrides['frequency_penalty']}**, **presence penalty={overrides['presence_penalty']}**\n" + f"Overrides for this thread is **temp={overrides['temperature']}**, **top_p={overrides['top_p']}**" + f", **frequency penalty={overrides['frequency_penalty']}**, **presence penalty={overrides['presence_penalty']}**\n" f"The model used is **{self.conversation_threads[thread.id].model}**\n" f"End the conversation by saying `end`.\n\n" f"If you want GPT3 to ignore your messages, start your messages with `~`\n\n" @@ -921,11 +967,12 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): self.awaiting_thread_responses.remove(thread.id) async def end_command(self, ctx: discord.ApplicationContext): + '''Command handler. Gets the user's thread and ends it''' await ctx.defer(ephemeral=True) user_id = ctx.user.id try: thread_id = self.conversation_thread_owners[user_id] - except: + except Exception: await ctx.respond( "You haven't started any conversations", ephemeral=True, delete_after=10 ) @@ -936,13 +983,13 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): except Exception as e: print(e) traceback.print_exc() - pass else: await ctx.respond( "You're not in any conversations", ephemeral=True, delete_after=10 ) async def setup_command(self, ctx: discord.ApplicationContext): + '''Command handler. Opens the setup modal''' if not USER_INPUT_API_KEYS: await ctx.respond( "This server doesn't support user input API keys.", @@ -956,6 +1003,7 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): async def settings_command( self, ctx: discord.ApplicationContext, parameter: str = None, value: str = None ): + '''Command handler. Returns current settings or sets new values''' await ctx.defer() if parameter is None and value is None: await self.send_settings_text(ctx) @@ -976,13 +1024,14 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): # Otherwise, process the settings change await self.process_settings(ctx, parameter, value) - """ - Text-based context menu commands from here - """ + # + #Text-based context menu commands from here + # async def ask_gpt_action( self, ctx, message: discord.Message - ): # message commands return the message + ): + '''Message command. Return the message''' await self.ask_command( ctx, message.content, None, None, None, None, from_action=message.content ) diff --git a/models/autocomplete_model.py b/models/autocomplete_model.py index 62371c1..f390263 100644 --- a/models/autocomplete_model.py +++ b/models/autocomplete_model.py @@ -12,7 +12,9 @@ model = Model(usage_service) class Settings_autocompleter: + '''autocompleter for the settings command''' async def get_settings(ctx: discord.AutocompleteContext): + '''get settings for the settings option''' SETTINGS = [ re.sub("^_", "", key) for key in model.__dict__.keys() @@ -27,6 +29,7 @@ class Settings_autocompleter: async def get_value( ctx: discord.AutocompleteContext, ): # Behaves a bit weird if you go back and edit the parameter without typing in a new command + '''gets valid values for the value option''' values = { "max_conversation_length": [str(num) for num in range(1, 500, 2)], "num_images": [str(num) for num in range(1, 4 + 1)], @@ -40,19 +43,21 @@ class Settings_autocompleter: "num_conversation_lookback": [str(num) for num in range(5, 15 + 1)], "summarize_threshold": [str(num) for num in range(800, 3500, 50)], } - if ctx.options["parameter"] in values.keys(): - return [ - value - for value in values[ctx.options["parameter"]] - if value.startswith(ctx.value.lower()) - ] - else: - await ctx.interaction.response.defer() # defer so the autocomplete in int values doesn't error but rather just says not found - return [] + for parameter in values: + if parameter == ctx.options["parameter"]: + return [ + value + for value in values[ctx.options["parameter"]] + if value.startswith(ctx.value.lower()) + ] + await ctx.interaction.response.defer() # defer so the autocomplete in int values doesn't error but rather just says not found + return [] class File_autocompleter: + '''Autocompleter for the opener command''' async def get_openers(ctx: discord.AutocompleteContext): + '''get all files in the openers folder''' try: return [ file diff --git a/models/check_model.py b/models/check_model.py index 43d30d6..caf0777 100644 --- a/models/check_model.py +++ b/models/check_model.py @@ -10,7 +10,7 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() class Check: - def check_admin_roles() -> Callable: + def check_admin_roles(self) -> Callable: async def inner(ctx: discord.ApplicationContext): if ADMIN_ROLES == [None]: return True @@ -18,7 +18,7 @@ class Check: if not any(role.name.lower() in ADMIN_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - f"You don't have permission to use this.", + f"You don't have permission, list of roles is {ADMIN_ROLES}", ephemeral=True, delete_after=10, ) @@ -27,14 +27,14 @@ class Check: return inner - def check_dalle_roles() -> Callable: + def check_dalle_roles(self) -> Callable: async def inner(ctx: discord.ApplicationContext): if DALLE_ROLES == [None]: return True if not any(role.name.lower() in DALLE_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - "You don't have permission to use this.", + "You don't have permission, list of roles is {DALLE_ROLES}", ephemeral=True, delete_after=10, ) @@ -43,14 +43,14 @@ class Check: return inner - def check_gpt_roles() -> Callable: + def check_gpt_roles(self) -> Callable: async def inner(ctx: discord.ApplicationContext): if GPT_ROLES == [None]: return True if not any(role.name.lower() in GPT_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - "You don't have permission to use this.", + "You don't have permission, list of roles is {GPT_ROLES}", ephemeral=True, delete_after=10, ) diff --git a/models/openai_model.py b/models/openai_model.py index 59b1d9d..f83825b 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -5,7 +5,7 @@ import os import tempfile import traceback import uuid -from typing import Tuple, List, Any +from typing import Any, Tuple import aiohttp import backoff @@ -14,7 +14,6 @@ import discord # An enum of two modes, TOP_P or TEMPERATURE import requests from PIL import Image -from aiohttp import RequestInfo from discord import File @@ -43,10 +42,9 @@ class Model: self._temp = 0.6 # Higher value means more random, lower value means more likely to be a coherent sentence self._top_p = 0.9 # 1 is equivalent to greedy sampling, 0.1 means that the model will only consider the top 10% of the probability distribution self._max_tokens = 4000 # The maximum number of tokens the model can generate - self._presence_penalty = ( - 0 # Penalize new tokens based on whether they appear in the text so far - ) - self._frequency_penalty = 0 # Penalize new tokens based on their existing frequency in the text so far. (Higher frequency = lower probability of being chosen.) + self._presence_penalty = 0 # Penalize new tokens based on whether they appear in the text so far + # Penalize new tokens based on their existing frequency in the text so far. (Higher frequency = lower probability of being chosen.) + self._frequency_penalty = 0 self._best_of = 1 # Number of responses to compare the loglikelihoods of self._prompt_min_length = 8 self._max_conversation_length = 100 @@ -66,7 +64,7 @@ class Model: try: self.IMAGE_SAVE_PATH = os.environ["IMAGE_SAVE_PATH"] self.custom_image_path = True - except: + except Exception: self.IMAGE_SAVE_PATH = "dalleimages" # Try to make this folder called images/ in the local directory if it doesnt exist if not os.path.exists(self.IMAGE_SAVE_PATH): @@ -355,11 +353,11 @@ class Model: try: tokens_used = int(response["usage"]["total_tokens"]) await self.usage_service.update_usage(tokens_used) - except: + except Exception as e: raise ValueError( "The API returned an invalid response: " + str(response["error"]["message"]) - ) + ) from e @backoff.on_exception( backoff.expo, @@ -402,7 +400,7 @@ class Model: async def send_edit_request( self, instruction, - input=None, + text=None, temp_override=None, top_p_override=None, codex=False, @@ -417,14 +415,14 @@ class Model: ) print( - f"The text about to be edited is [{input}] with instructions [{instruction}] codex [{codex}]" + f"The text about to be edited is [{text}] with instructions [{instruction}] codex [{codex}]" ) print(f"Overrides -> temp:{temp_override}, top_p:{top_p_override}") async with aiohttp.ClientSession(raise_for_status=True) as session: payload = { "model": Models.EDIT if codex is False else Models.CODE_EDIT, - "input": "" if input is None else input, + "input": "" if text is None else text, "instruction": instruction, "temperature": self.temp if temp_override is None else temp_override, "top_p": self.top_p if top_p_override is None else top_p_override, @@ -477,8 +475,10 @@ class Model: """ summary_request_text = [] summary_request_text.append( - "The following is a conversation instruction set and a conversation" - " between two people, a , and GPTie. Firstly, determine the 's name from the conversation history, then summarize the conversation. Do not summarize the instructions for GPTie, only the conversation. Summarize the conversation in a detailed fashion. If mentioned their name, be sure to mention it in the summary. Pay close attention to things the has told you, such as personal details." + "The following is a conversation instruction set and a conversation between two people, a , and GPTie." + " Firstly, determine the 's name from the conversation history, then summarize the conversation." + " Do not summarize the instructions for GPTie, only the conversation. Summarize the conversation in a detailed fashion. If mentioned" + " their name, be sure to mention it in the summary. Pay close attention to things the has told you, such as personal details." ) summary_request_text.append(prompt + "\nDetailed summary of conversation: \n") @@ -533,8 +533,7 @@ class Model: model=None, custom_api_key=None, ) -> ( - dict, - bool, + Tuple[dict, bool] ): # The response, and a boolean indicating whether or not the context limit was reached. # Validate that all the parameters are in a good state before we send the request diff --git a/services/image_service.py b/services/image_service.py index ff98845..9e66ce8 100644 --- a/services/image_service.py +++ b/services/image_service.py @@ -25,6 +25,18 @@ class ImageService: draw_from_optimizer=None, custom_api_key=None, ): + """service function that takes input and returns an image generation + + Args: + image_service_cog (Cog): The cog which contains draw related commands + user_id (int): A discord user id + prompt (string): Prompt for the model + ctx (ApplicationContext): A discord ApplicationContext, from an interaction + response_message (Message, optional): A discord message. Defaults to None. + vary (bool, optional): If the image is a variation of another one. Defaults to None. + draw_from_optimizer (bool, optional): If the prompt is passed from the optimizer command. Defaults to None. + custom_api_key (str, optional): User defined OpenAI API key. Defaults to None. + """ await asyncio.sleep(0) # send the prompt to the model from_context = isinstance(ctx, discord.ApplicationContext) @@ -42,16 +54,18 @@ class ImageService: message = ( f"The API returned an invalid response: **{e.status}: {e.message}**" ) - await ctx.channel.send(message) if not from_context else await ctx.respond( - message - ) + if not from_context: + await ctx.channel.send(message) + else: + await ctx.respond(message, ephemeral=True) return except ValueError as e: message = f"Error: {e}. Please try again with a different prompt." - await ctx.channel.send(message) if not from_context else await ctx.respond( - message, ephemeral=True - ) + if not from_context: + await ctx.channel.send(message) + else: + await ctx.respond(message, ephemeral=True) return @@ -70,7 +84,8 @@ class ImageService: embed.set_image(url=f"attachment://{file.filename}") if not response_message: # Original generation case - # Start an interaction with the user, we also want to send data embed=embed, file=file, view=SaveView(image_urls, image_service_cog, image_service_cog.converser_cog) + # Start an interaction with the user, we also want to send data embed=embed, file=file, + # view=SaveView(image_urls, image_service_cog, image_service_cog.converser_cog) result_message = ( await ctx.channel.send( embed=embed, @@ -104,7 +119,7 @@ class ImageService: prompt=prompt, message=ctx, ctx=ctx, - response=response_message, + response=result_message, instruction=None, codex=False, paginator=None, @@ -166,7 +181,7 @@ class ImageService: prompt=prompt, message=ctx, ctx=ctx, - response=response_message, + response=result_message, instruction=None, codex=False, paginator=None, @@ -357,7 +372,7 @@ class RedoButton(discord.ui.Button["SaveView"]): prompt = self.cog.redo_users[user_id].prompt response_message = self.cog.redo_users[user_id].response message = await interaction.response.send_message( - f"Regenerating the image for your original prompt, check the original message.", + "Regenerating the image for your original prompt, check the original message.", ephemeral=True, ) self.converser_cog.users_to_interactions[user_id].append(message.id) diff --git a/services/moderations_service.py b/services/moderations_service.py index 11dff74..5698a38 100644 --- a/services/moderations_service.py +++ b/services/moderations_service.py @@ -22,6 +22,17 @@ class ModerationResult: class ThresholdSet: def __init__(self, h_t, hv_t, sh_t, s_t, sm_t, v_t, vg_t): + """A set of thresholds for the OpenAI moderation endpoint + + Args: + h_t (float): hate + hv_t (float): hate/violence + sh_t (float): self-harm + s_t (float): sexual + sm_t (float): sexual/minors + v_t (float): violence + vg_t (float): violence/graphic + """ self.keys = [ "hate", "hate/threatening", @@ -44,6 +55,7 @@ class ThresholdSet: # The string representation is just the keys alongside the threshold values def __str__(self): + '''"key": value format''' # "key": value format return ", ".join([f"{k}: {v}" for k, v in zip(self.keys, self.thresholds)]) diff --git a/services/text_service.py b/services/text_service.py index 47d19cb..efca958 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -490,7 +490,7 @@ class TextService: ): content = message.content.strip() conversing = converser_cog.check_conversing( - message.author.id, message.channel.id, content + message.channel.id, content ) # If the user is conversing and they want to end it, end it immediately before we continue any further. From 207df295f01b4ff08803fae848dc486dcb766601 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sun, 15 Jan 2023 17:39:13 +0000 Subject: [PATCH 2/8] Even more --- cogs/commands.py | 4 +-- cogs/moderations_service_cog.py | 2 +- cogs/prompt_optimizer_cog.py | 2 ++ gpt3discord.py | 50 ++++++++++++++++--------------- models/autocomplete_model.py | 2 +- models/check_model.py | 13 ++++---- models/openai_model.py | 2 +- models/user_model.py | 20 ++++++------- services/deletion_service.py | 3 +- services/environment_service.py | 28 ++++++++--------- services/message_queue_service.py | 2 +- services/moderations_service.py | 15 ++++------ services/pinecone_service.py | 27 ++++++++--------- services/text_service.py | 29 +++++++++--------- services/usage_service.py | 5 ++-- 15 files changed, 101 insertions(+), 103 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index aa0a292..e43d66d 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -374,13 +374,13 @@ class Commands(discord.Cog, name="Commands"): self, ctx: discord.ApplicationContext, instruction: str, - input: str, + text: str, temperature: float, top_p: float, codex: bool, ): await self.converser_cog.edit_command( - ctx, instruction, input, temperature, top_p, codex + ctx, instruction, text, temperature, top_p, codex ) @add_to_group("gpt") diff --git a/cogs/moderations_service_cog.py b/cogs/moderations_service_cog.py index a7c9ed4..1754ec9 100644 --- a/cogs/moderations_service_cog.py +++ b/cogs/moderations_service_cog.py @@ -193,7 +193,7 @@ class ModerationsService(discord.Cog, name="ModerationsService"): async def restart_moderations_service(self, ctx): '''restarts the moderation of the guild it's run in''' if not self.check_guild_moderated(ctx.guild_id): - await ctx.respond("Moderations are not enabled, can't restart") + await ctx.respond("Moderations are not enabled, can't restart", ephemeral=True, delete_after=30) return await ctx.respond( diff --git a/cogs/prompt_optimizer_cog.py b/cogs/prompt_optimizer_cog.py index e85203a..6adf6ae 100644 --- a/cogs/prompt_optimizer_cog.py +++ b/cogs/prompt_optimizer_cog.py @@ -18,6 +18,7 @@ if USER_INPUT_API_KEYS: class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): + '''cog containing the optimizer command''' _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" def __init__( @@ -54,6 +55,7 @@ class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): self.OPTIMIZER_PRETEXT = self._OPTIMIZER_PRETEXT async def optimize_command(self, ctx: discord.ApplicationContext, prompt: str): + '''Command handler. Given a string it generates an output that's fitting for image generation''' user_api_key = None if USER_INPUT_API_KEYS: user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) diff --git a/gpt3discord.py b/gpt3discord.py index 7a4c7ef..0c0b33f 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -1,3 +1,4 @@ +import os import asyncio import sys import traceback @@ -6,36 +7,37 @@ from pathlib import Path import discord import pinecone from pycord.multicog import apply_multicog -import os - -from cogs.moderations_service_cog import ModerationsService -from services.pinecone_service import PineconeService - -if sys.platform == "win32": - separator = "\\" -else: - separator = "/" -from cogs.image_service_cog import DrawDallEService from cogs.text_service_cog import GPT3ComCon +from cogs.image_service_cog import DrawDallEService from cogs.prompt_optimizer_cog import ImgPromptOptimizer +from cogs.moderations_service_cog import ModerationsService from cogs.commands import Commands + +from services.pinecone_service import PineconeService from services.deletion_service import Deletion from services.message_queue_service import Message -from models.openai_model import Model from services.usage_service import UsageService from services.environment_service import EnvService +from models.openai_model import Model + __version__ = "7.0" -""" -The pinecone service is used to store and retrieve conversation embeddings. -""" +if sys.platform == "win32": + separator = "\\" +else: + separator = "/" + +# +# The pinecone service is used to store and retrieve conversation embeddings. +# + try: PINECONE_TOKEN = os.getenv("PINECONE_TOKEN") -except: +except Exception: PINECONE_TOKEN = None pinecone_service = None @@ -55,18 +57,18 @@ if PINECONE_TOKEN: print("Got the pinecone service") -""" -Message queueing for the debug service, defer debug messages to be sent later so we don't hit rate limits. -""" +# +# Message queueing for the debug service, defer debug messages to be sent later so we don't hit rate limits. +# message_queue = asyncio.Queue() deletion_queue = asyncio.Queue() asyncio.ensure_future(Message.process_message_queue(message_queue, 1.5, 5)) asyncio.ensure_future(Deletion.process_deletion_queue(deletion_queue, 1, 1)) -""" -Settings for the bot -""" +# +#Settings for the bot +# activity = discord.Activity( type=discord.ActivityType.watching, name="for /help /gpt, and more!" ) @@ -75,9 +77,9 @@ usage_service = UsageService(Path(os.environ.get("DATA_DIR", os.getcwd()))) model = Model(usage_service) -""" -An encapsulating wrapper for the discord.py client. This uses the old re-write without cogs, but it gets the job done! -""" +# +# An encapsulating wrapper for the discord.py client. This uses the old re-write without cogs, but it gets the job done! +# @bot.event # Using self gives u diff --git a/models/autocomplete_model.py b/models/autocomplete_model.py index f390263..51f06b9 100644 --- a/models/autocomplete_model.py +++ b/models/autocomplete_model.py @@ -66,5 +66,5 @@ class File_autocompleter: ][ :25 ] # returns the 25 first files from your current input - except: + except Exception: return ["No 'openers' folder"] diff --git a/models/check_model.py b/models/check_model.py index caf0777..c717ebf 100644 --- a/models/check_model.py +++ b/models/check_model.py @@ -10,7 +10,8 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() class Check: - def check_admin_roles(self) -> Callable: + @staticmethod + def check_admin_roles() -> Callable: async def inner(ctx: discord.ApplicationContext): if ADMIN_ROLES == [None]: return True @@ -27,14 +28,15 @@ class Check: return inner - def check_dalle_roles(self) -> Callable: + @staticmethod + def check_dalle_roles() -> Callable: async def inner(ctx: discord.ApplicationContext): if DALLE_ROLES == [None]: return True if not any(role.name.lower() in DALLE_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - "You don't have permission, list of roles is {DALLE_ROLES}", + f"You don't have permission, list of roles is {DALLE_ROLES}", ephemeral=True, delete_after=10, ) @@ -43,14 +45,15 @@ class Check: return inner - def check_gpt_roles(self) -> Callable: + @staticmethod + def check_gpt_roles() -> Callable: async def inner(ctx: discord.ApplicationContext): if GPT_ROLES == [None]: return True if not any(role.name.lower() in GPT_ROLES for role in ctx.user.roles): await ctx.defer(ephemeral=True) await ctx.respond( - "You don't have permission, list of roles is {GPT_ROLES}", + f"You don't have permission, list of roles is {GPT_ROLES}", ephemeral=True, delete_after=10, ) diff --git a/models/openai_model.py b/models/openai_model.py index f83825b..ff03035 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -669,7 +669,7 @@ class Model: images = await asyncio.get_running_loop().run_in_executor( None, lambda: [ - Image.open(requests.get(url, stream=True).raw) for url in image_urls + Image.open(requests.get(url, stream=True, timeout=10).raw) for url in image_urls ], ) diff --git a/models/user_model.py b/models/user_model.py index 761531f..f8221a5 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -34,8 +34,8 @@ class RedoUser: class User: - def __init__(self, id): - self.id = id + def __init__(self, user_id): + self.user_id = user_id self.history = [] self.count = 0 @@ -43,21 +43,21 @@ class User: # objects in a list, and we did `if 1203910293001 in user_list`, it would return True # if the user with that ID was in the list def __eq__(self, other): - return self.id == other.id + return self.user_id == other.id def __hash__(self): - return hash(self.id) + return hash(self.user_id) def __repr__(self): - return f"User(id={self.id}, history={self.history})" + return f"User(id={self.user_id}, history={self.history})" def __str__(self): return self.__repr__() class Thread: - def __init__(self, id): - self.id = id + def __init__(self, thread_id): + self.thread_id = thread_id self.history = [] self.count = 0 self.model = None @@ -90,13 +90,13 @@ class Thread: # objects in a list, and we did `if 1203910293001 in user_list`, it would return True # if the user with that ID was in the list def __eq__(self, other): - return self.id == other.id + return self.thread_id == other.id def __hash__(self): - return hash(self.id) + return hash(self.thread_id) def __repr__(self): - return f"Thread(id={self.id}, history={self.history})" + return f"Thread(id={self.thread_id}, history={self.history})" def __str__(self): return self.__repr__() diff --git a/services/deletion_service.py b/services/deletion_service.py index d1a5710..6ddcc44 100644 --- a/services/deletion_service.py +++ b/services/deletion_service.py @@ -39,6 +39,5 @@ class Deletion: # Sleep for a short time before processing the next message # This will prevent the bot from spamming messages too quickly await asyncio.sleep(PROCESS_WAIT_TIME) - except: + except Exception: traceback.print_exc() - pass diff --git a/services/environment_service.py b/services/environment_service.py index c6f10e7..701b587 100644 --- a/services/environment_service.py +++ b/services/environment_service.py @@ -31,9 +31,9 @@ class EnvService: @staticmethod def environment_path_with_fallback(env_name, relative_fallback=None): - dir = os.getenv(env_name) - if dir != None: - return Path(dir).resolve() + directory = os.getenv(env_name) + if directory is not None: + return Path(directory).resolve() if relative_fallback: app_relative = (app_root_path() / relative_fallback).resolve() @@ -70,7 +70,7 @@ class EnvService: # Read these allowed guilds and return as a list of ints try: allowed_guilds = os.getenv("ALLOWED_GUILDS") - except: + except Exception: allowed_guilds = None if allowed_guilds is None: @@ -93,7 +93,7 @@ class EnvService: # Read these allowed roles and return as a list of strings try: admin_roles = os.getenv("ADMIN_ROLES") - except: + except Exception: admin_roles = None if admin_roles is None: @@ -119,7 +119,7 @@ class EnvService: # Read these allowed roles and return as a list of strings try: dalle_roles = os.getenv("DALLE_ROLES") - except: + except Exception: dalle_roles = None if dalle_roles is None: @@ -145,7 +145,7 @@ class EnvService: # Read these allowed roles and return as a list of strings try: gpt_roles = os.getenv("GPT_ROLES") - except: + except Exception: gpt_roles = None if gpt_roles is None: @@ -171,7 +171,7 @@ class EnvService: # The string is DMd to the new server member as part of an embed. try: welcome_message = os.getenv("WELCOME_MESSAGE") - except: + except Exception: welcome_message = "Hi there! Welcome to our Discord server!" return welcome_message @@ -181,7 +181,7 @@ class EnvService: # The string can be blank but this is not advised. If a string cannot be found in the .env file, the below string is used. try: moderations_alert_channel = os.getenv("MODERATIONS_ALERT_CHANNEL") - except: + except Exception: moderations_alert_channel = None return moderations_alert_channel @@ -191,9 +191,8 @@ class EnvService: user_input_api_keys = os.getenv("USER_INPUT_API_KEYS") if user_input_api_keys.lower().strip() == "true": return True - else: - return False - except: + return False + except Exception: return False @staticmethod @@ -202,7 +201,6 @@ class EnvService: user_key_db_path = os.getenv("USER_KEY_DB_PATH") if user_key_db_path is None: return None - else: - return Path(user_key_db_path) - except: + return Path(user_key_db_path) + except Exception: return None diff --git a/services/message_queue_service.py b/services/message_queue_service.py index 3221661..c382528 100644 --- a/services/message_queue_service.py +++ b/services/message_queue_service.py @@ -22,7 +22,7 @@ class Message: # Send the message try: await message.channel.send(message.content) - except: + except Exception: pass # Sleep for a short time before processing the next message diff --git a/services/moderations_service.py b/services/moderations_service.py index 5698a38..83061da 100644 --- a/services/moderations_service.py +++ b/services/moderations_service.py @@ -155,10 +155,9 @@ class Moderation: if delete_result: return ModerationResult.DELETE - elif warn_result: + if warn_result: return ModerationResult.WARN - else: - return ModerationResult.NONE + return ModerationResult.NONE # This function will be called by the bot to process the message queue @staticmethod @@ -238,9 +237,8 @@ class Moderation: # Sleep for a short time before processing the next message # This will prevent the bot from spamming messages too quickly await asyncio.sleep(PROCESS_WAIT_TIME) - except: + except Exception: traceback.print_exc() - pass class ModerationAdminView(discord.ui.View): @@ -352,7 +350,7 @@ class KickUserButton(discord.ui.Button["ModerationAdminView"]): await self.message.author.kick( reason="You broke the server rules. Please rejoin and review the rules." ) - except: + except Exception: pass await interaction.response.send_message( "This user was attempted to be kicked", ephemeral=True, delete_after=10 @@ -392,7 +390,7 @@ class TimeoutUserButton(discord.ui.Button["ModerationAdminView"]): # Get the user id try: await self.message.delete() - except: + except Exception: pass try: @@ -400,9 +398,8 @@ class TimeoutUserButton(discord.ui.Button["ModerationAdminView"]): until=discord.utils.utcnow() + timedelta(hours=self.hours), reason="Breaking the server chat rules", ) - except Exception as e: + except Exception: traceback.print_exc() - pass await interaction.response.send_message( f"This user was timed out for {self.hours} hour(s)", diff --git a/services/pinecone_service.py b/services/pinecone_service.py index d57c029..2334f21 100644 --- a/services/pinecone_service.py +++ b/services/pinecone_service.py @@ -37,20 +37,19 @@ class PineconeService: }, ) return first_embedding - else: - embedding = await model.send_embedding_request( - text, custom_api_key=custom_api_key - ) - self.index.upsert( - [ - ( - text, - embedding, - {"conversation_id": conversation_id, "timestamp": timestamp}, - ) - ] - ) - return embedding + embedding = await model.send_embedding_request( + text, custom_api_key=custom_api_key + ) + self.index.upsert( + [ + ( + text, + embedding, + {"conversation_id": conversation_id, "timestamp": timestamp}, + ) + ] + ) + return embedding def get_n_similar(self, conversation_id: int, embedding, n=10): response = self.index.query( diff --git a/services/text_service.py b/services/text_service.py index efca958..cd116cc 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -470,9 +470,11 @@ class TextService: except Exception: message = "Something went wrong, please try again later. This may be due to upstream issues on the API, or rate limiting." - await ctx.send_followup(message) if from_context else await ctx.reply( - message - ) + if not from_context: + await ctx.send_followup(message) + else: + await ctx.reply(message) + converser_cog.remove_awaiting( ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command ) @@ -480,7 +482,7 @@ class TextService: try: await converser_cog.end_conversation(ctx) - except: + except Exception: pass return @@ -675,9 +677,9 @@ class TextService: converser_cog.redo_users[after.author.id].prompt = edited_content -""" -Conversation interaction buttons -""" +# +#Conversation interaction buttons +# class ConversationView(discord.ui.View): @@ -753,7 +755,6 @@ class EndConvoButton(discord.ui.Button["ConversationView"]): await interaction.response.send_message( e, ephemeral=True, delete_after=30 ) - pass else: await interaction.response.send_message( "This is not your conversation to end!", ephemeral=True, delete_after=10 @@ -789,7 +790,7 @@ class RedoButton(discord.ui.Button["ConversationView"]): response_message = self.converser_cog.redo_users[user_id].response codex = self.converser_cog.redo_users[user_id].codex - msg = await interaction.response.send_message( + await interaction.response.send_message( "Retrying your original request...", ephemeral=True, delete_after=15 ) @@ -815,9 +816,9 @@ class RedoButton(discord.ui.Button["ConversationView"]): ) -""" -The setup modal when using user input API keys -""" +# +#The setup modal when using user input API keys +# class SetupModal(discord.ui.Modal): @@ -880,7 +881,7 @@ class SetupModal(discord.ui.Modal): ephemeral=True, delete_after=10, ) - except Exception as e: + except Exception: traceback.print_exc() await interaction.followup.send( "There was an error saving your API key.", @@ -888,5 +889,3 @@ class SetupModal(discord.ui.Modal): delete_after=30, ) return - - pass diff --git a/services/usage_service.py b/services/usage_service.py index 4bf7e4e..dd13927 100644 --- a/services/usage_service.py +++ b/services/usage_service.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import aiofiles @@ -36,8 +35,8 @@ class UsageService: await f.close() return usage - def count_tokens(self, input): - res = self.tokenizer(input)["input_ids"] + def count_tokens(self, text): + res = self.tokenizer(text)["input_ids"] return len(res) async def update_usage_image(self, image_size): From ba3615f878553e1b9dc1e5121722d26fbdc2dee5 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sun, 15 Jan 2023 17:40:22 +0000 Subject: [PATCH 3/8] Bump version --- gpt3discord.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt3discord.py b/gpt3discord.py index 0c0b33f..f2caf04 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -23,7 +23,7 @@ from services.environment_service import EnvService from models.openai_model import Model -__version__ = "7.0" +__version__ = "7.1" if sys.platform == "win32": From 3028244fe65421630164f44912d80b5b12b6a304 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sun, 15 Jan 2023 17:59:52 +0000 Subject: [PATCH 4/8] More cleaning and added choice to mod set --- cogs/commands.py | 14 ++------------ cogs/moderations_service_cog.py | 5 ----- services/text_service.py | 2 +- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index e43d66d..e1b6861 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -169,6 +169,7 @@ class Commands(discord.Cog, name="Commands"): name="status", description="Enable or disable the moderations service for the current guild (on/off)", required=True, + choices=["on", "off"] ) @discord.option( name="alert_channel_id", @@ -197,7 +198,6 @@ class Commands(discord.Cog, name="Commands"): name="hate", description="The threshold for hate speech", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -205,7 +205,6 @@ class Commands(discord.Cog, name="Commands"): name="hate_threatening", description="The threshold for hate/threatening speech", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -213,7 +212,6 @@ class Commands(discord.Cog, name="Commands"): name="self_harm", description="The threshold for self_harm speech", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -221,7 +219,6 @@ class Commands(discord.Cog, name="Commands"): name="sexual", description="The threshold for sexual speech", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -229,7 +226,6 @@ class Commands(discord.Cog, name="Commands"): name="sexual_minors", description="The threshold for sexual speech with minors in context", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -237,7 +233,6 @@ class Commands(discord.Cog, name="Commands"): name="violence", description="The threshold for violent speech", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -245,7 +240,6 @@ class Commands(discord.Cog, name="Commands"): name="violence_graphic", description="The threshold for violent and graphic speech", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -291,7 +285,6 @@ class Commands(discord.Cog, name="Commands"): name="temperature", description="Higher values means the model will take more risks", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -299,7 +292,6 @@ class Commands(discord.Cog, name="Commands"): name="top_p", description="1 is greedy sampling, 0.1 means only considering the top 10% of probability distribution", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=0, max_value=1, ) @@ -307,7 +299,6 @@ class Commands(discord.Cog, name="Commands"): name="frequency_penalty", description="Decreasing the model's likelihood to repeat the same line verbatim", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=-2, max_value=2, ) @@ -315,7 +306,6 @@ class Commands(discord.Cog, name="Commands"): name="presence_penalty", description="Increasing the model's likelihood to talk about new topics", required=False, - input_type=discord.SlashCommandOptionType.number, min_value=-2, max_value=2, ) @@ -345,7 +335,7 @@ class Commands(discord.Cog, name="Commands"): required=True, ) @discord.option( - name="input", + name="text", description="The text you want to edit, can be empty", required=False, default="", diff --git a/cogs/moderations_service_cog.py b/cogs/moderations_service_cog.py index 1754ec9..9fee0e9 100644 --- a/cogs/moderations_service_cog.py +++ b/cogs/moderations_service_cog.py @@ -147,11 +147,6 @@ class ModerationsService(discord.Cog, name="ModerationsService"): '''command handler for toggling moderation and setting an alert channel''' await ctx.defer() - status = status.lower().strip() - if status not in ["on", "off"]: - await ctx.respond("Invalid status, please use on or off") - return - if status == "on": # Check if the current guild is already in the database and if so, if the moderations is on if self.check_guild_moderated(ctx.guild_id): diff --git a/services/text_service.py b/services/text_service.py index cd116cc..d3fef23 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -232,7 +232,7 @@ class TextService: # Send the request to the model if from_edit_command: response = await converser_cog.model.send_edit_request( - input=new_prompt, + text=new_prompt, instruction=instruction, temp_override=temp_override, top_p_override=top_p_override, From 27d5f6e982afbf215b64399ce28a04f60d848c64 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Sun, 15 Jan 2023 18:17:44 +0000 Subject: [PATCH 5/8] Add basic encapsulate send docstring --- services/text_service.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/services/text_service.py b/services/text_service.py index d3fef23..05c6029 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -26,16 +26,38 @@ class TextService: top_p_override=None, frequency_penalty_override=None, presence_penalty_override=None, - from_ask_command=False, instruction=None, + from_ask_command=False, from_edit_command=False, codex=False, model=None, custom_api_key=None, edited_request=False, redo_request=False, - from_action=None, + from_action=False, ): + """General service function for sending and recieving gpt generations + + Args: + converser_cog (Cog): The conversation cog with our gpt commands + id (user or thread id): A user or thread id for keeping track of conversations + prompt (str): The prompt to use for generation + ctx (ApplicationContext): The interaction which called this + response_message (discord.Message, optional): For when we're doing redos. Defaults to None. + temp_override (float, optional): Sets the temperature for the generation. Defaults to None. + top_p_override (float, optional): Sets the top p for the generation. Defaults to None. + frequency_penalty_override (float, optional): Sets the frequency penalty for the generation. Defaults to None. + presence_penalty_override (float, optional): Sets the presence penalty for the generation. Defaults to None. + instruction (str, optional): Instruction for use with the edit endpoint. Defaults to None. + from_ask_command (bool, optional): Called from the ask command. Defaults to False. + from_edit_command (bool, optional): Called from the edit command. Defaults to False. + codex (bool, optional): Pass along that we want to use a codex model. Defaults to False. + model (str, optional): Which model to genereate output with. Defaults to None. + custom_api_key (str, optional): per-user api key. Defaults to None. + edited_request (bool, optional): If we're doing an edited message. Defaults to False. + redo_request (bool, optional): If we're redoing a previous prompt. Defaults to False. + from_action (bool, optional): If the function is being called from a message action. Defaults to False. + """ new_prompt = ( prompt + "\nGPTie: " if not from_ask_command and not from_edit_command From 3acbca9d7fb2b125cc61344300e3fa2859363e7c Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Sun, 15 Jan 2023 17:50:59 -0500 Subject: [PATCH 6/8] remove author check --- services/text_service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/services/text_service.py b/services/text_service.py index 05c6029..ef0a8fb 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -367,7 +367,8 @@ class TextService: custom_api_key=custom_api_key, ) paginator = pages.Paginator( - pages=embed_pages, timeout=None, custom_view=view + pages=embed_pages, timeout=None, custom_view=view, + author_check=True, ) response_message = await paginator.respond(ctx.interaction) else: From ca0fa419f9b38569f10f52067f5e3f96b49cf565 Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Sun, 15 Jan 2023 18:05:20 -0500 Subject: [PATCH 7/8] change some defaults --- models/openai_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/openai_model.py b/models/openai_model.py index ff03035..31c4628 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -39,8 +39,8 @@ class ImageSize: class Model: def __init__(self, usage_service): self._mode = Mode.TEMPERATURE - self._temp = 0.6 # Higher value means more random, lower value means more likely to be a coherent sentence - self._top_p = 0.9 # 1 is equivalent to greedy sampling, 0.1 means that the model will only consider the top 10% of the probability distribution + self._temp = 0.8 # Higher value means more random, lower value means more likely to be a coherent sentence + self._top_p = 0.95 # 1 is equivalent to greedy sampling, 0.1 means that the model will only consider the top 10% of the probability distribution self._max_tokens = 4000 # The maximum number of tokens the model can generate self._presence_penalty = 0 # Penalize new tokens based on whether they appear in the text so far # Penalize new tokens based on their existing frequency in the text so far. (Higher frequency = lower probability of being chosen.) @@ -55,11 +55,11 @@ class Model: self._image_size = ImageSize.MEDIUM self._num_images = 2 self._summarize_conversations = True - self._summarize_threshold = 2500 + self._summarize_threshold = 3000 self.model_max_tokens = 4024 self._welcome_message_enabled = True - self._num_static_conversation_items = 8 - self._num_conversation_lookback = 6 + self._num_static_conversation_items = 10 + self._num_conversation_lookback = 5 try: self.IMAGE_SAVE_PATH = os.environ["IMAGE_SAVE_PATH"] From e947e030fb652e527b856cc0bb73008c72adddef Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Sun, 15 Jan 2023 18:32:33 -0500 Subject: [PATCH 8/8] more autocomplete --- cogs/commands.py | 3 ++- cogs/moderations_service_cog.py | 8 ++++++++ models/autocomplete_model.py | 21 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/cogs/commands.py b/cogs/commands.py index e1b6861..fd6ddde 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -175,6 +175,7 @@ class Commands(discord.Cog, name="Commands"): name="alert_channel_id", description="The channel ID to send moderation alerts to", required=False, + autocomplete=Settings_autocompleter.get_value_alert_id_channel ) @discord.guild_only() async def moderations( @@ -191,8 +192,8 @@ class Commands(discord.Cog, name="Commands"): @discord.option( name="type", description="The type of moderation to configure", - choices=["warn", "delete"], required=True, + autocomplete=Settings_autocompleter.get_value_moderations, ) @discord.option( name="hate", diff --git a/cogs/moderations_service_cog.py b/cogs/moderations_service_cog.py index 9fee0e9..ad2750e 100644 --- a/cogs/moderations_service_cog.py +++ b/cogs/moderations_service_cog.py @@ -147,6 +147,14 @@ class ModerationsService(discord.Cog, name="ModerationsService"): '''command handler for toggling moderation and setting an alert channel''' await ctx.defer() + try: + if alert_channel_id: + int(alert_channel_id) + except ValueError: + # the alert_channel_id was passed in as a channel NAME instead of an ID, fetch the ID. + alert_channel = discord.utils.get(ctx.guild.channels, name=alert_channel_id) + alert_channel_id = alert_channel.id + if status == "on": # Check if the current guild is already in the database and if so, if the moderations is on if self.check_guild_moderated(ctx.guild_id): diff --git a/models/autocomplete_model.py b/models/autocomplete_model.py index 51f06b9..7898f9c 100644 --- a/models/autocomplete_model.py +++ b/models/autocomplete_model.py @@ -42,6 +42,7 @@ class Settings_autocompleter: "num_static_conversation_items": [str(num) for num in range(5, 20 + 1)], "num_conversation_lookback": [str(num) for num in range(5, 15 + 1)], "summarize_threshold": [str(num) for num in range(800, 3500, 50)], + "type": ["warn", "delete"], } for parameter in values: if parameter == ctx.options["parameter"]: @@ -53,6 +54,26 @@ class Settings_autocompleter: await ctx.interaction.response.defer() # defer so the autocomplete in int values doesn't error but rather just says not found return [] + async def get_value_moderations( + ctx: discord.AutocompleteContext, + ): # Behaves a bit weird if you go back and edit the parameter without typing in a new command + '''gets valid values for the type option''' + print(f"The value is {ctx.value}") + return [ + value + for value in ["warn", "delete"] + if value.startswith(ctx.value.lower()) + ] + + async def get_value_alert_id_channel(self, ctx: discord.AutocompleteContext): + '''gets valid values for the channel option''' + return [ + channel.name + for channel in ctx.interaction.guild.channels + if channel.name.startswith(ctx.value.lower()) + ] + + class File_autocompleter: '''Autocompleter for the opener command'''