Added discord channel and whole server indexing

A channel index does 1000 messages
Whole server index does 300 messages per channel
Might need to add more restrictions since all the index commands are expensive
Rene Teigen 1 year ago
parent 64ce627bb7
commit 37a20a7e37

@ -503,15 +503,26 @@ class Commands(discord.Cog, name="Commands"):
@add_to_group("index")
@discord.slash_command(
name="set",
name="set_file",
description="Set an index to query from",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
@discord.option(name="file", description="A file to create the index from", required=True, input_type=discord.Attachment)
async def set(self, ctx:discord.ApplicationContext, file: discord.Attachment):
@discord.option(name="file", description="A file to create the index from", required=True, input_type=discord.SlashCommandOptionType.attachment)
async def set_file(self, ctx:discord.ApplicationContext, file: discord.Attachment):
await self.index_cog.set_index_command(ctx, file)
@add_to_group("index")
@discord.slash_command(
name="set_discord",
description="Set a index from a discord channel",
guild_ids=ALLOWED_GUILDS
)
@discord.guild_only()
@discord.option(name="channel", description="A channel to create the index from", required=False, input_type=discord.SlashCommandOptionType.channel)
async def set_discord(self, ctx:discord.ApplicationContext, channel: discord.TextChannel):
await self.index_cog.set_discord_command(ctx, channel)
@add_to_group("index")
@discord.slash_command(

@ -27,7 +27,23 @@ class IndexService(discord.Cog, name="IndexService"):
return
await ctx.defer(ephemeral=True)
await self.index_handler.set_index(ctx, file, user_api_key=user_api_key)
await self.index_handler.set_file_index(ctx, file, user_api_key=user_api_key)
async def set_discord_command(self, ctx, channel: discord.TextChannel = None):
"""Command handler to set a channel as your personal index"""
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB)
if not user_api_key:
return
await ctx.defer(ephemeral=True)
if not channel:
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key, no_channel=True)
return
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key)
async def query_command(self, ctx, query):

@ -2,8 +2,14 @@ import os
import traceback
import asyncio
import tempfile
from functools import partial
import discord
from functools import partial
from typing import List, Optional
from gpt_index.readers.base import BaseReader
from gpt_index.readers.schema.base import Document
from gpt_index.response.schema import Response
from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader
@ -19,9 +25,12 @@ class Index_handler:
document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document)
return index
def index_discord(self, document):
index = GPTSimpleVectorIndex(document)
return index
async def set_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key):
async def set_file_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
@ -46,7 +55,32 @@ class Index_handler:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def query(self, ctx: discord.ApplicationContext, query, user_api_key):
async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
try:
reader = DiscordReader()
if no_channel:
channel_ids:List[int] = []
for c in ctx.guild.text_channels:
channel_ids.append(c.id)
document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False)
else:
document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False)
index = await self.loop.run_in_executor(None, partial(self.index_discord, document))
self.index_storage[ctx.user.id] = index
await ctx.respond("Index set")
except Exception:
await ctx.respond("Failed to set index")
traceback.print_exc()
async def query(self, ctx: discord.ApplicationContext, query:str, user_api_key):
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
@ -57,5 +91,111 @@ class Index_handler:
return
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True))
await ctx.respond(f"Query response: {response}")
try:
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True))
except Exception:
ctx.respond("You haven't set and index", delete_after=5)
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
#Set our own version of the DiscordReader class that's async
class DiscordReader(BaseReader):
"""Discord reader.
Reads conversations from channels.
Args:
discord_token (Optional[str]): Discord token. If not provided, we
assume the environment variable `DISCORD_TOKEN` is set.
"""
def __init__(self, discord_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if discord_token is None:
discord_token = os.environ["DISCORD_TOKEN"]
if discord_token is None:
raise ValueError(
"Must specify `discord_token` or set environment "
"variable `DISCORD_TOKEN`."
)
self.discord_token = discord_token
async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str:
"""Async read channel."""
messages: List[discord.Message] = []
class CustomClient(discord.Client):
async def on_ready(self) -> None:
try:
channel = client.get_channel(channel_id)
print(f"Added {channel.name} from {channel.guild.name}")
# only work for text channels for now
if not isinstance(channel, discord.TextChannel):
raise ValueError(
f"Channel {channel_id} is not a text channel. "
"Only text channels are supported for now."
)
# thread_dict maps thread_id to thread
thread_dict = {}
for thread in channel.threads:
thread_dict[thread.id] = thread
async for msg in channel.history(
limit=limit, oldest_first=oldest_first
):
messages.append(msg)
if msg.id in thread_dict:
thread = thread_dict[msg.id]
async for thread_msg in thread.history(
limit=limit, oldest_first=oldest_first
):
messages.append(thread_msg)
except Exception as e:
print("Encountered error: " + str(e))
finally:
await self.close()
intents = discord.Intents.default()
intents.message_content = True
client = CustomClient(intents=intents)
await client.start(self.discord_token)
msg_txt_list = [f"{m.author.display_name}: {m.content}" for m in messages]
channel = client.get_channel(channel_id)
return ("\n\n".join(msg_txt_list), channel.name)
async def load_data(
self,
channel_ids: List[int],
limit: Optional[int] = None,
oldest_first: bool = True,
) -> List[Document]:
"""Load data from the input directory.
Args:
channel_ids (List[int]): List of channel ids to read.
limit (Optional[int]): Maximum number of messages to read.
oldest_first (bool): Whether to read oldest messages first.
Defaults to `True`.
Returns:
List[Document]: List of documents.
"""
results: List[Document] = []
for channel_id in channel_ids:
if not isinstance(channel_id, int):
raise ValueError(
f"Channel id {channel_id} must be an integer, "
f"not {type(channel_id)}."
)
(channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first)
results.append(
Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name})
)
return results
Loading…
Cancel
Save