From 8ecc8286aa95099dad07308b672c8208c7349a97 Mon Sep 17 00:00:00 2001 From: Kaveen Kumarasinghe Date: Mon, 27 Feb 2023 14:14:06 -0500 Subject: [PATCH] improve index responses --- models/embed_statics_model.py | 34 ++++++++++++++++++++++- models/index_model.py | 51 ++++++++++++++++++++++++++--------- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/models/embed_statics_model.py b/models/embed_statics_model.py index af32922..178636f 100644 --- a/models/embed_statics_model.py +++ b/models/embed_statics_model.py @@ -84,9 +84,10 @@ class EmbedStatics: return embed @staticmethod - def get_index_set_success_embed(): + def get_index_set_success_embed(price="Unknown"): embed = discord.Embed( title="Index Added", + description=f"This index can now be queried and loaded with `/index query` and `/index load`\n\n||Total cost: {round(float(price), 6)}||", color=discord.Color.green(), ) # thumbnail of https://i.imgur.com/7JF0oGD.png @@ -242,3 +243,34 @@ class EmbedStatics: # thumbnail of https://i.imgur.com/VLJ32x7.png embed.set_thumbnail(url="https://i.imgur.com/VLJ32x7.png") return embed + + @staticmethod + def build_index_progress_embed(): + embed = discord.Embed( + title="Index Service", + description="Indexing...", + color=discord.Color.blurple(), + ) + embed.set_thumbnail(url="https://i.imgur.com/txHhNzL.png") + return embed + + @staticmethod + def build_index_query_progress_embed(query): + embed = discord.Embed( + title="Index Service", + description=f"Query:\n`{query}`\nQuerying...", + color=discord.Color.blurple(), + ) + embed.set_thumbnail(url="https://i.imgur.com/txHhNzL.png") + return embed + + @staticmethod + def build_index_query_success_embed(query, price="Unknown"): + embed = discord.Embed( + title="Index Service", + description=f"Query:\n`{query}`\nThe index query was successful.\n\n||Total cost: {round(float(price), 6)}||", + color=discord.Color.green(), + ) + # thumbnail of https://i.imgur.com/7JF0oGD.png + embed.set_thumbnail(url="https://i.imgur.com/7JF0oGD.png") + return embed diff --git a/models/index_model.py b/models/index_model.py index 44c8a64..f7df629 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -365,6 +365,10 @@ class Index_handler: ) ) return + + # Send indexing message + response = await ctx.respond(embed=EmbedStatics.build_index_progress_embed()) + async with aiofiles.tempfile.TemporaryDirectory() as temp_path: async with aiofiles.tempfile.NamedTemporaryFile( suffix=suffix, dir=temp_path, delete=False @@ -378,11 +382,17 @@ class Index_handler: embedding_model.last_token_usage, embeddings=True ) + try: + price = await self.usage_service.get_price(embedding_model.last_token_usage, embeddings=True) + except: + traceback.print_exc() + price = "Unknown" + file_name = file.filename self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name) - await ctx.respond(embed=EmbedStatics.get_index_set_success_embed()) + await response.edit(embed=EmbedStatics.get_index_set_success_embed(str(price))) except Exception as e: - await ctx.respond(embed=EmbedStatics.get_index_set_failure_embed(str(e))) + await ctx.channel.send(embed=EmbedStatics.get_index_set_failure_embed(str(e))) traceback.print_exc() async def set_link_index( @@ -393,18 +403,20 @@ class Index_handler: else: os.environ["OPENAI_API_KEY"] = user_api_key + response = await ctx.respond(embed=EmbedStatics.build_index_progress_embed()) try: + embedding_model = OpenAIEmbedding() # Pre-emptively connect and get the content-type of the response try: async with aiohttp.ClientSession() as session: - async with session.get(link, timeout=2) as response: - print(response.status) - if response.status == 200: - content_type = response.headers.get("content-type") + async with session.get(link, timeout=2) as _response: + print(_response.status) + if _response.status == 200: + content_type = _response.headers.get("content-type") else: - await ctx.respond( + await response.edit( embed=EmbedStatics.get_index_set_failure_embed( "Invalid URL or could not connect to the provided URL." ) @@ -412,7 +424,7 @@ class Index_handler: return except Exception as e: traceback.print_exc() - await ctx.respond( + await response.edit( embed=EmbedStatics.get_index_set_failure_embed( "Invalid URL or could not connect to the provided URL. " + str(e) @@ -435,6 +447,12 @@ class Index_handler: embedding_model.last_token_usage, embeddings=True ) + try: + price = await self.usage_service.get_price(embedding_model.last_token_usage, embeddings=True) + except: + traceback.print_exc() + price = "Unknown" + # Make the url look nice, remove https, useless stuff, random characters file_name = ( link.replace("https://", "") @@ -451,16 +469,16 @@ class Index_handler: self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name) except ValueError as e: - await ctx.respond(embed=EmbedStatics.get_index_set_failure_embed(str(e))) + await response.edit(embed=EmbedStatics.get_index_set_failure_embed(str(e))) traceback.print_exc() return except Exception as e: - await ctx.respond(embed=EmbedStatics.get_index_set_failure_embed(str(e))) + await response.edit(embed=EmbedStatics.get_index_set_failure_embed(str(e))) traceback.print_exc() return - await ctx.respond(embed=EmbedStatics.get_index_set_success_embed()) + await response.edit(embed=EmbedStatics.get_index_set_success_embed(price)) async def set_discord_index( self, @@ -697,6 +715,8 @@ class Index_handler: else: os.environ["OPENAI_API_KEY"] = user_api_key + ctx_response = await ctx.respond(embed=EmbedStatics.build_index_query_progress_embed(query)) + try: llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) embedding_model = OpenAIEmbedding() @@ -720,6 +740,12 @@ class Index_handler: await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True ) + + try: + total_price = round(await self.usage_service.get_price(llm_predictor.last_token_usage) + await self.usage_service.get_price(embedding_model.last_token_usage, True), 6) + except: + total_price = "Unknown" + query_response_message = f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}" query_response_message = query_response_message.replace( "<|endofstatement|>", "" @@ -730,10 +756,11 @@ class Index_handler: timeout=None, author_check=False, ) + await ctx_response.edit(embed=EmbedStatics.build_index_query_success_embed(query,total_price)) await paginator.respond(ctx.interaction) except Exception: traceback.print_exc() - await ctx.respond( + await ctx_response.edit( embed=EmbedStatics.get_index_query_failure_embed( "Failed to send query. You may not have an index set, load an index with /index load" ),