@ -2,8 +2,14 @@ import os
import traceback
import asyncio
import tempfile
from functools import partial
import discord
from functools import partial
from typing import List , Optional
from gpt_index . readers . base import BaseReader
from gpt_index . readers . schema . base import Document
from gpt_index . response . schema import Response
from gpt_index import GPTSimpleVectorIndex , SimpleDirectoryReader
@ -19,9 +25,12 @@ class Index_handler:
document = SimpleDirectoryReader ( file_path ) . load_data ( )
index = GPTSimpleVectorIndex ( document )
return index
def index_discord ( self , document ) :
index = GPTSimpleVectorIndex ( document )
return index
async def set_index ( self , ctx : discord . ApplicationContext , file : discord . Attachment , user_api_key ) :
async def set_ file_ index( self , ctx : discord . ApplicationContext , file : discord . Attachment , user_api_key ) :
if not user_api_key :
os . environ [ " OPENAI_API_KEY " ] = self . openai_key
else :
@ -46,7 +55,32 @@ class Index_handler:
await ctx . respond ( " Failed to set index " )
traceback . print_exc ( )
async def query ( self , ctx : discord . ApplicationContext , query , user_api_key ) :
async def set_discord_index ( self , ctx : discord . ApplicationContext , channel : discord . TextChannel , user_api_key , no_channel = False ) :
if not user_api_key :
os . environ [ " OPENAI_API_KEY " ] = self . openai_key
else :
os . environ [ " OPENAI_API_KEY " ] = user_api_key
try :
reader = DiscordReader ( )
if no_channel :
channel_ids : List [ int ] = [ ]
for c in ctx . guild . text_channels :
channel_ids . append ( c . id )
document = await reader . load_data ( channel_ids = channel_ids , limit = 300 , oldest_first = False )
else :
document = await reader . load_data ( channel_ids = [ channel . id ] , limit = 1000 , oldest_first = False )
index = await self . loop . run_in_executor ( None , partial ( self . index_discord , document ) )
self . index_storage [ ctx . user . id ] = index
await ctx . respond ( " Index set " )
except Exception :
await ctx . respond ( " Failed to set index " )
traceback . print_exc ( )
async def query ( self , ctx : discord . ApplicationContext , query : str , user_api_key ) :
if not user_api_key :
os . environ [ " OPENAI_API_KEY " ] = self . openai_key
else :
@ -57,5 +91,111 @@ class Index_handler:
return
index : GPTSimpleVectorIndex = self . index_storage [ ctx . user . id ]
response = await self . loop . run_in_executor ( None , partial ( index . query , query , verbose = True ) )
await ctx . respond ( f " Query response: { response } " )
try :
response : Response = await self . loop . run_in_executor ( None , partial ( index . query , query , verbose = True ) )
except Exception :
ctx . respond ( " You haven ' t set and index " , delete_after = 5 )
await ctx . respond ( f " **Query:** \n \n { query . strip ( ) } \n \n **Query response:** \n \n { response . response . strip ( ) } " )
#Set our own version of the DiscordReader class that's async
class DiscordReader ( BaseReader ) :
""" Discord reader.
Reads conversations from channels .
Args :
discord_token ( Optional [ str ] ) : Discord token . If not provided , we
assume the environment variable ` DISCORD_TOKEN ` is set .
"""
def __init__ ( self , discord_token : Optional [ str ] = None ) - > None :
""" Initialize with parameters. """
if discord_token is None :
discord_token = os . environ [ " DISCORD_TOKEN " ]
if discord_token is None :
raise ValueError (
" Must specify `discord_token` or set environment "
" variable `DISCORD_TOKEN`. "
)
self . discord_token = discord_token
async def read_channel ( self , channel_id : int , limit : Optional [ int ] , oldest_first : bool ) - > str :
""" Async read channel. """
messages : List [ discord . Message ] = [ ]
class CustomClient ( discord . Client ) :
async def on_ready ( self ) - > None :
try :
channel = client . get_channel ( channel_id )
print ( f " Added { channel . name } from { channel . guild . name } " )
# only work for text channels for now
if not isinstance ( channel , discord . TextChannel ) :
raise ValueError (
f " Channel { channel_id } is not a text channel. "
" Only text channels are supported for now. "
)
# thread_dict maps thread_id to thread
thread_dict = { }
for thread in channel . threads :
thread_dict [ thread . id ] = thread
async for msg in channel . history (
limit = limit , oldest_first = oldest_first
) :
messages . append ( msg )
if msg . id in thread_dict :
thread = thread_dict [ msg . id ]
async for thread_msg in thread . history (
limit = limit , oldest_first = oldest_first
) :
messages . append ( thread_msg )
except Exception as e :
print ( " Encountered error: " + str ( e ) )
finally :
await self . close ( )
intents = discord . Intents . default ( )
intents . message_content = True
client = CustomClient ( intents = intents )
await client . start ( self . discord_token )
msg_txt_list = [ f " { m . author . display_name } : { m . content } " for m in messages ]
channel = client . get_channel ( channel_id )
return ( " \n \n " . join ( msg_txt_list ) , channel . name )
async def load_data (
self ,
channel_ids : List [ int ] ,
limit : Optional [ int ] = None ,
oldest_first : bool = True ,
) - > List [ Document ] :
""" Load data from the input directory.
Args :
channel_ids ( List [ int ] ) : List of channel ids to read .
limit ( Optional [ int ] ) : Maximum number of messages to read .
oldest_first ( bool ) : Whether to read oldest messages first .
Defaults to ` True ` .
Returns :
List [ Document ] : List of documents .
"""
results : List [ Document ] = [ ]
for channel_id in channel_ids :
if not isinstance ( channel_id , int ) :
raise ValueError (
f " Channel id { channel_id } must be an integer, "
f " not { type ( channel_id ) } . "
)
( channel_content , channel_name ) = await self . read_channel ( channel_id , limit = limit , oldest_first = oldest_first )
results . append (
Document ( channel_content , extra_info = { " channel_id " : channel_id , " channel_name " : channel_name } )
)
return results