diff --git a/cogs/commands.py b/cogs/commands.py index 89b7a4c..049d37b 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -849,5 +849,7 @@ class Commands(discord.Cog, name="Commands"): min_value=1, ) @discord.guild_only() - async def search(self, ctx: discord.ApplicationContext, query: str, scope: int, nodes:int): + async def search( + self, ctx: discord.ApplicationContext, query: str, scope: int, nodes: int + ): await self.search_cog.search_command(ctx, query, scope, nodes) diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py index f84b84a..7237161 100644 --- a/cogs/search_service_cog.py +++ b/cogs/search_service_cog.py @@ -47,7 +47,7 @@ class SearchService(discord.Cog, name="SearchService"): await ctx.defer() - response = await self.model.search(query, user_api_key, search_scope,nodes) + response = await self.model.search(query, user_api_key, search_scope, nodes) await ctx.respond( f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}" diff --git a/models/index_model.py b/models/index_model.py index 0d9265e..fd21dea 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -27,7 +27,8 @@ from gpt_index import ( LLMPredictor, QueryConfig, PromptHelper, - IndexStructType, OpenAIEmbedding, + IndexStructType, + OpenAIEmbedding, ) from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR @@ -38,7 +39,9 @@ from services.environment_service import EnvService, app_root_path SHORT_TO_LONG_CACHE = {} -def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predictor, embed_model): +def get_and_query( + user_id, index_storage, query, response_mode, nodes, llm_predictor, embed_model +): index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[ user_id ].get_index_or_throw() @@ -334,8 +337,14 @@ class Index_handler: ] llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) embedding_model = OpenAIEmbedding() - tree_index = GPTTreeIndex(documents=documents, llm_predictor=llm_predictor, embed_model=embedding_model) - await self.usage_service.update_usage(llm_predictor.last_token_usage+embedding_model.last_token_usage) + tree_index = GPTTreeIndex( + documents=documents, + llm_predictor=llm_predictor, + embed_model=embedding_model, + ) + await self.usage_service.update_usage( + llm_predictor.last_token_usage + embedding_model.last_token_usage + ) # Now we have a list of tree indexes, we can compose them if not name: @@ -358,7 +367,9 @@ class Index_handler: embedding_model = OpenAIEmbedding() # Add everything into a simple vector index - simple_index = GPTSimpleVectorIndex(documents=documents, embed_model=embedding_model) + simple_index = GPTSimpleVectorIndex( + documents=documents, embed_model=embedding_model + ) await self.usage_service.update_usage(embedding_model.last_token_usage) if not name: @@ -425,7 +436,7 @@ class Index_handler: response_mode, nodes, llm_predictor, - embedding_model + embedding_model, ), ) print("The last token usage was ", llm_predictor.last_token_usage) diff --git a/models/search_model.py b/models/search_model.py index 6bb92b9..57d8362 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -11,7 +11,10 @@ from gpt_index import ( QuestionAnswerPrompt, GPTSimpleVectorIndex, BeautifulSoupWebReader, - Document, PromptHelper, LLMPredictor, OpenAIEmbedding, + Document, + PromptHelper, + LLMPredictor, + OpenAIEmbedding, ) from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from langchain import OpenAI @@ -96,7 +99,17 @@ class Search: llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) embedding_model = OpenAIEmbedding() # Now we can search the index for a query: - response = index.query(query,verbose=True,embed_model=embedding_model,llm_predictor=llm_predictor,prompt_helper=prompthelper, similarity_top_k=nodes or DEFAULT_SEARCH_NODES, text_qa_template=self.qaprompt) - await self.usage_service.update_usage(llm_predictor.last_token_usage + embedding_model.last_token_usage) + response = index.query( + query, + verbose=True, + embed_model=embedding_model, + llm_predictor=llm_predictor, + prompt_helper=prompthelper, + similarity_top_k=nodes or DEFAULT_SEARCH_NODES, + text_qa_template=self.qaprompt, + ) + await self.usage_service.update_usage( + llm_predictor.last_token_usage, embedding_model.last_token_usage + ) return response