diff --git a/README.md b/README.md
index 2a74921..669b731 100644
--- a/README.md
+++ b/README.md
@@ -491,4 +491,4 @@ The health check endpoint will then be present in your bot's console when it is
### Custom Bot Name
-Add a line `CUSTOM_BOT_NAME=` to your `.env` to give your bot a custom name in conversations.
\ No newline at end of file
+Add a line `CUSTOM_BOT_NAME=` to your `.env` to give your bot a custom name in conversations.
diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py
index 7237161..977fa81 100644
--- a/cogs/search_service_cog.py
+++ b/cogs/search_service_cog.py
@@ -1,7 +1,9 @@
import traceback
import aiohttp
+import re
import discord
+from discord.ext import pages
from models.deepl_model import TranslationModel
from models.search_model import Search
@@ -26,9 +28,38 @@ class SearchService(discord.Cog, name="SearchService"):
self.bot = bot
self.usage_service = usage_service
self.model = Search(gpt_model, usage_service)
+ self.EMBED_CUTOFF = 2000
# Make a mapping of all the country codes and their full country names:
- async def search_command(self, ctx, query, search_scope, nodes):
+ async def paginate_embed(self, response_text):
+ """Given a response text make embed pages and return a list of the pages. Codex makes it a codeblock in the embed"""
+
+ response_text = [
+ response_text[i : i + self.EMBED_CUTOFF]
+ for i in range(0, len(response_text), self.EMBED_CUTOFF)
+ ]
+ pages = []
+ first = False
+ # Send each chunk as a message
+ for count, chunk in enumerate(response_text, start=1):
+ if not first:
+ page = discord.Embed(
+ title=f"Page {count}",
+ description=chunk,
+ )
+ first = True
+ else:
+ page = discord.Embed(
+ title=f"Page {count}",
+ description=chunk,
+ )
+ pages.append(page)
+
+ return pages
+
+ async def search_command(
+ self, ctx: discord.ApplicationContext, query, search_scope, nodes
+ ):
"""Command handler for the translation command"""
user_api_key = None
if USER_INPUT_API_KEYS:
@@ -47,8 +78,41 @@ class SearchService(discord.Cog, name="SearchService"):
await ctx.defer()
- response = await self.model.search(query, user_api_key, search_scope, nodes)
+ try:
+ response = await self.model.search(query, user_api_key, search_scope, nodes)
+ except ValueError:
+ await ctx.respond(
+ "The Google Search API returned an error. Check the console for more details.",
+ ephemeral=True,
+ )
+ return
+ except Exception:
+ await ctx.respond(
+ "An error occurred. Check the console for more details.", ephemeral=True
+ )
+ traceback.print_exc()
+ return
+
+ url_extract_pattern = "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)"
+ urls = re.findall(
+ url_extract_pattern,
+ str(response.get_formatted_sources(length=200)),
+ flags=re.IGNORECASE,
+ )
+ urls = "\n".join(f"<{url}>" for url in urls)
- await ctx.respond(
- f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}"
+ query_response_message = f"**Query:**`\n\n{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}\n\n**Sources:**\n{urls}"
+ query_response_message = query_response_message.replace(
+ "<|endofstatement|>", ""
)
+
+ # If the response is too long, lets paginate using the discord pagination
+ # helper
+ embed_pages = await self.paginate_embed(query_response_message)
+ paginator = pages.Paginator(
+ pages=embed_pages,
+ timeout=None,
+ author_check=False,
+ )
+
+ await paginator.respond(ctx.interaction)
diff --git a/models/index_model.py b/models/index_model.py
index 23418cd..dd59a49 100644
--- a/models/index_model.py
+++ b/models/index_model.py
@@ -1,14 +1,18 @@
import os
+import tempfile
import traceback
import asyncio
from collections import defaultdict
+import aiohttp
import discord
import aiofiles
from functools import partial
from typing import List, Optional
from pathlib import Path
from datetime import date
+
+from discord.ext import pages
from langchain import OpenAI
from gpt_index.readers import YoutubeTranscriptReader
@@ -45,7 +49,6 @@ def get_and_query(
index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[
user_id
].get_index_or_throw()
- prompthelper = PromptHelper(4096, 500, 20)
if isinstance(index, GPTTreeIndex):
response = index.query(
query,
@@ -53,7 +56,6 @@ def get_and_query(
child_branch_factor=2,
llm_predictor=llm_predictor,
embed_model=embed_model,
- prompt_helper=prompthelper,
)
else:
response = index.query(
@@ -63,7 +65,6 @@ def get_and_query(
llm_predictor=llm_predictor,
embed_model=embed_model,
similarity_top_k=nodes,
- prompt_helper=prompthelper,
)
return response
@@ -134,6 +135,33 @@ class Index_handler:
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
+ self.EMBED_CUTOFF = 2000
+
+ async def paginate_embed(self, response_text):
+ """Given a response text make embed pages and return a list of the pages. Codex makes it a codeblock in the embed"""
+
+ response_text = [
+ response_text[i : i + self.EMBED_CUTOFF]
+ for i in range(0, len(response_text), self.EMBED_CUTOFF)
+ ]
+ pages = []
+ first = False
+ # Send each chunk as a message
+ for count, chunk in enumerate(response_text, start=1):
+ if not first:
+ page = discord.Embed(
+ title=f"Index Query Results",
+ description=chunk,
+ )
+ first = True
+ else:
+ page = discord.Embed(
+ title=f"Page {count}",
+ description=chunk,
+ )
+ pages.append(page)
+
+ return pages
# TODO We need to do predictions below for token usage.
def index_file(self, file_path, embed_model) -> GPTSimpleVectorIndex:
@@ -141,6 +169,22 @@ class Index_handler:
index = GPTSimpleVectorIndex(document, embed_model=embed_model)
return index
+ async def index_web_pdf(self, url, embed_model) -> GPTSimpleVectorIndex:
+ print("Indexing a WEB PDF")
+ async with aiohttp.ClientSession() as session:
+ async with session.get(url) as response:
+ if response.status == 200:
+ data = await response.read()
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(data)
+ f.close()
+ else:
+ return "An error occurred while downloading the PDF."
+
+ document = SimpleDirectoryReader(input_files=[f.name]).load_data()
+ index = GPTSimpleVectorIndex(document, embed_model=embed_model)
+ return index
+
def index_gdoc(self, doc_id, embed_model) -> GPTSimpleVectorIndex:
document = GoogleDocsReader().load_data(doc_id)
index = GPTSimpleVectorIndex(document, embed_model=embed_model)
@@ -243,11 +287,29 @@ class Index_handler:
# TODO Link validation
try:
embedding_model = OpenAIEmbedding()
+
+ # Pre-emptively connect and get the content-type of the response
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(link, timeout=2) as response:
+ print(response.status)
+ if response.status == 200:
+ content_type = response.headers.get("content-type")
+ else:
+ await ctx.respond("Failed to get link", ephemeral=True)
+ return
+ except Exception:
+ traceback.print_exc()
+ await ctx.respond("Failed to get link", ephemeral=True)
+ return
+
# Check if the link contains youtube in it
if "youtube" in link:
index = await self.loop.run_in_executor(
None, partial(self.index_youtube_transcript, link, embedding_model)
)
+ elif "pdf" in content_type:
+ index = await self.index_web_pdf(link, embedding_model)
else:
index = await self.loop.run_in_executor(
None, partial(self.index_webpage, link, embedding_model)
@@ -349,13 +411,21 @@ class Index_handler:
for doc_id in [docmeta for docmeta in _index.docstore.docs.keys()]
if isinstance(_index.docstore.get_document(doc_id), Document)
]
- llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
+ llm_predictor = LLMPredictor(
+ llm=OpenAI(model_name="text-davinci-003", max_tokens=-1)
+ )
embedding_model = OpenAIEmbedding()
- tree_index = GPTTreeIndex(
- documents=documents,
- llm_predictor=llm_predictor,
- embed_model=embedding_model,
+
+ tree_index = await self.loop.run_in_executor(
+ None,
+ partial(
+ GPTTreeIndex,
+ documents=documents,
+ llm_predictor=llm_predictor,
+ embed_model=embedding_model,
+ ),
)
+
await self.usage_service.update_usage(llm_predictor.last_token_usage)
await self.usage_service.update_usage(
embedding_model.last_token_usage, embeddings=True
@@ -381,10 +451,16 @@ class Index_handler:
]
embedding_model = OpenAIEmbedding()
- # Add everything into a simple vector index
- simple_index = GPTSimpleVectorIndex(
- documents=documents, embed_model=embedding_model
+
+ simple_index = await self.loop.run_in_executor(
+ None,
+ partial(
+ GPTSimpleVectorIndex,
+ documents=documents,
+ embed_model=embedding_model,
+ ),
)
+
await self.usage_service.update_usage(
embedding_model.last_token_usage, embeddings=True
)
@@ -466,9 +542,17 @@ class Index_handler:
await self.usage_service.update_usage(
embedding_model.last_token_usage, embeddings=True
)
- await ctx.respond(
- f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}"
+ query_response_message = f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}"
+ query_response_message = query_response_message.replace(
+ "<|endofstatement|>", ""
+ )
+ embed_pages = await self.paginate_embed(query_response_message)
+ paginator = pages.Paginator(
+ pages=embed_pages,
+ timeout=None,
+ author_check=False,
)
+ await paginator.respond(ctx.interaction)
except Exception:
traceback.print_exc()
await ctx.respond(
@@ -665,7 +749,7 @@ class ComposeModal(discord.ui.View):
)
else:
composing_message = await interaction.response.send_message(
- "Composing indexes, this may take a long time...",
+ "Composing indexes, this may take a long time, you will be DMed when it's ready!",
ephemeral=True,
delete_after=120,
)
@@ -679,9 +763,17 @@ class ComposeModal(discord.ui.View):
else True,
)
await interaction.followup.send(
- "Composed indexes", ephemeral=True, delete_after=10
+ "Composed indexes", ephemeral=True, delete_after=180
)
+ # Try to direct message the user that their composed index is ready
+ try:
+ await self.index_cog.bot.get_user(self.user_id).send(
+ f"Your composed index is ready! You can load it with /index load now in the server."
+ )
+ except discord.Forbidden:
+ pass
+
try:
await composing_message.delete()
except:
diff --git a/models/search_model.py b/models/search_model.py
index 1c444c4..a3f71d9 100644
--- a/models/search_model.py
+++ b/models/search_model.py
@@ -2,9 +2,11 @@ import asyncio
import os
import random
import re
+import tempfile
import traceback
from functools import partial
+import discord
from bs4 import BeautifulSoup
import aiohttp
from gpt_index import (
@@ -15,11 +17,12 @@ from gpt_index import (
PromptHelper,
LLMPredictor,
OpenAIEmbedding,
+ SimpleDirectoryReader,
)
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
from langchain import OpenAI
-from services.environment_service import EnvService
+from services.environment_service import EnvService, app_root_path
from services.usage_service import UsageService
@@ -40,6 +43,7 @@ class Search:
"answer the question, say that you were unable to answer the question if there is not sufficient context to formulate a decisive answer. The search query was: {query_str}\n"
)
self.openai_key = os.getenv("OPENAI_TOKEN")
+ self.EMBED_CUTOFF = 2000
def index_webpage(self, url) -> list[Document]:
documents = BeautifulSoupWebReader(
@@ -47,7 +51,26 @@ class Search:
).load_data(urls=[url])
return documents
- async def get_links(self, query, search_scope=3):
+ async def index_pdf(self, url) -> list[Document]:
+ # Download the PDF at the url and save it to a tempfile
+ async with aiohttp.ClientSession() as session:
+ async with session.get(url) as response:
+ if response.status == 200:
+ data = await response.read()
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(data)
+ f.close()
+ else:
+ return "An error occurred while downloading the PDF."
+ # Get the file path of this tempfile.NamedTemporaryFile
+ # Save this temp file to an actual file that we can put into something else to read it
+ documents = SimpleDirectoryReader(input_files=[f.name]).load_data()
+ print("Loaded the PDF document data")
+
+ # Delete the temporary file
+ return documents
+
+ async def get_links(self, query, search_scope=2):
"""Search the web for a query"""
async with aiohttp.ClientSession() as session:
async with session.get(
@@ -55,10 +78,17 @@ class Search:
) as response:
if response.status == 200:
data = await response.json()
- # Return a list of the top 5 links
- return [item["link"] for item in data["items"][:search_scope]], [item["link"] for item in data["items"]]
+ # Return a list of the top 2 links
+ return (
+ [item["link"] for item in data["items"][:search_scope]],
+ [item["link"] for item in data["items"]],
+ )
else:
- return "An error occurred while searching."
+ print(
+ "The Google Search API returned an error: "
+ + str(response.status)
+ )
+ return ["An error occurred while searching.", None]
async def search(self, query, user_api_key, search_scope, nodes):
DEFAULT_SEARCH_NODES = 1
@@ -79,7 +109,9 @@ class Search:
query_refined_text = query
# Get the links for the query
- links, all_links = await self.get_links(query_refined_text, search_scope=search_scope)
+ links, all_links = await self.get_links(query, search_scope=search_scope)
+ if all_links is None:
+ raise ValueError("The Google Search API returned an error.")
# For each link, crawl the page and get all the text that's not HTML garbage.
# Concatenate all the text for a given website into one string and save it into an array:
@@ -87,6 +119,7 @@ class Search:
for link in links:
# First, attempt a connection with a timeout of 3 seconds to the link, if the timeout occurs, don't
# continue to the document loading.
+ pdf = False
try:
async with aiohttp.ClientSession() as session:
async with session.get(link, timeout=2) as response:
@@ -103,8 +136,14 @@ class Search:
try:
print("Adding redirect")
links.append(response.url)
+ continue
except:
- pass
+ continue
+ else:
+ # Detect if the link is a PDF, if it is, we load it differently
+ if response.headers["Content-Type"] == "application/pdf":
+ print("Found a PDF at the link " + link)
+ pdf = True
except:
traceback.print_exc()
@@ -121,30 +160,47 @@ class Search:
continue
try:
- document = await self.loop.run_in_executor(
- None, partial(self.index_webpage, link)
- )
+ if not pdf:
+ document = await self.loop.run_in_executor(
+ None, partial(self.index_webpage, link)
+ )
+ else:
+ document = await self.index_pdf(link)
[documents.append(doc) for doc in document]
except Exception as e:
traceback.print_exc()
- prompthelper = PromptHelper(4096, 1024, 20)
-
embedding_model = OpenAIEmbedding()
- index = GPTSimpleVectorIndex(documents, embed_model=embedding_model)
- await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True)
+
+
+ index = await self.loop.run_in_executor(
+ None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model)
+ )
+
+ await self.usage_service.update_usage(
+ embedding_model.last_token_usage, embeddings=True
+ )
+
+ llm_predictor = LLMPredictor(
+ llm=OpenAI(model_name="text-davinci-003", max_tokens=-1)
+ )
# Now we can search the index for a query:
embedding_model.last_token_usage = 0
- response = index.query(
- query,
- verbose=True,
- embed_model=embedding_model,
- llm_predictor=llm_predictor,
- prompt_helper=prompthelper,
- similarity_top_k=nodes or DEFAULT_SEARCH_NODES,
- text_qa_template=self.qaprompt,
+
+ response = await self.loop.run_in_executor(
+ None,
+ partial(
+ index.query,
+ query,
+ verbose=True,
+ embed_model=embedding_model,
+ llm_predictor=llm_predictor,
+ similarity_top_k=nodes or DEFAULT_SEARCH_NODES,
+ text_qa_template=self.qaprompt,
+ ),
)
+
await self.usage_service.update_usage(llm_predictor.last_token_usage)
await self.usage_service.update_usage(
embedding_model.last_token_usage, embeddings=True