diff --git a/cogs/image_service_cog.py b/cogs/image_service_cog.py index 6e432a1..8ab4b1e 100644 --- a/cogs/image_service_cog.py +++ b/cogs/image_service_cog.py @@ -15,9 +15,7 @@ users_to_interactions = {} ALLOWED_GUILDS = EnvService.get_allowed_guilds() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() -USER_KEY_DB = None -if USER_INPUT_API_KEYS: - USER_KEY_DB = SqliteDict("user_key_db.sqlite") +USER_KEY_DB = EnvService.get_api_db() class DrawDallEService(discord.Cog, name="DrawDallEService"): diff --git a/cogs/prompt_optimizer_cog.py b/cogs/prompt_optimizer_cog.py index 6392a48..e117ac9 100644 --- a/cogs/prompt_optimizer_cog.py +++ b/cogs/prompt_optimizer_cog.py @@ -12,9 +12,7 @@ from services.text_service import TextService ALLOWED_GUILDS = EnvService.get_allowed_guilds() USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() -USER_KEY_DB = None -if USER_INPUT_API_KEYS: - USER_KEY_DB = SqliteDict("user_key_db.sqlite") +USER_KEY_DB = EnvService.get_api_db() class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): diff --git a/cogs/text_service_cog.py b/cogs/text_service_cog.py index 4a77585..78a0cc5 100644 --- a/cogs/text_service_cog.py +++ b/cogs/text_service_cog.py @@ -30,27 +30,7 @@ else: # Get the user key service if it is enabled. # USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() -USER_KEY_DB = None -if USER_INPUT_API_KEYS: - print( - "This server was configured to enforce user input API keys. Doing the required database setup now" - ) - # Get USER_KEY_DB from the environment variable - USER_KEY_DB_PATH = EnvService.get_user_key_db_path() - # Check if USER_KEY_DB_PATH is valid - if not USER_KEY_DB_PATH: - print( - "No user key database path was provided. Defaulting to user_key_db.sqlite" - ) - USER_KEY_DB_PATH = "user_key_db.sqlite" - else: - # append "user_key_db.sqlite" to USER_KEY_DB_PATH if it doesn't already end with .sqlite - if not USER_KEY_DB_PATH.match("*.sqlite"): - # append "user_key_db.sqlite" to USER_KEY_DB_PATH - USER_KEY_DB_PATH = USER_KEY_DB_PATH / "user_key_db.sqlite" - USER_KEY_DB = SqliteDict(USER_KEY_DB_PATH) - print("Retrieved/created the user key database") - +USER_KEY_DB = EnvService.get_api_db() # # Obtain the Moderation table and the General table, these are two SQLite tables that contain @@ -994,8 +974,9 @@ class GPT3ComCon(discord.Cog, name="GPT3ComCon"): ephemeral=True, delete_after=30, ) + return - modal = SetupModal(title="API Key Setup") + modal = SetupModal(user_key_db=USER_KEY_DB) await ctx.send_modal(modal) async def settings_command( diff --git a/gpt3discord.py b/gpt3discord.py index cc6f489..afc691c 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -204,8 +204,9 @@ def init(): os.remove(PID_FILE) finally: # Kill all threads - print("Killing all subprocesses") - process.terminate() + if process: + print("Killing all subprocesses") + process.terminate() print("Killed all subprocesses") sys.exit(0) diff --git a/services/environment_service.py b/services/environment_service.py index 280fccd..648a88e 100644 --- a/services/environment_service.py +++ b/services/environment_service.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Union from dotenv import load_dotenv +from sqlitedict import SqliteDict def app_root_path(): @@ -241,6 +242,32 @@ class EnvService: except Exception: return None + + @staticmethod + def get_api_db(): + user_input_api_keys = EnvService.get_user_input_api_keys() + user_key_db = None + if user_input_api_keys: + print("This server was configured to enforce user input API keys. Doing the required database setup now") + # Get user_key_db from environment variable + user_key_db_path = EnvService.get_user_key_db_path() + # Check if user_key_db_path is valid + if not user_key_db_path: + print("No user key database path was provided. Defaulting to user_key_db.sqlite") + user_key_db_path = "user_key_db.sqlite" + else: + # append "user_key_db.sqlite" to USER_KEY_DB_PATH if it doesn't already end with .sqlite + if not user_key_db_path.match("*.sqlite"): + # append "user_key_db.sqlite" to USER_KEY_DB_PATH + user_key_db_path = user_key_db_path / "user_key_db.sqlite" + user_key_db = SqliteDict(user_key_db_path) + print("Retrieved/created the user key database") + return user_key_db + return user_key_db + + + + @staticmethod def get_deepl_token(): try: diff --git a/services/text_service.py b/services/text_service.py index 43829d5..d774ddc 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -631,7 +631,7 @@ class TextService: async def get_user_api_key(user_id, ctx, USER_KEY_DB): user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id] if user_api_key is None or user_api_key == "": - modal = SetupModal(title="API Key Setup", user_key_db=USER_KEY_DB) + modal = SetupModal(user_key_db=USER_KEY_DB) if isinstance(ctx, discord.ApplicationContext): await ctx.send_modal(modal) await ctx.send_followup( @@ -845,10 +845,10 @@ class RedoButton(discord.ui.Button["ConversationView"]): class SetupModal(discord.ui.Modal): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, user_key_db) -> None: + super().__init__(title="API Key Setup") # Get the argument named "user_key_db" and save it as USER_KEY_DB - self.USER_KEY_DB = kwargs.pop("user_key_db") + self.USER_KEY_DB = user_key_db self.add_item( discord.ui.InputText(