Added server index backup and loading

Rene Teigen 1 year ago
parent dd9cb0ce4c
commit c6a6245e00

3
.gitignore vendored

@ -8,4 +8,5 @@ __pycache__
*.sqlite
bot.pid
usage.txt
/dalleimages
/dalleimages
/indexes

@ -501,6 +501,18 @@ class Commands(discord.Cog, name="Commands"):
# Index commands
#
@add_to_group("index")
@discord.slash_command(
name="load_file",
description="Set an index to query from",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
@discord.option(name="index", description="Which file to load the index from", required=True, autocomplete=File_autocompleter.get_indexes)
async def load_index(self, ctx:discord.ApplicationContext, index: str):
await self.index_cog.load_index_command(ctx, index)
@add_to_group("index")
@discord.slash_command(
name="set_file",
@ -524,6 +536,17 @@ class Commands(discord.Cog, name="Commands"):
await self.index_cog.set_discord_command(ctx, channel)
@add_to_group("index")
@discord.slash_command(
name="discord_backup",
description="Save an index made from the whole server",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
async def discord_backup(self, ctx:discord.ApplicationContext):
await self.index_cog.discord_backup_command(ctx)
@add_to_group("index")
@discord.slash_command(
name="query",

@ -15,7 +15,7 @@ class IndexService(discord.Cog, name="IndexService"):
):
super().__init__()
self.bot = bot
self.index_handler = Index_handler()
self.index_handler = Index_handler(bot)
async def set_index_command(self, ctx, file: discord.Attachment):
"""Command handler to set a file as your personal index"""
@ -40,11 +40,32 @@ class IndexService(discord.Cog, name="IndexService"):
return
await ctx.defer(ephemeral=True)
if not channel:
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key, no_channel=True)
return
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key)
async def discord_backup_command(self, ctx):
"""Command handler to backup the entire server"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
await self.index_handler.backup_discord(ctx, user_api_key=user_api_key)
async def load_index_command(self, ctx, index):
"""Command handler to backup the entire server"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
await self.index_handler.load_index(ctx, index, user_api_key)
async def query_command(self, ctx, query, response_mode):
"""Command handler to query your index"""

@ -141,3 +141,16 @@ class File_autocompleter:
] # returns the 25 first files from your current input
except Exception:
return ["No 'openers' folder"]
async def get_indexes(ctx: discord.AutocompleteContext):
"""get all files in the openers folder"""
try:
return [
file
for file in os.listdir(EnvService.find_shared_file("indexes"))
if file.startswith(ctx.value.lower())
][
:25
] # returns the 25 first files from your current input
except Exception:
return ["No 'indexes' folder"]

@ -1,30 +1,44 @@
import os
import traceback
import asyncio
import tempfile
import discord
import aiofiles
from functools import partial
from typing import List, Optional
from pathlib import Path
from datetime import date, datetime
from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document
from gpt_index.response.schema import Response
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, GPTPineconeIndex
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader
from services.environment_service import EnvService, app_root_path
class Index_handler:
def __init__(self):
def __init__(self, bot):
self.bot = bot
self.openai_key = os.getenv("OPENAI_TOKEN")
self.index_storage = {}
self.loop = asyncio.get_running_loop()
self.qaprompt = QuestionAnswerPrompt(
"Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Never say '<|endofstatement|>'\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
def index_file(self, file_path):
document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document)
return index
def index_load_file(self, file_path):
index = GPTSimpleVectorIndex.load_from_disk(file_path)
return index
def index_discord(self, document):
index = GPTSimpleVectorIndex(document)
return index
@ -37,7 +51,6 @@ class Index_handler:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
temp_path = tempfile.TemporaryDirectory()
if file.content_type.startswith("text/plain"):
suffix = ".txt"
elif file.content_type.startswith("application/pdf"):
@ -45,8 +58,9 @@ class Index_handler:
else:
await ctx.respond("Only accepts txt or pdf files")
return
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False)
await file.save(temp_file.name)
async with aiofiles.tempfile.TemporaryDirectory() as temp_path:
async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) as temp_file:
await file.save(temp_file.name)
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name))
self.index_storage[ctx.user.id] = index
temp_path.cleanup()
@ -56,21 +70,14 @@ class Index_handler:
traceback.print_exc()
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False):
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
reader = DiscordReader()
if no_channel:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False)
else:
document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
self.index_storage[ctx.user.id] = index
await ctx.respond("Index set")
@ -78,6 +85,42 @@ class Index_handler:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def load_index(self, ctx:discord.ApplicationContext, index, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
index_file = EnvService.find_shared_file(f"indexes/{index}")
index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
self.index_storage[ctx.user.id] = index
await ctx.respond("Loaded index")
except Exception as e:
await ctx.respond(e)
async def backup_discord(self, ctx: discord.ApplicationContext, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await self.load_data(channel_ids=channel_ids, limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
Path(app_root_path() / "indexes").mkdir(parents = True, exist_ok=True)
index.save_to_disk(app_root_path() / "indexes" / f"{ctx.guild.name.replace(' ', '-')}_{date.today()}-H{datetime.now().hour}.json")
await ctx.respond("Backup saved")
except Exception:
await ctx.respond("Failed to save backup")
traceback.print_exc()
async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
@ -85,88 +128,56 @@ class Index_handler:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode))
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode, text_qa_template=self.qaprompt))
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
except Exception:
await ctx.respond("You haven't set and index", delete_after=10)
#Set our own version of the DiscordReader class that's async
class DiscordReader(BaseReader):
"""Discord reader.
Reads conversations from channels.
await ctx.respond("Failed to send query", delete_after=10)
Args:
discord_token (Optional[str]): Discord token. If not provided, we
assume the environment variable `DISCORD_TOKEN` is set.
"""
def __init__(self, discord_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if discord_token is None:
discord_token = os.environ["DISCORD_TOKEN"]
if discord_token is None:
raise ValueError(
"Must specify `discord_token` or set environment "
"variable `DISCORD_TOKEN`."
)
self.discord_token = discord_token
# Extracted functions from DiscordReader
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
"""Async read channel."""
messages: List[discord.Message] = []
class CustomClient(discord.Client):
async def on_ready(self) -> None:
try:
channel = client.get_channel(channel_id)
print(f"Added {channel.name} from {channel.guild.name}")
# only work for text channels for now
if not isinstance(channel, discord.TextChannel):
raise ValueError(
f"Channel {channel_id} is not a text channel. "
"Only text channels are supported for now."
)
# thread_dict maps thread_id to thread
thread_dict = {}
for thread in channel.threads:
thread_dict[thread.id] = thread
async for msg in channel.history(
limit=limit, oldest_first=oldest_first
):
if msg.author.bot:
pass
else:
messages.append(msg)
if msg.id in thread_dict:
thread = thread_dict[msg.id]
async for thread_msg in thread.history(
limit=limit, oldest_first=oldest_first
):
messages.append(thread_msg)
except Exception as e:
print("Encountered error: " + str(e))
finally:
await self.close()
intents = discord.Intents.default()
intents.message_content = True
client = CustomClient(intents=intents)
await client.start(self.discord_token)
channel = client.get_channel(channel_id)
try:
channel = self.bot.get_channel(channel_id)
print(f"Added {channel.name} from {channel.guild.name}")
# only work for text channels for now
if not isinstance(channel, discord.TextChannel):
raise ValueError(
f"Channel {channel_id} is not a text channel. "
"Only text channels are supported for now."
)
# thread_dict maps thread_id to thread
thread_dict = {}
for thread in channel.threads:
thread_dict[thread.id] = thread
async for msg in channel.history(
limit=limit, oldest_first=oldest_first
):
if msg.author.bot:
pass
else:
messages.append(msg)
if msg.id in thread_dict:
thread = thread_dict[msg.id]
async for thread_msg in thread.history(
limit=limit, oldest_first=oldest_first
):
messages.append(thread_msg)
except Exception as e:
print("Encountered error: " + str(e))
channel = self.bot.get_channel(channel_id)
msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages]
return ("\n\n".join(msg_txt_list), channel.name)
return ("<|endofstatement|>\n\n".join(msg_txt_list), channel.name)
async def load_data(
self,
@ -195,6 +206,6 @@ class DiscordReader(BaseReader):
)
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
results.append(
Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name})
Document(channel_content, extra_info={"channel_name": channel_name})
)
return results
Loading…
Cancel
Save