diff --git a/cogs/commands.py b/cogs/commands.py index ec4f154..b2b3851 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -848,8 +848,14 @@ class Commands(discord.Cog, name="Commands"): max_value=4, min_value=1, ) + @discord.option( + name="deep", + description="Do a more intensive, long-running search", + required=False, + input_type=discord.SlashCommandOptionType.boolean, + ) @discord.guild_only() async def search( - self, ctx: discord.ApplicationContext, query: str, scope: int, nodes: int + self, ctx: discord.ApplicationContext, query: str, scope: int, nodes: int, deep: bool ): - await self.search_cog.search_command(ctx, query, scope, nodes) + await self.search_cog.search_command(ctx, query, scope, nodes, deep) diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py index 2b6326a..3307d65 100644 --- a/cogs/search_service_cog.py +++ b/cogs/search_service_cog.py @@ -68,7 +68,7 @@ class SearchService(discord.Cog, name="SearchService"): return pages async def search_command( - self, ctx: discord.ApplicationContext, query, search_scope, nodes, redo=None + self, ctx: discord.ApplicationContext, query, search_scope, nodes, deep, redo=None ): """Command handler for the translation command""" user_api_key = None @@ -90,7 +90,7 @@ class SearchService(discord.Cog, name="SearchService"): try: response, refined_text = await self.model.search( - ctx, query, user_api_key, search_scope, nodes + ctx, query, user_api_key, search_scope, nodes, deep ) except ValueError: await ctx.respond( diff --git a/gpt3discord.py b/gpt3discord.py index cc3309b..3c9a65d 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -31,7 +31,7 @@ from services.environment_service import EnvService from models.openai_model import Model -__version__ = "10.3.1" +__version__ = "10.3.2" PID_FILE = Path("bot.pid") diff --git a/models/index_model.py b/models/index_model.py index 851d1d5..6e9c126 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -33,7 +33,7 @@ from gpt_index import ( QueryConfig, PromptHelper, IndexStructType, - OpenAIEmbedding, + OpenAIEmbedding, GithubRepositoryReader, ) from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR @@ -53,7 +53,6 @@ def get_and_query( if isinstance(index, GPTTreeIndex): response = index.query( query, - verbose=True, child_branch_factor=2, llm_predictor=llm_predictor, embed_model=embed_model, @@ -62,7 +61,6 @@ def get_and_query( response = index.query( query, response_mode=response_mode, - verbose=True, llm_predictor=llm_predictor, embed_model=embed_model, similarity_top_k=nodes, @@ -185,6 +183,23 @@ class Index_handler: ) return index + def index_github_repository(self, link, embed_model): + print("indexing github repo") + # Extract the "owner" and the "repo" name from the github link. + owner = link.split("/")[3] + repo = link.split("/")[4] + + try: + documents = GithubRepositoryReader(owner=owner, repo=repo).load_data(branch="main") + except KeyError: + documents = GithubRepositoryReader(owner=owner, repo=repo).load_data(branch="master") + + index = GPTSimpleVectorIndex( + documents, + embed_model=embed_model, + ) + return index + def index_load_file(self, file_path) -> [GPTSimpleVectorIndex, ComposableGraph]: if "composed_deep" in str(file_path): index = GPTTreeIndex.load_from_disk(file_path) @@ -335,6 +350,10 @@ class Index_handler: index = await self.loop.run_in_executor( None, partial(self.index_youtube_transcript, link, embedding_model) ) + elif "github" in link: + index = await self.loop.run_in_executor( + None, partial(self.index_github_repository, link, embedding_model) + ) else: index = await self.index_webpage(link, embedding_model) await self.usage_service.update_usage( @@ -359,6 +378,7 @@ class Index_handler: except Exception: await ctx.respond("Failed to set index") traceback.print_exc() + return await ctx.respond("Index set") diff --git a/models/search_model.py b/models/search_model.py index 5180939..64c9306 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -20,6 +20,7 @@ from gpt_index import ( OpenAIEmbedding, SimpleDirectoryReader, ) +from gpt_index.indices.knowledge_graph import GPTKnowledgeGraphIndex from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from langchain import OpenAI @@ -168,6 +169,7 @@ class Search: user_api_key, search_scope, nodes, + deep, redo=None, ): DEFAULT_SEARCH_NODES = 1 @@ -284,10 +286,31 @@ class Search: embedding_model = OpenAIEmbedding() - index = await self.loop.run_in_executor( - None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model) + llm_predictor = LLMPredictor( + llm=OpenAI(model_name="text-davinci-003", max_tokens=-1) ) + if not deep: + index = await self.loop.run_in_executor( + None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model) + ) + else: + print("Doing a deep search") + llm_predictor_deep = LLMPredictor( + llm=OpenAI(model_name="text-davinci-002", temperature=0, max_tokens=-1) + ) + index = await self.loop.run_in_executor( + None, partial(GPTKnowledgeGraphIndex, documents, chunk_size_limit=512, max_triplets_per_chunk=2, embed_model=embedding_model, llm_predictor=llm_predictor_deep) + ) + await self.usage_service.update_usage( + embedding_model.last_token_usage, embeddings=True + ) + await self.usage_service.update_usage( + llm_predictor_deep.last_token_usage, embeddings=False + ) + + + if ctx: await self.try_edit( in_progress_message, self.build_search_indexed_embed(query_refined_text) @@ -297,25 +320,37 @@ class Search: embedding_model.last_token_usage, embeddings=True ) - llm_predictor = LLMPredictor( - llm=OpenAI(model_name="text-davinci-003", max_tokens=-1) - ) # Now we can search the index for a query: embedding_model.last_token_usage = 0 - response = await self.loop.run_in_executor( - None, - partial( - index.query, - query, - verbose=True, - embed_model=embedding_model, - llm_predictor=llm_predictor, - similarity_top_k=nodes or DEFAULT_SEARCH_NODES, - text_qa_template=self.qaprompt, - ), - ) + if not deep: + response = await self.loop.run_in_executor( + None, + partial( + index.query, + query, + embed_model=embedding_model, + llm_predictor=llm_predictor, + similarity_top_k=nodes or DEFAULT_SEARCH_NODES, + text_qa_template=self.qaprompt, + ), + ) + else: + response = await self.loop.run_in_executor( + None, + partial( + index.query, + query, + include_text=True, + response_mode="tree_summarize", + embed_model=embedding_model, + llm_predictor=llm_predictor, + text_qa_template=self.qaprompt, + ), + ) + + await self.usage_service.update_usage(llm_predictor.last_token_usage) await self.usage_service.update_usage( diff --git a/requirements.txt b/requirements.txt index 7acee75..1ee5fbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ sqlitedict==2.1.0 backoff==2.2.1 flask==2.2.2 beautifulsoup4==4.11.1 -gpt-index==0.3.5 +gpt-index==0.4.4 PyPDF2==3.0.1 youtube_transcript_api==0.5.0 sentencepiece==0.1.97