From c6a6245e00bbfca18bdd0f69f1c1330612f2ec8a Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Tue, 31 Jan 2023 20:32:23 +0000 Subject: [PATCH] Added server index backup and loading --- .gitignore | 3 +- cogs/commands.py | 23 +++++ cogs/index_service_cog.py | 29 +++++- models/autocomplete_model.py | 13 +++ models/index_model.py | 185 +++++++++++++++++++---------------- 5 files changed, 161 insertions(+), 92 deletions(-) diff --git a/.gitignore b/.gitignore index a95b25f..3411cda 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ __pycache__ *.sqlite bot.pid usage.txt -/dalleimages \ No newline at end of file +/dalleimages +/indexes \ No newline at end of file diff --git a/cogs/commands.py b/cogs/commands.py index 1075560..93af420 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -501,6 +501,18 @@ class Commands(discord.Cog, name="Commands"): # Index commands # + @add_to_group("index") + @discord.slash_command( + name="load_file", + description="Set an index to query from", + guild_ids=ALLOWED_GUILDS + ) + @discord.guild_only() + @discord.option(name="index", description="Which file to load the index from", required=True, autocomplete=File_autocompleter.get_indexes) + async def load_index(self, ctx:discord.ApplicationContext, index: str): + await self.index_cog.load_index_command(ctx, index) + + @add_to_group("index") @discord.slash_command( name="set_file", @@ -524,6 +536,17 @@ class Commands(discord.Cog, name="Commands"): await self.index_cog.set_discord_command(ctx, channel) + @add_to_group("index") + @discord.slash_command( + name="discord_backup", + description="Save an index made from the whole server", + guild_ids=ALLOWED_GUILDS + ) + @discord.guild_only() + async def discord_backup(self, ctx:discord.ApplicationContext): + await self.index_cog.discord_backup_command(ctx) + + @add_to_group("index") @discord.slash_command( name="query", diff --git a/cogs/index_service_cog.py b/cogs/index_service_cog.py index 41515c5..f57cb6a 100644 --- a/cogs/index_service_cog.py +++ b/cogs/index_service_cog.py @@ -15,7 +15,7 @@ class IndexService(discord.Cog, name="IndexService"): ): super().__init__() self.bot = bot - self.index_handler = Index_handler() + self.index_handler = Index_handler(bot) async def set_index_command(self, ctx, file: discord.Attachment): """Command handler to set a file as your personal index""" @@ -40,11 +40,32 @@ class IndexService(discord.Cog, name="IndexService"): return await ctx.defer(ephemeral=True) - if not channel: - await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key, no_channel=True) - return await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key) + async def discord_backup_command(self, ctx): + """Command handler to backup the entire server""" + + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) + if not user_api_key: + return + + await ctx.defer(ephemeral=True) + await self.index_handler.backup_discord(ctx, user_api_key=user_api_key) + + + async def load_index_command(self, ctx, index): + """Command handler to backup the entire server""" + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) + if not user_api_key: + return + + await ctx.defer(ephemeral=True) + await self.index_handler.load_index(ctx, index, user_api_key) + async def query_command(self, ctx, query, response_mode): """Command handler to query your index""" diff --git a/models/autocomplete_model.py b/models/autocomplete_model.py index 8525cb8..682a13b 100644 --- a/models/autocomplete_model.py +++ b/models/autocomplete_model.py @@ -141,3 +141,16 @@ class File_autocompleter: ] # returns the 25 first files from your current input except Exception: return ["No 'openers' folder"] + + async def get_indexes(ctx: discord.AutocompleteContext): + """get all files in the openers folder""" + try: + return [ + file + for file in os.listdir(EnvService.find_shared_file("indexes")) + if file.startswith(ctx.value.lower()) + ][ + :25 + ] # returns the 25 first files from your current input + except Exception: + return ["No 'indexes' folder"] diff --git a/models/index_model.py b/models/index_model.py index bcd0fec..86afe44 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -1,30 +1,44 @@ import os import traceback import asyncio -import tempfile import discord +import aiofiles from functools import partial from typing import List, Optional +from pathlib import Path +from datetime import date, datetime - -from gpt_index.readers.base import BaseReader from gpt_index.readers.schema.base import Document from gpt_index.response.schema import Response +from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, GPTPineconeIndex -from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader +from services.environment_service import EnvService, app_root_path class Index_handler: - def __init__(self): + def __init__(self, bot): + self.bot = bot self.openai_key = os.getenv("OPENAI_TOKEN") self.index_storage = {} self.loop = asyncio.get_running_loop() + self.qaprompt = QuestionAnswerPrompt( + "Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n" + "---------------------\n" + "{context_str}" + "\n---------------------\n" + "Never say '<|endofstatement|>'\n" + "Given the context information and not prior knowledge, " + "answer the question: {query_str}\n" + ) def index_file(self, file_path): document = SimpleDirectoryReader(file_path).load_data() index = GPTSimpleVectorIndex(document) return index + def index_load_file(self, file_path): + index = GPTSimpleVectorIndex.load_from_disk(file_path) + return index def index_discord(self, document): index = GPTSimpleVectorIndex(document) return index @@ -37,7 +51,6 @@ class Index_handler: os.environ["OPENAI_API_KEY"] = user_api_key try: - temp_path = tempfile.TemporaryDirectory() if file.content_type.startswith("text/plain"): suffix = ".txt" elif file.content_type.startswith("application/pdf"): @@ -45,8 +58,9 @@ class Index_handler: else: await ctx.respond("Only accepts txt or pdf files") return - temp_file = tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) - await file.save(temp_file.name) + async with aiofiles.tempfile.TemporaryDirectory() as temp_path: + async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) as temp_file: + await file.save(temp_file.name) index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name)) self.index_storage[ctx.user.id] = index temp_path.cleanup() @@ -56,21 +70,14 @@ class Index_handler: traceback.print_exc() - async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False): + async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key): if not user_api_key: os.environ["OPENAI_API_KEY"] = self.openai_key else: os.environ["OPENAI_API_KEY"] = user_api_key try: - reader = DiscordReader() - if no_channel: - channel_ids:List[int] = [] - for c in ctx.guild.text_channels: - channel_ids.append(c.id) - document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False) - else: - document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False) + document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False) index = await self.loop.run_in_executor(None, partial(self.index_discord, document)) self.index_storage[ctx.user.id] = index await ctx.respond("Index set") @@ -78,6 +85,42 @@ class Index_handler: await ctx.respond("Failed to set index") traceback.print_exc() + + async def load_index(self, ctx:discord.ApplicationContext, index, user_api_key): + if not user_api_key: + os.environ["OPENAI_API_KEY"] = self.openai_key + else: + os.environ["OPENAI_API_KEY"] = user_api_key + + try: + index_file = EnvService.find_shared_file(f"indexes/{index}") + index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file)) + self.index_storage[ctx.user.id] = index + await ctx.respond("Loaded index") + except Exception as e: + await ctx.respond(e) + + + async def backup_discord(self, ctx: discord.ApplicationContext, user_api_key): + if not user_api_key: + os.environ["OPENAI_API_KEY"] = self.openai_key + else: + os.environ["OPENAI_API_KEY"] = user_api_key + + try: + channel_ids:List[int] = [] + for c in ctx.guild.text_channels: + channel_ids.append(c.id) + document = await self.load_data(channel_ids=channel_ids, limit=1000, oldest_first=False) + index = await self.loop.run_in_executor(None, partial(self.index_discord, document)) + Path(app_root_path() / "indexes").mkdir(parents = True, exist_ok=True) + index.save_to_disk(app_root_path() / "indexes" / f"{ctx.guild.name.replace(' ', '-')}_{date.today()}-H{datetime.now().hour}.json") + + await ctx.respond("Backup saved") + except Exception: + await ctx.respond("Failed to save backup") + traceback.print_exc() + async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key): @@ -85,88 +128,56 @@ class Index_handler: os.environ["OPENAI_API_KEY"] = self.openai_key else: os.environ["OPENAI_API_KEY"] = user_api_key - + try: index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id] - response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode)) + response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode, text_qa_template=self.qaprompt)) await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}") except Exception: - await ctx.respond("You haven't set and index", delete_after=10) - - -#Set our own version of the DiscordReader class that's async - -class DiscordReader(BaseReader): - """Discord reader. - - Reads conversations from channels. + await ctx.respond("Failed to send query", delete_after=10) - Args: - discord_token (Optional[str]): Discord token. If not provided, we - assume the environment variable `DISCORD_TOKEN` is set. - - """ - - def __init__(self, discord_token: Optional[str] = None) -> None: - """Initialize with parameters.""" - if discord_token is None: - discord_token = os.environ["DISCORD_TOKEN"] - if discord_token is None: - raise ValueError( - "Must specify `discord_token` or set environment " - "variable `DISCORD_TOKEN`." - ) - - self.discord_token = discord_token + # Extracted functions from DiscordReader async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str: """Async read channel.""" messages: List[discord.Message] = [] - class CustomClient(discord.Client): - async def on_ready(self) -> None: - try: - channel = client.get_channel(channel_id) - print(f"Added {channel.name} from {channel.guild.name}") - # only work for text channels for now - if not isinstance(channel, discord.TextChannel): - raise ValueError( - f"Channel {channel_id} is not a text channel. " - "Only text channels are supported for now." - ) - # thread_dict maps thread_id to thread - thread_dict = {} - for thread in channel.threads: - thread_dict[thread.id] = thread - - async for msg in channel.history( - limit=limit, oldest_first=oldest_first - ): - if msg.author.bot: - pass - else: - messages.append(msg) - if msg.id in thread_dict: - thread = thread_dict[msg.id] - async for thread_msg in thread.history( - limit=limit, oldest_first=oldest_first - ): - messages.append(thread_msg) - except Exception as e: - print("Encountered error: " + str(e)) - finally: - await self.close() - - intents = discord.Intents.default() - intents.message_content = True - client = CustomClient(intents=intents) - await client.start(self.discord_token) - - channel = client.get_channel(channel_id) + + try: + channel = self.bot.get_channel(channel_id) + print(f"Added {channel.name} from {channel.guild.name}") + # only work for text channels for now + if not isinstance(channel, discord.TextChannel): + raise ValueError( + f"Channel {channel_id} is not a text channel. " + "Only text channels are supported for now." + ) + # thread_dict maps thread_id to thread + thread_dict = {} + for thread in channel.threads: + thread_dict[thread.id] = thread + + async for msg in channel.history( + limit=limit, oldest_first=oldest_first + ): + if msg.author.bot: + pass + else: + messages.append(msg) + if msg.id in thread_dict: + thread = thread_dict[msg.id] + async for thread_msg in thread.history( + limit=limit, oldest_first=oldest_first + ): + messages.append(thread_msg) + except Exception as e: + print("Encountered error: " + str(e)) + + channel = self.bot.get_channel(channel_id) msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages] - return ("\n\n".join(msg_txt_list), channel.name) + return ("<|endofstatement|>\n\n".join(msg_txt_list), channel.name) async def load_data( self, @@ -195,6 +206,6 @@ class DiscordReader(BaseReader): ) (channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first) results.append( - Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name}) + Document(channel_content, extra_info={"channel_name": channel_name}) ) return results \ No newline at end of file