You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

201 lines
7.6 KiB

import os
import traceback
import asyncio
import tempfile
import discord
from functools import partial
from typing import List, Optional
from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document
from gpt_index.response.schema import Response
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader
class Index_handler:
def __init__(self):
self.openai_key = os.getenv("OPENAI_TOKEN")
self.index_storage = {}
self.loop = asyncio.get_running_loop()
def index_file(self, file_path):
document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document)
return index
def index_discord(self, document):
index = GPTSimpleVectorIndex(document)
return index
async def set_file_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
temp_path = tempfile.TemporaryDirectory()
if file.content_type.startswith("text/plain"):
suffix = ".txt"
elif file.content_type.startswith("application/pdf"):
suffix = ".pdf"
else:
await ctx.respond("Only accepts txt or pdf files")
return
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, dir=temp_path.name, delete=False)
await file.save(temp_file.name)
index = await self.loop.run_in_executor(None, partial(self.index_file, temp_path.name))
self.index_storage[ctx.user.id] = index
temp_path.cleanup()
await ctx.respond("Index set")
except Exception:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
reader = DiscordReader()
if no_channel:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False)
else:
document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
self.index_storage[ctx.user.id] = index
await ctx.respond("Index set")
except Exception:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def query(self, ctx: discord.ApplicationContext, query:str, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
if not self.index_storage[ctx.user.id]:
await ctx.respond("You need to set an index", ephemeral=True, delete_after=5)
return
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
try:
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True))
except Exception:
ctx.respond("You haven't set and index", delete_after=5)
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
#Set our own version of the DiscordReader class that's async
class DiscordReader(BaseReader):
"""Discord reader.
Reads conversations from channels.
Args:
discord_token (Optional[str]): Discord token. If not provided, we
assume the environment variable `DISCORD_TOKEN` is set.
"""
def __init__(self, discord_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if discord_token is None:
discord_token = os.environ["DISCORD_TOKEN"]
if discord_token is None:
raise ValueError(
"Must specify `discord_token` or set environment "
"variable `DISCORD_TOKEN`."
)
self.discord_token = discord_token
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
"""Async read channel."""
messages: List[discord.Message] = []
class CustomClient(discord.Client):
async def on_ready(self) -> None:
try:
channel = client.get_channel(channel_id)
print(f"Added {channel.name} from {channel.guild.name}")
# only work for text channels for now
if not isinstance(channel, discord.TextChannel):
raise ValueError(
f"Channel {channel_id} is not a text channel. "
"Only text channels are supported for now."
)
# thread_dict maps thread_id to thread
thread_dict = {}
for thread in channel.threads:
thread_dict[thread.id] = thread
async for msg in channel.history(
limit=limit, oldest_first=oldest_first
):
messages.append(msg)
if msg.id in thread_dict:
thread = thread_dict[msg.id]
async for thread_msg in thread.history(
limit=limit, oldest_first=oldest_first
):
messages.append(thread_msg)
except Exception as e:
print("Encountered error: " + str(e))
finally:
await self.close()
intents = discord.Intents.default()
intents.message_content = True
client = CustomClient(intents=intents)
await client.start(self.discord_token)
msg_txt_list = [f"{m.author.display_name}: {m.content}" for m in messages]
channel = client.get_channel(channel_id)
return ("\n\n".join(msg_txt_list), channel.name)
async def load_data(
self,
channel_ids: List[int],
limit: Optional[int] = None,
oldest_first: bool = True,
) -> List[Document]:
"""Load data from the input directory.
Args:
channel_ids (List[int]): List of channel ids to read.
limit (Optional[int]): Maximum number of messages to read.
oldest_first (bool): Whether to read oldest messages first.
Defaults to `True`.
Returns:
List[Document]: List of documents.
"""
results: List[Document] = []
for channel_id in channel_ids:
if not isinstance(channel_id, int):
raise ValueError(
f"Channel id {channel_id} must be an integer, "
f"not {type(channel_id)}."
)
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
results.append(
Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name})
)
return results