Added server index backup and loading

Rene Teigen 1 year ago
parent dd9cb0ce4c
commit c6a6245e00

3
.gitignore vendored

@ -8,4 +8,5 @@ __pycache__
*.sqlite *.sqlite
bot.pid bot.pid
usage.txt usage.txt
/dalleimages /dalleimages
/indexes

@ -501,6 +501,18 @@ class Commands(discord.Cog, name="Commands"):
# Index commands # Index commands
# #
@add_to_group("index")
@discord.slash_command(
name="load_file",
description="Set an index to query from",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
@discord.option(name="index", description="Which file to load the index from", required=True, autocomplete=File_autocompleter.get_indexes)
async def load_index(self, ctx:discord.ApplicationContext, index: str):
await self.index_cog.load_index_command(ctx, index)
@add_to_group("index") @add_to_group("index")
@discord.slash_command( @discord.slash_command(
name="set_file", name="set_file",
@ -524,6 +536,17 @@ class Commands(discord.Cog, name="Commands"):
await self.index_cog.set_discord_command(ctx, channel) await self.index_cog.set_discord_command(ctx, channel)
@add_to_group("index")
@discord.slash_command(
name="discord_backup",
description="Save an index made from the whole server",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
async def discord_backup(self, ctx:discord.ApplicationContext):
await self.index_cog.discord_backup_command(ctx)
@add_to_group("index") @add_to_group("index")
@discord.slash_command( @discord.slash_command(
name="query", name="query",

@ -15,7 +15,7 @@ class IndexService(discord.Cog, name="IndexService"):
): ):
super().__init__() super().__init__()
self.bot = bot self.bot = bot
self.index_handler = Index_handler() self.index_handler = Index_handler(bot)
async def set_index_command(self, ctx, file: discord.Attachment): async def set_index_command(self, ctx, file: discord.Attachment):
"""Command handler to set a file as your personal index""" """Command handler to set a file as your personal index"""
@ -40,11 +40,32 @@ class IndexService(discord.Cog, name="IndexService"):
return return
await ctx.defer(ephemeral=True) await ctx.defer(ephemeral=True)
if not channel:
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key, no_channel=True)
return
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key) await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key)
async def discord_backup_command(self, ctx):
"""Command handler to backup the entire server"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
await self.index_handler.backup_discord(ctx, user_api_key=user_api_key)
async def load_index_command(self, ctx, index):
"""Command handler to backup the entire server"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
await self.index_handler.load_index(ctx, index, user_api_key)
async def query_command(self, ctx, query, response_mode): async def query_command(self, ctx, query, response_mode):
"""Command handler to query your index""" """Command handler to query your index"""

@ -141,3 +141,16 @@ class File_autocompleter:
] # returns the 25 first files from your current input ] # returns the 25 first files from your current input
except Exception: except Exception:
return ["No 'openers' folder"] return ["No 'openers' folder"]
async def get_indexes(ctx: discord.AutocompleteContext):
"""get all files in the openers folder"""
try:
return [
file
for file in os.listdir(EnvService.find_shared_file("indexes"))
if file.startswith(ctx.value.lower())
][
:25
] # returns the 25 first files from your current input
except Exception:
return ["No 'indexes' folder"]

@ -1,30 +1,44 @@
import os import os
import traceback import traceback
import asyncio import asyncio
import tempfile
import discord import discord
import aiofiles
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional
from pathlib import Path
from datetime import date, datetime
from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document from gpt_index.readers.schema.base import Document
from gpt_index.response.schema import Response from gpt_index.response.schema import Response
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, GPTPineconeIndex
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader from services.environment_service import EnvService, app_root_path
class Index_handler: class Index_handler:
def __init__(self): def __init__(self, bot):
self.bot = bot
self.openai_key = os.getenv("OPENAI_TOKEN") self.openai_key = os.getenv("OPENAI_TOKEN")
self.index_storage = {} self.index_storage = {}
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.qaprompt = QuestionAnswerPrompt(
"Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Never say '<|endofstatement|>'\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
def index_file(self, file_path): def index_file(self, file_path):
document = SimpleDirectoryReader(file_path).load_data() document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document) index = GPTSimpleVectorIndex(document)
return index return index
def index_load_file(self, file_path):
index = GPTSimpleVectorIndex.load_from_disk(file_path)
return index
def index_discord(self, document): def index_discord(self, document):
index = GPTSimpleVectorIndex(document) index = GPTSimpleVectorIndex(document)
return index return index
@ -37,7 +51,6 @@ class Index_handler:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
try: try:
temp_path = tempfile.TemporaryDirectory()
if file.content_type.startswith("text/plain"): if file.content_type.startswith("text/plain"):
suffix = ".txt" suffix = ".txt"
elif file.content_type.startswith("application/pdf"): elif file.content_type.startswith("application/pdf"):
@ -45,8 +58,9 @@ class Index_handler:
else: else:
await ctx.respond("Only accepts txt or pdf files") await ctx.respond("Only accepts txt or pdf files")
return return
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) async with aiofiles.tempfile.TemporaryDirectory() as temp_path:
await file.save(temp_file.name) async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) as temp_file:
await file.save(temp_file.name)
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name)) index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name))
self.index_storage[ctx.user.id] = index self.index_storage[ctx.user.id] = index
temp_path.cleanup() temp_path.cleanup()
@ -56,21 +70,14 @@ class Index_handler:
traceback.print_exc() traceback.print_exc()
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False): async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key):
if not user_api_key: if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key os.environ["OPENAI_API_KEY"] = self.openai_key
else: else:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
try: try:
reader = DiscordReader() document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
if no_channel:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False)
else:
document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document)) index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
self.index_storage[ctx.user.id] = index self.index_storage[ctx.user.id] = index
await ctx.respond("Index set") await ctx.respond("Index set")
@ -78,6 +85,42 @@ class Index_handler:
await ctx.respond("Failed to set index") await ctx.respond("Failed to set index")
traceback.print_exc() traceback.print_exc()
async def load_index(self, ctx:discord.ApplicationContext, index, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
index_file = EnvService.find_shared_file(f"indexes/{index}")
index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
self.index_storage[ctx.user.id] = index
await ctx.respond("Loaded index")
except Exception as e:
await ctx.respond(e)
async def backup_discord(self, ctx: discord.ApplicationContext, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await self.load_data(channel_ids=channel_ids, limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
Path(app_root_path() / "indexes").mkdir(parents = True, exist_ok=True)
index.save_to_disk(app_root_path() / "indexes" / f"{ctx.guild.name.replace(' ', '-')}_{date.today()}-H{datetime.now().hour}.json")
await ctx.respond("Backup saved")
except Exception:
await ctx.respond("Failed to save backup")
traceback.print_exc()
async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key): async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
@ -85,88 +128,56 @@ class Index_handler:
os.environ["OPENAI_API_KEY"] = self.openai_key os.environ["OPENAI_API_KEY"] = self.openai_key
else: else:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
try: try:
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id] index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode)) response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode, text_qa_template=self.qaprompt))
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}") await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
except Exception: except Exception:
await ctx.respond("You haven't set and index", delete_after=10) await ctx.respond("Failed to send query", delete_after=10)
#Set our own version of the DiscordReader class that's async
class DiscordReader(BaseReader):
"""Discord reader.
Reads conversations from channels.
Args: # Extracted functions from DiscordReader
discord_token (Optional[str]): Discord token. If not provided, we
assume the environment variable `DISCORD_TOKEN` is set.
"""
def __init__(self, discord_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if discord_token is None:
discord_token = os.environ["DISCORD_TOKEN"]
if discord_token is None:
raise ValueError(
"Must specify `discord_token` or set environment "
"variable `DISCORD_TOKEN`."
)
self.discord_token = discord_token
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str: async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
"""Async read channel.""" """Async read channel."""
messages: List[discord.Message] = [] messages: List[discord.Message] = []
class CustomClient(discord.Client):
async def on_ready(self) -> None: try:
try: channel = self.bot.get_channel(channel_id)
channel = client.get_channel(channel_id) print(f"Added {channel.name} from {channel.guild.name}")
print(f"Added {channel.name} from {channel.guild.name}") # only work for text channels for now
# only work for text channels for now if not isinstance(channel, discord.TextChannel):
if not isinstance(channel, discord.TextChannel): raise ValueError(
raise ValueError( f"Channel {channel_id} is not a text channel. "
f"Channel {channel_id} is not a text channel. " "Only text channels are supported for now."
"Only text channels are supported for now." )
) # thread_dict maps thread_id to thread
# thread_dict maps thread_id to thread thread_dict = {}
thread_dict = {} for thread in channel.threads:
for thread in channel.threads: thread_dict[thread.id] = thread
thread_dict[thread.id] = thread
async for msg in channel.history(
async for msg in channel.history( limit=limit, oldest_first=oldest_first
limit=limit, oldest_first=oldest_first ):
): if msg.author.bot:
if msg.author.bot: pass
pass else:
else: messages.append(msg)
messages.append(msg) if msg.id in thread_dict:
if msg.id in thread_dict: thread = thread_dict[msg.id]
thread = thread_dict[msg.id] async for thread_msg in thread.history(
async for thread_msg in thread.history( limit=limit, oldest_first=oldest_first
limit=limit, oldest_first=oldest_first ):
): messages.append(thread_msg)
messages.append(thread_msg) except Exception as e:
except Exception as e: print("Encountered error: " + str(e))
print("Encountered error: " + str(e))
finally: channel = self.bot.get_channel(channel_id)
await self.close()
intents = discord.Intents.default()
intents.message_content = True
client = CustomClient(intents=intents)
await client.start(self.discord_token)
channel = client.get_channel(channel_id)
msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages] msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages]
return ("\n\n".join(msg_txt_list), channel.name) return ("<|endofstatement|>\n\n".join(msg_txt_list), channel.name)
async def load_data( async def load_data(
self, self,
@ -195,6 +206,6 @@ class DiscordReader(BaseReader):
) )
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first) (channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
results.append( results.append(
Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name}) Document(channel_content, extra_info={"channel_name": channel_name})
) )
return results return results
Loading…
Cancel
Save