diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py index 7ab7a39..977fa81 100644 --- a/cogs/search_service_cog.py +++ b/cogs/search_service_cog.py @@ -57,7 +57,9 @@ class SearchService(discord.Cog, name="SearchService"): return pages - async def search_command(self, ctx: discord.ApplicationContext, query, search_scope, nodes): + async def search_command( + self, ctx: discord.ApplicationContext, query, search_scope, nodes + ): """Command handler for the translation command""" user_api_key = None if USER_INPUT_API_KEYS: @@ -79,10 +81,15 @@ class SearchService(discord.Cog, name="SearchService"): try: response = await self.model.search(query, user_api_key, search_scope, nodes) except ValueError: - await ctx.respond("The Google Search API returned an error. Check the console for more details.", ephemeral=True) + await ctx.respond( + "The Google Search API returned an error. Check the console for more details.", + ephemeral=True, + ) return except Exception: - await ctx.respond("An error occurred. Check the console for more details.", ephemeral=True) + await ctx.respond( + "An error occurred. Check the console for more details.", ephemeral=True + ) traceback.print_exc() return @@ -95,7 +102,9 @@ class SearchService(discord.Cog, name="SearchService"): urls = "\n".join(f"<{url}>" for url in urls) query_response_message = f"**Query:**`\n\n{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}\n\n**Sources:**\n{urls}" - query_response_message = query_response_message.replace("<|endofstatement|>", "") + query_response_message = query_response_message.replace( + "<|endofstatement|>", "" + ) # If the response is too long, lets paginate using the discord pagination # helper @@ -107,4 +116,3 @@ class SearchService(discord.Cog, name="SearchService"): ) await paginator.respond(ctx.interaction) - diff --git a/models/index_model.py b/models/index_model.py index b83dee5..dd59a49 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -185,7 +185,6 @@ class Index_handler: index = GPTSimpleVectorIndex(document, embed_model=embed_model) return index - def index_gdoc(self, doc_id, embed_model) -> GPTSimpleVectorIndex: document = GoogleDocsReader().load_data(doc_id) index = GPTSimpleVectorIndex(document, embed_model=embed_model) @@ -304,9 +303,6 @@ class Index_handler: await ctx.respond("Failed to get link", ephemeral=True) return - - - # Check if the link contains youtube in it if "youtube" in link: index = await self.loop.run_in_executor( @@ -415,11 +411,19 @@ class Index_handler: for doc_id in [docmeta for docmeta in _index.docstore.docs.keys()] if isinstance(_index.docstore.get_document(doc_id), Document) ] - llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003", max_tokens=-1)) + llm_predictor = LLMPredictor( + llm=OpenAI(model_name="text-davinci-003", max_tokens=-1) + ) embedding_model = OpenAIEmbedding() tree_index = await self.loop.run_in_executor( - None, partial(GPTTreeIndex, documents=documents, llm_predictor=llm_predictor, embed_model=embedding_model) + None, + partial( + GPTTreeIndex, + documents=documents, + llm_predictor=llm_predictor, + embed_model=embedding_model, + ), ) await self.usage_service.update_usage(llm_predictor.last_token_usage) @@ -449,7 +453,12 @@ class Index_handler: embedding_model = OpenAIEmbedding() simple_index = await self.loop.run_in_executor( - None, partial(GPTSimpleVectorIndex, documents=documents, embed_model=embedding_model) + None, + partial( + GPTSimpleVectorIndex, + documents=documents, + embed_model=embedding_model, + ), ) await self.usage_service.update_usage( @@ -533,8 +542,10 @@ class Index_handler: await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True ) - query_response_message=f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}" - query_response_message = query_response_message.replace("<|endofstatement|>", "") + query_response_message = f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}" + query_response_message = query_response_message.replace( + "<|endofstatement|>", "" + ) embed_pages = await self.paginate_embed(query_response_message) paginator = pages.Paginator( pages=embed_pages, @@ -763,7 +774,6 @@ class ComposeModal(discord.ui.View): except discord.Forbidden: pass - try: await composing_message.delete() except: diff --git a/models/search_model.py b/models/search_model.py index d977819..9befc62 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -16,7 +16,8 @@ from gpt_index import ( Document, PromptHelper, LLMPredictor, - OpenAIEmbedding, SimpleDirectoryReader, + OpenAIEmbedding, + SimpleDirectoryReader, ) from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from langchain import OpenAI @@ -50,7 +51,6 @@ class Search: ).load_data(urls=[url]) return documents - async def index_pdf(self, url) -> list[Document]: # Download the PDF at the url and save it to a tempfile async with aiohttp.ClientSession() as session: @@ -79,11 +79,15 @@ class Search: if response.status == 200: data = await response.json() # Return a list of the top 2 links - return ([item["link"] for item in data["items"][:search_scope]], [ - item["link"] for item in data["items"] - ]) + return ( + [item["link"] for item in data["items"][:search_scope]], + [item["link"] for item in data["items"]], + ) else: - print("The Google Search API returned an error: " + str(response.status)) + print( + "The Google Search API returned an error: " + + str(response.status) + ) return ["An error occurred while searching.", None] async def search(self, query, user_api_key, search_scope, nodes): @@ -157,17 +161,32 @@ class Search: embedding_model = OpenAIEmbedding() - index = await self.loop.run_in_executor(None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model)) + index = await self.loop.run_in_executor( + None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model) + ) await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True ) - llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003", max_tokens=-1)) + llm_predictor = LLMPredictor( + llm=OpenAI(model_name="text-davinci-003", max_tokens=-1) + ) # Now we can search the index for a query: embedding_model.last_token_usage = 0 - response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, embed_model=embedding_model, llm_predictor=llm_predictor, similarity_top_k=nodes or DEFAULT_SEARCH_NODES, text_qa_template=self.qaprompt)) + response = await self.loop.run_in_executor( + None, + partial( + index.query, + query, + verbose=True, + embed_model=embedding_model, + llm_predictor=llm_predictor, + similarity_top_k=nodes or DEFAULT_SEARCH_NODES, + text_qa_template=self.qaprompt, + ), + ) await self.usage_service.update_usage(llm_predictor.last_token_usage) await self.usage_service.update_usage(