|
|
|
@ -1,6 +1,8 @@
|
|
|
|
|
import os
|
|
|
|
|
import traceback
|
|
|
|
|
import asyncio
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
import discord
|
|
|
|
|
import aiofiles
|
|
|
|
|
from functools import partial
|
|
|
|
@ -10,18 +12,44 @@ from datetime import date, datetime
|
|
|
|
|
|
|
|
|
|
from gpt_index.readers.schema.base import Document
|
|
|
|
|
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, BeautifulSoupWebReader, \
|
|
|
|
|
GPTFaissIndex
|
|
|
|
|
GPTFaissIndex, GPTListIndex, QueryMode, GPTTreeIndex
|
|
|
|
|
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
|
|
|
|
|
|
|
|
|
|
from gpt_index.composability import ComposableGraph
|
|
|
|
|
|
|
|
|
|
from services.environment_service import EnvService, app_root_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IndexData:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.queryable_index = None
|
|
|
|
|
self.individual_indexes = []
|
|
|
|
|
|
|
|
|
|
# A safety check for the future
|
|
|
|
|
def get_index_or_throw(self):
|
|
|
|
|
if not self.queryable():
|
|
|
|
|
raise Exception("An index access was attempted before an index was created. This is a programmer error, please report this to the maintainers.")
|
|
|
|
|
return self.queryable_index
|
|
|
|
|
def queryable(self):
|
|
|
|
|
return self.queryable_index is not None
|
|
|
|
|
|
|
|
|
|
def add_index(self, index, user_id, file_name):
|
|
|
|
|
self.individual_indexes.append(index)
|
|
|
|
|
self.queryable_index = index
|
|
|
|
|
|
|
|
|
|
# Create a folder called "indexes/{USER_ID}" if it doesn't exist already
|
|
|
|
|
Path(f"{app_root_path()}/indexes/{user_id}").mkdir(parents=True, exist_ok=True)
|
|
|
|
|
print(f"{app_root_path()}/indexes/{user_id}")
|
|
|
|
|
# Save the index to file under the user id
|
|
|
|
|
index.save_to_disk(app_root_path() / "indexes" / f"{str(user_id)}"/f"{file_name}_{date.today()}-H{datetime.now().hour}.json")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Index_handler:
|
|
|
|
|
def __init__(self, bot):
|
|
|
|
|
self.bot = bot
|
|
|
|
|
self.openai_key = os.getenv("OPENAI_TOKEN")
|
|
|
|
|
self.index_storage = {}
|
|
|
|
|
self.index_storage = defaultdict(IndexData)
|
|
|
|
|
self.loop = asyncio.get_running_loop()
|
|
|
|
|
self.qaprompt = QuestionAnswerPrompt(
|
|
|
|
|
"Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n"
|
|
|
|
@ -67,8 +95,11 @@ class Index_handler:
|
|
|
|
|
async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path, delete=False) as temp_file:
|
|
|
|
|
await file.save(temp_file.name)
|
|
|
|
|
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path))
|
|
|
|
|
self.index_storage[ctx.user.id] = index
|
|
|
|
|
await ctx.respond("Index set")
|
|
|
|
|
|
|
|
|
|
file_name = file.filename
|
|
|
|
|
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name)
|
|
|
|
|
|
|
|
|
|
await ctx.respond("Index added to your indexes")
|
|
|
|
|
except Exception:
|
|
|
|
|
await ctx.respond("Failed to set index")
|
|
|
|
|
traceback.print_exc()
|
|
|
|
@ -84,7 +115,10 @@ class Index_handler:
|
|
|
|
|
|
|
|
|
|
index = await self.loop.run_in_executor(None, partial(self.index_webpage, link))
|
|
|
|
|
|
|
|
|
|
self.index_storage[ctx.user.id] = index
|
|
|
|
|
# Make the url look nice, remove https, useless stuff, random characters
|
|
|
|
|
file_name = link.replace("https://", "").replace("http://", "").replace("www.", "").replace("/", "_").replace("?", "_").replace("&", "_").replace("=", "_").replace("-", "_").replace(".", "_")
|
|
|
|
|
|
|
|
|
|
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name)
|
|
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
await ctx.respond("Failed to set index")
|
|
|
|
@ -102,7 +136,7 @@ class Index_handler:
|
|
|
|
|
try:
|
|
|
|
|
document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
|
|
|
|
|
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
|
|
|
|
|
self.index_storage[ctx.user.id] = index
|
|
|
|
|
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, channel.name)
|
|
|
|
|
await ctx.respond("Index set")
|
|
|
|
|
except Exception:
|
|
|
|
|
await ctx.respond("Failed to set index")
|
|
|
|
@ -116,9 +150,9 @@ class Index_handler:
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = user_api_key
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
index_file = EnvService.find_shared_file(f"indexes/{index}")
|
|
|
|
|
index_file = EnvService.find_shared_file(f"indexes/{ctx.user.id}/{index}")
|
|
|
|
|
index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
|
|
|
|
|
self.index_storage[ctx.user.id] = index
|
|
|
|
|
self.index_storage[ctx.user.id].queryable_index = index
|
|
|
|
|
await ctx.respond("Loaded index")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await ctx.respond(e)
|
|
|
|
@ -153,10 +187,15 @@ class Index_handler:
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = user_api_key
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
|
|
|
|
|
response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode, text_qa_template=self.qaprompt))
|
|
|
|
|
index: [GPTSimpleVectorIndex, ComposableGraph] = self.index_storage[ctx.user.id].get_index_or_throw()
|
|
|
|
|
if isinstance(index, GPTSimpleVectorIndex):
|
|
|
|
|
response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, text_qa_template=self.qaprompt))
|
|
|
|
|
else:
|
|
|
|
|
response = await self.loop.run_in_executor(None,
|
|
|
|
|
partial(index.query, query, query_configs=[], verbose=True))
|
|
|
|
|
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
|
|
|
|
|
except Exception:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
await ctx.respond("Failed to send query", delete_after=10)
|
|
|
|
|
|
|
|
|
|
# Extracted functions from DiscordReader
|
|
|
|
|