From 37a20a7e37686205e8f9a0835c5ae1f2c966b0e6 Mon Sep 17 00:00:00 2001 From: Rene Teigen Date: Tue, 31 Jan 2023 13:01:35 +0000 Subject: [PATCH] Added discord channel and whole server indexing A channel index does 1000 messages Whole server index does 300 messages per channel Might need to add more restrictions since all the index commands are expensive --- cogs/commands.py | 17 ++++- cogs/index_service_cog.py | 18 ++++- models/index_model.py | 150 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 176 insertions(+), 9 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index 241a7b2..aae8354 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -503,15 +503,26 @@ class Commands(discord.Cog, name="Commands"): @add_to_group("index") @discord.slash_command( - name="set", + name="set_file", description="Set an index to query from", guild_ids=ALLOWED_GUILDS ) @discord.guild_only() - @discord.option(name="file", description="A file to create the index from", required=True, input_type=discord.Attachment) - async def set(self, ctx:discord.ApplicationContext, file: discord.Attachment): + @discord.option(name="file", description="A file to create the index from", required=True, input_type=discord.SlashCommandOptionType.attachment) + async def set_file(self, ctx:discord.ApplicationContext, file: discord.Attachment): await self.index_cog.set_index_command(ctx, file) + @add_to_group("index") + @discord.slash_command( + name="set_discord", + description="Set a index from a discord channel", + guild_ids=ALLOWED_GUILDS + ) + @discord.guild_only() + @discord.option(name="channel", description="A channel to create the index from", required=False, input_type=discord.SlashCommandOptionType.channel) + async def set_discord(self, ctx:discord.ApplicationContext, channel: discord.TextChannel): + await self.index_cog.set_discord_command(ctx, channel) + @add_to_group("index") @discord.slash_command( diff --git a/cogs/index_service_cog.py b/cogs/index_service_cog.py index f3dada1..3e33582 100644 --- a/cogs/index_service_cog.py +++ b/cogs/index_service_cog.py @@ -27,7 +27,23 @@ class IndexService(discord.Cog, name="IndexService"): return await ctx.defer(ephemeral=True) - await self.index_handler.set_index(ctx, file, user_api_key=user_api_key) + await self.index_handler.set_file_index(ctx, file, user_api_key=user_api_key) + + + async def set_discord_command(self, ctx, channel: discord.TextChannel = None): + """Command handler to set a channel as your personal index""" + + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) + if not user_api_key: + return + + await ctx.defer(ephemeral=True) + if not channel: + await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key, no_channel=True) + return + await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key) async def query_command(self, ctx, query): diff --git a/models/index_model.py b/models/index_model.py index 2d6d2af..5c49781 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -2,8 +2,14 @@ import os import traceback import asyncio import tempfile -from functools import partial import discord +from functools import partial +from typing import List, Optional + + +from gpt_index.readers.base import BaseReader +from gpt_index.readers.schema.base import Document +from gpt_index.response.schema import Response from gpt_index import GPTSimpleVectorIndex, SimpleDirectoryReader @@ -19,9 +25,12 @@ class Index_handler: document = SimpleDirectoryReader(file_path).load_data() index = GPTSimpleVectorIndex(document) return index + def index_discord(self, document): + index = GPTSimpleVectorIndex(document) + return index - async def set_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key): + async def set_file_index(self, ctx: discord.ApplicationContext, file: discord.Attachment, user_api_key): if not user_api_key: os.environ["OPENAI_API_KEY"] = self.openai_key else: @@ -46,7 +55,32 @@ class Index_handler: await ctx.respond("Failed to set index") traceback.print_exc() - async def query(self, ctx: discord.ApplicationContext, query, user_api_key): + + async def set_discord_index(self, ctx: discord.ApplicationContext, channel: discord.TextChannel, user_api_key, no_channel=False): + if not user_api_key: + os.environ["OPENAI_API_KEY"] = self.openai_key + else: + os.environ["OPENAI_API_KEY"] = user_api_key + + try: + reader = DiscordReader() + if no_channel: + channel_ids:List[int] = [] + for c in ctx.guild.text_channels: + channel_ids.append(c.id) + document = await reader.load_data(channel_ids=channel_ids, limit=300, oldest_first=False) + else: + document = await reader.load_data(channel_ids=[channel.id], limit=1000, oldest_first=False) + index = await self.loop.run_in_executor(None, partial(self.index_discord, document)) + self.index_storage[ctx.user.id] = index + await ctx.respond("Index set") + except Exception: + await ctx.respond("Failed to set index") + traceback.print_exc() + + + + async def query(self, ctx: discord.ApplicationContext, query:str, user_api_key): if not user_api_key: os.environ["OPENAI_API_KEY"] = self.openai_key else: @@ -57,5 +91,111 @@ class Index_handler: return index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id] - response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True)) - await ctx.respond(f"Query response: {response}") \ No newline at end of file + try: + response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True)) + except Exception: + ctx.respond("You haven't set and index", delete_after=5) + await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}") + + +#Set our own version of the DiscordReader class that's async + +class DiscordReader(BaseReader): + """Discord reader. + + Reads conversations from channels. + + Args: + discord_token (Optional[str]): Discord token. If not provided, we + assume the environment variable `DISCORD_TOKEN` is set. + + """ + + def __init__(self, discord_token: Optional[str] = None) -> None: + """Initialize with parameters.""" + if discord_token is None: + discord_token = os.environ["DISCORD_TOKEN"] + if discord_token is None: + raise ValueError( + "Must specify `discord_token` or set environment " + "variable `DISCORD_TOKEN`." + ) + + self.discord_token = discord_token + + async def read_channel(self, channel_id: int, limit: Optional[int], oldest_first: bool) -> str: + """Async read channel.""" + + messages: List[discord.Message] = [] + + class CustomClient(discord.Client): + async def on_ready(self) -> None: + try: + channel = client.get_channel(channel_id) + print(f"Added {channel.name} from {channel.guild.name}") + # only work for text channels for now + if not isinstance(channel, discord.TextChannel): + raise ValueError( + f"Channel {channel_id} is not a text channel. " + "Only text channels are supported for now." + ) + # thread_dict maps thread_id to thread + thread_dict = {} + for thread in channel.threads: + thread_dict[thread.id] = thread + + async for msg in channel.history( + limit=limit, oldest_first=oldest_first + ): + messages.append(msg) + if msg.id in thread_dict: + thread = thread_dict[msg.id] + async for thread_msg in thread.history( + limit=limit, oldest_first=oldest_first + ): + messages.append(thread_msg) + except Exception as e: + print("Encountered error: " + str(e)) + finally: + await self.close() + + intents = discord.Intents.default() + intents.message_content = True + client = CustomClient(intents=intents) + await client.start(self.discord_token) + + msg_txt_list = [f"{m.author.display_name}: {m.content}" for m in messages] + channel = client.get_channel(channel_id) + + return ("\n\n".join(msg_txt_list), channel.name) + + async def load_data( + self, + channel_ids: List[int], + limit: Optional[int] = None, + oldest_first: bool = True, + ) -> List[Document]: + """Load data from the input directory. + + Args: + channel_ids (List[int]): List of channel ids to read. + limit (Optional[int]): Maximum number of messages to read. + oldest_first (bool): Whether to read oldest messages first. + Defaults to `True`. + + Returns: + List[Document]: List of documents. + + """ + results: List[Document] = [] + for channel_id in channel_ids: + if not isinstance(channel_id, int): + raise ValueError( + f"Channel id {channel_id} must be an integer, " + f"not {type(channel_id)}." + ) + (channel_content, channel_name) = await self.read_channel(channel_id, limit=limit, oldest_first=oldest_first) + results.append( + Document(channel_content, extra_info={"channel_id": channel_id, "channel_name": channel_name}) + ) + return results \ No newline at end of file