Added server index backup and loading

Rene Teigen 2 years ago
parent dd9cb0ce4c
commit c6a6245e00

1
.gitignore vendored

@ -9,3 +9,4 @@ __pycache__
bot.pid bot.pid
usage.txt usage.txt
/dalleimages /dalleimages
/indexes

@ -501,6 +501,18 @@ class Commands(discord.Cog, name="Commands"):
# Index commands # Index commands
# #
@add_to_group("index")
@discord.slash_command(
name="load_file",
description="Set an index to query from",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
@discord.option(name="index", description="Which file to load the index from", required=True, autocomplete=File_autocompleter.get_indexes)
async def load_index(self, ctx:discord.ApplicationContext, index: str):
await self.index_cog.load_index_command(ctx, index)
@add_to_group("index") @add_to_group("index")
@discord.slash_command( @discord.slash_command(
name="set_file", name="set_file",
@ -524,6 +536,17 @@ class Commands(discord.Cog, name="Commands"):
await self.index_cog.set_discord_command(ctx, channel) await self.index_cog.set_discord_command(ctx, channel)
@add_to_group("index")
@discord.slash_command(
name="discord_backup",
description="Save an index made from the whole server",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
async def discord_backup(self, ctx:discord.ApplicationContext):
await self.index_cog.discord_backup_command(ctx)
@add_to_group("index") @add_to_group("index")
@discord.slash_command( @discord.slash_command(
name="query", name="query",

@ -15,7 +15,7 @@ class IndexService(discord.Cog, name="IndexService"):
): ):
super().__init__() super().__init__()
self.bot = bot self.bot = bot
self.index_handler = Index_handler() self.index_handler = Index_handler(bot)
async def set_index_command(self, ctx, file: discord.Attachment): async def set_index_command(self, ctx, file: discord.Attachment):
"""Command handler to set a file as your personal index""" """Command handler to set a file as your personal index"""
@ -40,11 +40,32 @@ class IndexService(discord.Cog, name="IndexService"):
return return
await ctx.defer(ephemeral=True) await ctx.defer(ephemeral=True)
if not channel:
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key, no_channel=True)
return
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key) await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key)
async def discord_backup_command(self, ctx):
"""Command handler to backup the entire server"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
await self.index_handler.backup_discord(ctx, user_api_key=user_api_key)
async def load_index_command(self, ctx, index):
"""Command handler to backup the entire server"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
await self.index_handler.load_index(ctx, index, user_api_key)
async def query_command(self, ctx, query, response_mode): async def query_command(self, ctx, query, response_mode):
"""Command handler to query your index""" """Command handler to query your index"""

@ -141,3 +141,16 @@ class File_autocompleter:
] # returns the 25 first files from your current input ] # returns the 25 first files from your current input
except Exception: except Exception:
return ["No 'openers' folder"] return ["No 'openers' folder"]
async def get_indexes(ctx: discord.AutocompleteContext):
"""get all files in the openers folder"""
try:
return [
file
for file in os.listdir(EnvService.find_shared_file("indexes"))
if file.startswith(ctx.value.lower())
][
:25
] # returns the 25 first files from your current input
except Exception:
return ["No 'indexes' folder"]

@ -1,30 +1,44 @@
import os import os
import traceback import traceback
import asyncio import asyncio
import tempfile
import discord import discord
import aiofiles
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional
from pathlib import Path
from datetime import date, datetime
from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document from gpt_index.readers.schema.base import Document
from gpt_index.response.schema import Response from gpt_index.response.schema import Response
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, GPTPineconeIndex
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader from services.environment_service import EnvService, app_root_path
class Index_handler: class Index_handler:
def __init__(self): def __init__(self, bot):
self.bot = bot
self.openai_key = os.getenv("OPENAI_TOKEN") self.openai_key = os.getenv("OPENAI_TOKEN")
self.index_storage = {} self.index_storage = {}
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.qaprompt = QuestionAnswerPrompt(
"Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Never say '<|endofstatement|>'\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
def index_file(self, file_path): def index_file(self, file_path):
document = SimpleDirectoryReader(file_path).load_data() document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document) index = GPTSimpleVectorIndex(document)
return index return index
def index_load_file(self, file_path):
index = GPTSimpleVectorIndex.load_from_disk(file_path)
return index
def index_discord(self, document): def index_discord(self, document):
index = GPTSimpleVectorIndex(document) index = GPTSimpleVectorIndex(document)
return index return index
@ -37,7 +51,6 @@ class Index_handler:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
try: try:
temp_path = tempfile.TemporaryDirectory()
if file.content_type.startswith("text/plain"): if file.content_type.startswith("text/plain"):
suffix = ".txt" suffix = ".txt"
elif file.content_type.startswith("application/pdf"): elif file.content_type.startswith("application/pdf"):
@ -45,7 +58,8 @@ class Index_handler:
else: else:
await ctx.respond("Only accepts txt or pdf files") await ctx.respond("Only accepts txt or pdf files")
return return
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) async with aiofiles.tempfile.TemporaryDirectory() as temp_path:
async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False) as temp_file:
await file.save(temp_file.name) await file.save(temp_file.name)
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name)) index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name))
self.index_storage[ctx.user.id] = index self.index_storage[ctx.user.id] = index
@ -56,21 +70,14 @@ class Index_handler:
traceback.print_exc() traceback.print_exc()
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False): async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key):
if not user_api_key: if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key os.environ["OPENAI_API_KEY"] = self.openai_key
else: else:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
try: try:
reader = DiscordReader() document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
if no_channel:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False)
else:
document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document)) index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
self.index_storage[ctx.user.id] = index self.index_storage[ctx.user.id] = index
await ctx.respond("Index set") await ctx.respond("Index set")
@ -79,55 +86,66 @@ class Index_handler:
traceback.print_exc() traceback.print_exc()
async def load_index(self, ctx:discord.ApplicationContext, index, user_api_key):
async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
if not user_api_key: if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key os.environ["OPENAI_API_KEY"] = self.openai_key
else: else:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
try: try:
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id] index_file = EnvService.find_shared_file(f"indexes/{index}")
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode)) index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}") self.index_storage[ctx.user.id] = index
except Exception: await ctx.respond("Loaded index")
await ctx.respond("You haven't set and index", delete_after=10) except Exception as e:
await ctx.respond(e)
#Set our own version of the DiscordReader class that's async async def backup_discord(self, ctx: discord.ApplicationContext, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
class DiscordReader(BaseReader): try:
"""Discord reader. channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await self.load_data(channel_ids=channel_ids, limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
Path(app_root_path() / "indexes").mkdir(parents = True, exist_ok=True)
index.save_to_disk(app_root_path() / "indexes" / f"{ctx.guild.name.replace(' ', '-')}_{date.today()}-H{datetime.now().hour}.json")
Reads conversations from channels. await ctx.respond("Backup saved")
except Exception:
await ctx.respond("Failed to save backup")
traceback.print_exc()
Args:
discord_token (Optional[str]): Discord token. If not provided, we
assume the environment variable `DISCORD_TOKEN` is set.
"""
def __init__(self, discord_token: Optional[str] = None) -> None: async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
"""Initialize with parameters.""" if not user_api_key:
if discord_token is None: os.environ["OPENAI_API_KEY"] = self.openai_key
discord_token = os.environ["DISCORD_TOKEN"] else:
if discord_token is None: os.environ["OPENAI_API_KEY"] = user_api_key
raise ValueError(
"Must specify `discord_token` or set environment " try:
"variable `DISCORD_TOKEN`." index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
) response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode, text_qa_template=self.qaprompt))
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
except Exception:
await ctx.respond("Failed to send query", delete_after=10)
self.discord_token = discord_token # Extracted functions from DiscordReader
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str: async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
"""Async read channel.""" """Async read channel."""
messages: List[discord.Message] = [] messages: List[discord.Message] = []
class CustomClient(discord.Client):
async def on_ready(self) -> None:
try: try:
channel = client.get_channel(channel_id) channel = self.bot.get_channel(channel_id)
print(f"Added {channel.name} from {channel.guild.name}") print(f"Added {channel.name} from {channel.guild.name}")
# only work for text channels for now # only work for text channels for now
if not isinstance(channel, discord.TextChannel): if not isinstance(channel, discord.TextChannel):
@ -155,18 +173,11 @@ class DiscordReader(BaseReader):
messages.append(thread_msg) messages.append(thread_msg)
except Exception as e: except Exception as e:
print("Encountered error: " + str(e)) print("Encountered error: " + str(e))
finally:
await self.close()
intents = discord.Intents.default()
intents.message_content = True
client = CustomClient(intents=intents)
await client.start(self.discord_token)
channel = client.get_channel(channel_id) channel = self.bot.get_channel(channel_id)
msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages] msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages]
return ("\n\n".join(msg_txt_list), channel.name) return ("<|endofstatement|>\n\n".join(msg_txt_list), channel.name)
async def load_data( async def load_data(
self, self,
@ -195,6 +206,6 @@ class DiscordReader(BaseReader):
) )
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first) (channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
results.append( results.append(
Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name}) Document(channel_content, extra_info={"channel_name": channel_name})
) )
return results return results
Loading…
Cancel
Save