You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

235 lines
9.2 KiB

import os
import traceback
import asyncio
import discord
import aiofiles
from functools import partial
from typing import List, Optional
from pathlib import Path
from datetime import date, datetime
from gpt_index.readers.schema.base import Document
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader, QuestionAnswerPrompt, BeautifulSoupWebReader, \
GPTFaissIndex
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
from services.environment_service import EnvService, app_root_path
class Index_handler:
def __init__(self, bot):
self.bot = bot
self.openai_key = os.getenv("OPENAI_TOKEN")
self.index_storage = {}
self.loop = asyncio.get_running_loop()
self.qaprompt = QuestionAnswerPrompt(
"Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Never say '<|endofstatement|>'\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
def index_file(self, file_path):
document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document)
return index
def index_load_file(self, file_path):
index = GPTSimpleVectorIndex.load_from_disk(file_path)
return index
def index_discord(self, document):
index = GPTSimpleVectorIndex(document)
return index
def index_webpage(self, url):
documents = BeautifulSoupWebReader(website_extractor=DEFAULT_WEBSITE_EXTRACTOR).load_data(urls=[url])
index = GPTSimpleVectorIndex(documents)
return index
async def set_file_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
if file.content_type.startswith("text/plain"):
suffix = ".txt"
elif file.content_type.startswith("application/pdf"):
suffix = ".pdf"
else:
await ctx.respond("Only accepts txt or pdf files")
return
async with aiofiles.tempfile.TemporaryDirectory() as temp_path:
async with aiofiles.tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path, delete=False) as temp_file:
await file.save(temp_file.name)
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path))
self.index_storage[ctx.user.id] = index
await ctx.respond("Index set")
except Exception:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def set_link_index(self, ctx: discord.ApplicationContext, link: str, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
# TODO Link validation
try:
index = await self.loop.run_in_executor(None, partial(self.index_webpage, link))
self.index_storage[ctx.user.id] = index
except Exception:
await ctx.respond("Failed to set index")
traceback.print_exc()
await ctx.respond("Index set")
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
document = await self.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
self.index_storage[ctx.user.id] = index
await ctx.respond("Index set")
except Exception:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def load_index(self, ctx:discord.ApplicationContext, index, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
index_file = EnvService.find_shared_file(f"indexes/{index}")
index = await self.loop.run_in_executor(None, partial(self.index_load_file, index_file))
self.index_storage[ctx.user.id] = index
await ctx.respond("Loaded index")
except Exception as e:
await ctx.respond(e)
async def backup_discord(self, ctx: discord.ApplicationContext, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await self.load_data(channel_ids=channel_ids, limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
Path(app_root_path() / "indexes").mkdir(parents = True, exist_ok=True)
index.save_to_disk(app_root_path() / "indexes" / f"{ctx.guild.name.replace(' ', '-')}_{date.today()}-H{datetime.now().hour}.json")
await ctx.respond("Backup saved")
except Exception:
await ctx.respond("Failed to save backup")
traceback.print_exc()
async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode, text_qa_template=self.qaprompt))
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
except Exception:
await ctx.respond("Failed to send query", delete_after=10)
# Extracted functions from DiscordReader
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
"""Async read channel."""
messages: List[discord.Message] = []
try:
channel = self.bot.get_channel(channel_id)
print(f"Added {channel.name} from {channel.guild.name}")
# only work for text channels for now
if not isinstance(channel, discord.TextChannel):
raise ValueError(
f"Channel {channel_id} is not a text channel. "
"Only text channels are supported for now."
)
# thread_dict maps thread_id to thread
thread_dict = {}
for thread in channel.threads:
thread_dict[thread.id] = thread
async for msg in channel.history(
limit=limit, oldest_first=oldest_first
):
if msg.author.bot:
pass
else:
messages.append(msg)
if msg.id in thread_dict:
thread = thread_dict[msg.id]
async for thread_msg in thread.history(
limit=limit, oldest_first=oldest_first
):
messages.append(thread_msg)
except Exception as e:
print("Encountered error: " + str(e))
channel = self.bot.get_channel(channel_id)
msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages]
return ("<|endofstatement|>\n\n".join(msg_txt_list), channel.name)
async def load_data(
self,
channel_ids: List[int],
limit: Optional[int] = None,
oldest_first: bool = True,
) -> List[Document]:
"""Load data from the input directory.
Args:
channel_ids (List[int]): List of channel ids to read.
limit (Optional[int]): Maximum number of messages to read.
oldest_first (bool): Whether to read oldest messages first.
Defaults to `True`.
Returns:
List[Document]: List of documents.
"""
results: List[Document] = []
for channel_id in channel_ids:
if not isinstance(channel_id, int):
raise ValueError(
f"Channel id {channel_id} must be an integer, "
f"not {type(channel_id)}."
)
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
results.append(
Document(channel_content, extra_info={"channel_name": channel_name})
)
return results