diff --git a/README.md b/README.md index 2a74921..669b731 100644 --- a/README.md +++ b/README.md @@ -491,4 +491,4 @@ The health check endpoint will then be present in your bot's console when it is
### Custom Bot Name -Add a line `CUSTOM_BOT_NAME=` to your `.env` to give your bot a custom name in conversations. \ No newline at end of file +Add a line `CUSTOM_BOT_NAME=` to your `.env` to give your bot a custom name in conversations. diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py index 7237161..977fa81 100644 --- a/cogs/search_service_cog.py +++ b/cogs/search_service_cog.py @@ -1,7 +1,9 @@ import traceback import aiohttp +import re import discord +from discord.ext import pages from models.deepl_model import TranslationModel from models.search_model import Search @@ -26,9 +28,38 @@ class SearchService(discord.Cog, name="SearchService"): self.bot = bot self.usage_service = usage_service self.model = Search(gpt_model, usage_service) + self.EMBED_CUTOFF = 2000 # Make a mapping of all the country codes and their full country names: - async def search_command(self, ctx, query, search_scope, nodes): + async def paginate_embed(self, response_text): + """Given a response text make embed pages and return a list of the pages. Codex makes it a codeblock in the embed""" + + response_text = [ + response_text[i : i + self.EMBED_CUTOFF] + for i in range(0, len(response_text), self.EMBED_CUTOFF) + ] + pages = [] + first = False + # Send each chunk as a message + for count, chunk in enumerate(response_text, start=1): + if not first: + page = discord.Embed( + title=f"Page {count}", + description=chunk, + ) + first = True + else: + page = discord.Embed( + title=f"Page {count}", + description=chunk, + ) + pages.append(page) + + return pages + + async def search_command( + self, ctx: discord.ApplicationContext, query, search_scope, nodes + ): """Command handler for the translation command""" user_api_key = None if USER_INPUT_API_KEYS: @@ -47,8 +78,41 @@ class SearchService(discord.Cog, name="SearchService"): await ctx.defer() - response = await self.model.search(query, user_api_key, search_scope, nodes) + try: + response = await self.model.search(query, user_api_key, search_scope, nodes) + except ValueError: + await ctx.respond( + "The Google Search API returned an error. Check the console for more details.", + ephemeral=True, + ) + return + except Exception: + await ctx.respond( + "An error occurred. Check the console for more details.", ephemeral=True + ) + traceback.print_exc() + return + + url_extract_pattern = "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)" + urls = re.findall( + url_extract_pattern, + str(response.get_formatted_sources(length=200)), + flags=re.IGNORECASE, + ) + urls = "\n".join(f"<{url}>" for url in urls) - await ctx.respond( - f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}" + query_response_message = f"**Query:**`\n\n{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}\n\n**Sources:**\n{urls}" + query_response_message = query_response_message.replace( + "<|endofstatement|>", "" ) + + # If the response is too long, lets paginate using the discord pagination + # helper + embed_pages = await self.paginate_embed(query_response_message) + paginator = pages.Paginator( + pages=embed_pages, + timeout=None, + author_check=False, + ) + + await paginator.respond(ctx.interaction) diff --git a/models/index_model.py b/models/index_model.py index 23418cd..dd59a49 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -1,14 +1,18 @@ import os +import tempfile import traceback import asyncio from collections import defaultdict +import aiohttp import discord import aiofiles from functools import partial from typing import List, Optional from pathlib import Path from datetime import date + +from discord.ext import pages from langchain import OpenAI from gpt_index.readers import YoutubeTranscriptReader @@ -45,7 +49,6 @@ def get_and_query( index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[ user_id ].get_index_or_throw() - prompthelper = PromptHelper(4096, 500, 20) if isinstance(index, GPTTreeIndex): response = index.query( query, @@ -53,7 +56,6 @@ def get_and_query( child_branch_factor=2, llm_predictor=llm_predictor, embed_model=embed_model, - prompt_helper=prompthelper, ) else: response = index.query( @@ -63,7 +65,6 @@ def get_and_query( llm_predictor=llm_predictor, embed_model=embed_model, similarity_top_k=nodes, - prompt_helper=prompthelper, ) return response @@ -134,6 +135,33 @@ class Index_handler: "Given the context information and not prior knowledge, " "answer the question: {query_str}\n" ) + self.EMBED_CUTOFF = 2000 + + async def paginate_embed(self, response_text): + """Given a response text make embed pages and return a list of the pages. Codex makes it a codeblock in the embed""" + + response_text = [ + response_text[i : i + self.EMBED_CUTOFF] + for i in range(0, len(response_text), self.EMBED_CUTOFF) + ] + pages = [] + first = False + # Send each chunk as a message + for count, chunk in enumerate(response_text, start=1): + if not first: + page = discord.Embed( + title=f"Index Query Results", + description=chunk, + ) + first = True + else: + page = discord.Embed( + title=f"Page {count}", + description=chunk, + ) + pages.append(page) + + return pages # TODO We need to do predictions below for token usage. def index_file(self, file_path, embed_model) -> GPTSimpleVectorIndex: @@ -141,6 +169,22 @@ class Index_handler: index = GPTSimpleVectorIndex(document, embed_model=embed_model) return index + async def index_web_pdf(self, url, embed_model) -> GPTSimpleVectorIndex: + print("Indexing a WEB PDF") + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + data = await response.read() + f = tempfile.NamedTemporaryFile(delete=False) + f.write(data) + f.close() + else: + return "An error occurred while downloading the PDF." + + document = SimpleDirectoryReader(input_files=[f.name]).load_data() + index = GPTSimpleVectorIndex(document, embed_model=embed_model) + return index + def index_gdoc(self, doc_id, embed_model) -> GPTSimpleVectorIndex: document = GoogleDocsReader().load_data(doc_id) index = GPTSimpleVectorIndex(document, embed_model=embed_model) @@ -243,11 +287,29 @@ class Index_handler: # TODO Link validation try: embedding_model = OpenAIEmbedding() + + # Pre-emptively connect and get the content-type of the response + try: + async with aiohttp.ClientSession() as session: + async with session.get(link, timeout=2) as response: + print(response.status) + if response.status == 200: + content_type = response.headers.get("content-type") + else: + await ctx.respond("Failed to get link", ephemeral=True) + return + except Exception: + traceback.print_exc() + await ctx.respond("Failed to get link", ephemeral=True) + return + # Check if the link contains youtube in it if "youtube" in link: index = await self.loop.run_in_executor( None, partial(self.index_youtube_transcript, link, embedding_model) ) + elif "pdf" in content_type: + index = await self.index_web_pdf(link, embedding_model) else: index = await self.loop.run_in_executor( None, partial(self.index_webpage, link, embedding_model) @@ -349,13 +411,21 @@ class Index_handler: for doc_id in [docmeta for docmeta in _index.docstore.docs.keys()] if isinstance(_index.docstore.get_document(doc_id), Document) ] - llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) + llm_predictor = LLMPredictor( + llm=OpenAI(model_name="text-davinci-003", max_tokens=-1) + ) embedding_model = OpenAIEmbedding() - tree_index = GPTTreeIndex( - documents=documents, - llm_predictor=llm_predictor, - embed_model=embedding_model, + + tree_index = await self.loop.run_in_executor( + None, + partial( + GPTTreeIndex, + documents=documents, + llm_predictor=llm_predictor, + embed_model=embedding_model, + ), ) + await self.usage_service.update_usage(llm_predictor.last_token_usage) await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True @@ -381,10 +451,16 @@ class Index_handler: ] embedding_model = OpenAIEmbedding() - # Add everything into a simple vector index - simple_index = GPTSimpleVectorIndex( - documents=documents, embed_model=embedding_model + + simple_index = await self.loop.run_in_executor( + None, + partial( + GPTSimpleVectorIndex, + documents=documents, + embed_model=embedding_model, + ), ) + await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True ) @@ -466,9 +542,17 @@ class Index_handler: await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True ) - await ctx.respond( - f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}" + query_response_message = f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}" + query_response_message = query_response_message.replace( + "<|endofstatement|>", "" + ) + embed_pages = await self.paginate_embed(query_response_message) + paginator = pages.Paginator( + pages=embed_pages, + timeout=None, + author_check=False, ) + await paginator.respond(ctx.interaction) except Exception: traceback.print_exc() await ctx.respond( @@ -665,7 +749,7 @@ class ComposeModal(discord.ui.View): ) else: composing_message = await interaction.response.send_message( - "Composing indexes, this may take a long time...", + "Composing indexes, this may take a long time, you will be DMed when it's ready!", ephemeral=True, delete_after=120, ) @@ -679,9 +763,17 @@ class ComposeModal(discord.ui.View): else True, ) await interaction.followup.send( - "Composed indexes", ephemeral=True, delete_after=10 + "Composed indexes", ephemeral=True, delete_after=180 ) + # Try to direct message the user that their composed index is ready + try: + await self.index_cog.bot.get_user(self.user_id).send( + f"Your composed index is ready! You can load it with /index load now in the server." + ) + except discord.Forbidden: + pass + try: await composing_message.delete() except: diff --git a/models/search_model.py b/models/search_model.py index 1c444c4..a3f71d9 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -2,9 +2,11 @@ import asyncio import os import random import re +import tempfile import traceback from functools import partial +import discord from bs4 import BeautifulSoup import aiohttp from gpt_index import ( @@ -15,11 +17,12 @@ from gpt_index import ( PromptHelper, LLMPredictor, OpenAIEmbedding, + SimpleDirectoryReader, ) from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from langchain import OpenAI -from services.environment_service import EnvService +from services.environment_service import EnvService, app_root_path from services.usage_service import UsageService @@ -40,6 +43,7 @@ class Search: "answer the question, say that you were unable to answer the question if there is not sufficient context to formulate a decisive answer. The search query was: {query_str}\n" ) self.openai_key = os.getenv("OPENAI_TOKEN") + self.EMBED_CUTOFF = 2000 def index_webpage(self, url) -> list[Document]: documents = BeautifulSoupWebReader( @@ -47,7 +51,26 @@ class Search: ).load_data(urls=[url]) return documents - async def get_links(self, query, search_scope=3): + async def index_pdf(self, url) -> list[Document]: + # Download the PDF at the url and save it to a tempfile + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + data = await response.read() + f = tempfile.NamedTemporaryFile(delete=False) + f.write(data) + f.close() + else: + return "An error occurred while downloading the PDF." + # Get the file path of this tempfile.NamedTemporaryFile + # Save this temp file to an actual file that we can put into something else to read it + documents = SimpleDirectoryReader(input_files=[f.name]).load_data() + print("Loaded the PDF document data") + + # Delete the temporary file + return documents + + async def get_links(self, query, search_scope=2): """Search the web for a query""" async with aiohttp.ClientSession() as session: async with session.get( @@ -55,10 +78,17 @@ class Search: ) as response: if response.status == 200: data = await response.json() - # Return a list of the top 5 links - return [item["link"] for item in data["items"][:search_scope]], [item["link"] for item in data["items"]] + # Return a list of the top 2 links + return ( + [item["link"] for item in data["items"][:search_scope]], + [item["link"] for item in data["items"]], + ) else: - return "An error occurred while searching." + print( + "The Google Search API returned an error: " + + str(response.status) + ) + return ["An error occurred while searching.", None] async def search(self, query, user_api_key, search_scope, nodes): DEFAULT_SEARCH_NODES = 1 @@ -79,7 +109,9 @@ class Search: query_refined_text = query # Get the links for the query - links, all_links = await self.get_links(query_refined_text, search_scope=search_scope) + links, all_links = await self.get_links(query, search_scope=search_scope) + if all_links is None: + raise ValueError("The Google Search API returned an error.") # For each link, crawl the page and get all the text that's not HTML garbage. # Concatenate all the text for a given website into one string and save it into an array: @@ -87,6 +119,7 @@ class Search: for link in links: # First, attempt a connection with a timeout of 3 seconds to the link, if the timeout occurs, don't # continue to the document loading. + pdf = False try: async with aiohttp.ClientSession() as session: async with session.get(link, timeout=2) as response: @@ -103,8 +136,14 @@ class Search: try: print("Adding redirect") links.append(response.url) + continue except: - pass + continue + else: + # Detect if the link is a PDF, if it is, we load it differently + if response.headers["Content-Type"] == "application/pdf": + print("Found a PDF at the link " + link) + pdf = True except: traceback.print_exc() @@ -121,30 +160,47 @@ class Search: continue try: - document = await self.loop.run_in_executor( - None, partial(self.index_webpage, link) - ) + if not pdf: + document = await self.loop.run_in_executor( + None, partial(self.index_webpage, link) + ) + else: + document = await self.index_pdf(link) [documents.append(doc) for doc in document] except Exception as e: traceback.print_exc() - prompthelper = PromptHelper(4096, 1024, 20) - embedding_model = OpenAIEmbedding() - index = GPTSimpleVectorIndex(documents, embed_model=embedding_model) - await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) + + + index = await self.loop.run_in_executor( + None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model) + ) + + await self.usage_service.update_usage( + embedding_model.last_token_usage, embeddings=True + ) + + llm_predictor = LLMPredictor( + llm=OpenAI(model_name="text-davinci-003", max_tokens=-1) + ) # Now we can search the index for a query: embedding_model.last_token_usage = 0 - response = index.query( - query, - verbose=True, - embed_model=embedding_model, - llm_predictor=llm_predictor, - prompt_helper=prompthelper, - similarity_top_k=nodes or DEFAULT_SEARCH_NODES, - text_qa_template=self.qaprompt, + + response = await self.loop.run_in_executor( + None, + partial( + index.query, + query, + verbose=True, + embed_model=embedding_model, + llm_predictor=llm_predictor, + similarity_top_k=nodes or DEFAULT_SEARCH_NODES, + text_qa_template=self.qaprompt, + ), ) + await self.usage_service.update_usage(llm_predictor.last_token_usage) await self.usage_service.update_usage( embedding_model.last_token_usage, embeddings=True