Add option for response mode on queries

Rene Teigen 2 years ago
parent 37a20a7e37
commit dd9cb0ce4c

@ -532,9 +532,16 @@ class Commands(discord.Cog, name="Commands"):
) )
@discord.guild_only() @discord.guild_only()
@discord.option(name="query", description="What to query the index", required=True) @discord.option(name="query", description="What to query the index", required=True)
async def query(self, ctx:discord.ApplicationContext, query: str): @discord.option(
await self.index_cog.query_command(ctx, query) name="response_mode",
description="Response mode",
guild_ids=ALLOWED_GUILDS,
required=False,
default="default",
choices=["default", "compact", "tree_summarize"]
)
async def query(self, ctx:discord.ApplicationContext, query: str, response_mode: str):
await self.index_cog.query_command(ctx, query, response_mode)
# #

@ -46,7 +46,7 @@ class IndexService(discord.Cog, name="IndexService"):
await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key) await self.index_handler.set_discord_index(ctx, channel, user_api_key=user_api_key)
async def query_command(self, ctx, query): async def query_command(self, ctx, query, response_mode):
"""Command handler to query your index""" """Command handler to query your index"""
user_api_key = None user_api_key = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
@ -55,4 +55,4 @@ class IndexService(discord.Cog, name="IndexService"):
return return
await ctx.defer() await ctx.defer()
await self.index_handler.query(ctx, query, user_api_key=user_api_key) await self.index_handler.query(ctx, query, response_mode, user_api_key)

@ -80,22 +80,18 @@ class Index_handler:
async def query(self, ctx: discord.ApplicationContext, query:str, user_api_key): async def query(self, ctx: discord.ApplicationContext, query:str, response_mode, user_api_key):
if not user_api_key: if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key os.environ["OPENAI_API_KEY"] = self.openai_key
else: else:
os.environ["OPENAI_API_KEY"] = user_api_key os.environ["OPENAI_API_KEY"] = user_api_key
if not self.index_storage[ctx.user.id]:
await ctx.respond("You need to set an index", ephemeral=True, delete_after=5)
return
index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
try: try:
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True)) index: GPTSimpleVectorIndex = self.index_storage[ctx.user.id]
response: Response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, response_mode=response_mode))
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
except Exception: except Exception:
ctx.respond("You haven't set and index", delete_after=5) await ctx.respond("You haven't set and index", delete_after=10)
await ctx.respond(f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}")
#Set our own version of the DiscordReader class that's async #Set our own version of the DiscordReader class that's async
@ -147,13 +143,16 @@ class DiscordReader(BaseReader):
async for msg in channel.history( async for msg in channel.history(
limit=limit, oldest_first=oldest_first limit=limit, oldest_first=oldest_first
): ):
messages.append(msg) if msg.author.bot:
if msg.id in thread_dict: pass
thread = thread_dict[msg.id] else:
async for thread_msg in thread.history( messages.append(msg)
limit=limit, oldest_first=oldest_first if msg.id in thread_dict:
): thread = thread_dict[msg.id]
messages.append(thread_msg) async for thread_msg in thread.history(
limit=limit, oldest_first=oldest_first
):
messages.append(thread_msg)
except Exception as e: except Exception as e:
print("Encountered error: " + str(e)) print("Encountered error: " + str(e))
finally: finally:
@ -164,8 +163,8 @@ class DiscordReader(BaseReader):
client = CustomClient(intents=intents) client = CustomClient(intents=intents)
await client.start(self.discord_token) await client.start(self.discord_token)
msg_txt_list = [f"{m.author.display_name}: {m.content}" for m in messages]
channel = client.get_channel(channel_id) channel = client.get_channel(channel_id)
msg_txt_list = [f"user:{m.author.display_name}, content:{m.content}" for m in messages]
return ("\n\n".join(msg_txt_list), channel.name) return ("\n\n".join(msg_txt_list), channel.name)

Loading…
Cancel
Save