diff --git a/cogs/commands.py b/cogs/commands.py index 2fc30ca..ca8e3b4 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -810,7 +810,7 @@ class Commands(discord.Cog, name="Commands"): guild_ids=ALLOWED_GUILDS, ) @discord.option(name="query", description="The query to search", required=True) + @discord.option(name="scope", description="How many top links to use for context", required=False, input_type=discord.SlashCommandOptionType.integer, max_value=8, min_value=1) @discord.guild_only() - async def search(self, ctx: discord.ApplicationContext, query: str): - await ctx.respond("Not implemented yet") - # await self.search_cog.search_command(ctx, query) + async def search(self, ctx: discord.ApplicationContext, query: str, scope: int): + await self.search_cog.search_command(ctx, query, scope) diff --git a/cogs/search_service_cog.py b/cogs/search_service_cog.py index daa4ac6..48fc950 100644 --- a/cogs/search_service_cog.py +++ b/cogs/search_service_cog.py @@ -6,10 +6,11 @@ import discord from models.deepl_model import TranslationModel from models.search_model import Search from services.environment_service import EnvService - +from services.text_service import TextService ALLOWED_GUILDS = EnvService.get_allowed_guilds() - +USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() +USER_KEY_DB = EnvService.get_api_db() class SearchService(discord.Cog, name="SearchService"): """Cog containing translation commands and retrieval of translation services""" @@ -25,8 +26,24 @@ class SearchService(discord.Cog, name="SearchService"): self.model = Search(gpt_model, pinecone_service) # Make a mapping of all the country codes and their full country names: - async def search_command(self, ctx, query): + async def search_command(self, ctx, query, search_scope): """Command handler for the translation command""" + user_api_key = None + if USER_INPUT_API_KEYS: + user_api_key = await TextService.get_user_api_key( + ctx.user.id, ctx, USER_KEY_DB + ) + if not user_api_key: + return + + if not EnvService.get_google_search_api_key() or not EnvService.get_google_search_engine_id(): + await ctx.send("The search service is not enabled.") + return + await ctx.defer() - await self.model.search(query) - await ctx.respond("ok") + + response = await self.model.search(query, user_api_key, search_scope) + + await ctx.respond( + f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}" + ) diff --git a/models/search_model.py b/models/search_model.py index bcdaede..03ea3f7 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -1,7 +1,13 @@ +import asyncio +import os import random import re +from functools import partial + from bs4 import BeautifulSoup import aiohttp +from gpt_index import QuestionAnswerPrompt, GPTSimpleVectorIndex, BeautifulSoupWebReader, Document +from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from services.environment_service import EnvService from services.usage_service import UsageService @@ -13,8 +19,24 @@ class Search: self.pinecone_service = pinecone_service self.google_search_api_key = EnvService.get_google_search_api_key() self.google_search_engine_id = EnvService.get_google_search_engine_id() + self.loop = asyncio.get_running_loop() + self.qaprompt = QuestionAnswerPrompt( + "You are formulating the response to a search query given the search prompt and the context. Context information is below. The text '<|endofstatement|>' is used to separate chat entries and make it easier for you to understand the context\n" + "---------------------\n" + "{context_str}" + "\n---------------------\n" + "Never say '<|endofstatement|>'\n" + "Given the context information and not prior knowledge, " + "answer the question, say that you were unable to answer the question if there is not sufficient context to formulate a decisive answer. The search query was: {query_str}\n" + ) + self.openai_key = os.getenv("OPENAI_TOKEN") + def index_webpage(self, url) -> list[Document]: + documents = BeautifulSoupWebReader( + website_extractor=DEFAULT_WEBSITE_EXTRACTOR + ).load_data(urls=[url]) + return documents - async def get_links(self, query): + async def get_links(self, query, search_scope=5): """Search the web for a query""" async with aiohttp.ClientSession() as session: async with session.get( @@ -23,60 +45,30 @@ class Search: if response.status == 200: data = await response.json() # Return a list of the top 5 links - return [item["link"] for item in data["items"][:5]] + return [item["link"] for item in data["items"][:search_scope]] else: return "An error occurred while searching." - async def search(self, query): + async def search(self, query, user_api_key, search_scope): + if not user_api_key: + os.environ["OPENAI_API_KEY"] = self.openai_key + else: + os.environ["OPENAI_API_KEY"] = user_api_key + # Get the links for the query - links = await self.get_links(query) + links = await self.get_links(query, search_scope=search_scope) # For each link, crawl the page and get all the text that's not HTML garbage. # Concatenate all the text for a given website into one string and save it into an array: - texts = [] + documents = [] for link in links: - async with aiohttp.ClientSession() as session: - async with session.get(link, timeout=5) as response: - if response.status == 200: - soup = BeautifulSoup(await response.read(), "html.parser") - # Find all the content between
tags and join them together and then append to texts - texts.append(" ".join([p.text for p in soup.find_all("p")])) - else: - pass - print("Finished retrieving text content from the links") + document = await self.loop.run_in_executor(None, partial(self.index_webpage, link)) + [documents.append(doc) for doc in document] - # For each text in texts, split it up into 500 character chunks and create embeddings for it - # The pinecone service uses conversation_id, but we can use it here too to keep track of the "search", each - # conversation_id represents a unique search. - conversation_id = random.randint(0, 100000000) - for text in texts: - # Split the text into 150 character chunks without using re - chunks = [text[i : i + 500] for i in range(0, len(text), 500)] - # Create embeddings for each chunk - for chunk in chunks: - # Create an embedding for the chunk - embedding = await self.model.send_embedding_request(chunk) - # Upsert the embedding for the conversation ID - self.pinecone_service.upsert_conversation_embedding( - self.model, conversation_id, chunk, 0 - ) - print("Finished creating embeddings for the text") + index = GPTSimpleVectorIndex(documents) - # Now that we have all the embeddings for the search, we can embed the query and then - # query pinecone for the top 5 results - query_embedding = await self.model.send_embedding_request(query) - results = self.pinecone_service.get_n_similar( - conversation_id, query_embedding, n=3 - ) - # Get only the first elements of each result - results = [result[0] for result in results] + # Now we can search the index for a query: + response = index.query(query, text_qa_template=self.qaprompt) - # Construct a query for GPT3 to use these results to answer the query - GPT_QUERY = f"This is a search query. I want to know the answer to the query: {query}. Here are some results from the web: {[str(result) for result in results]}. \n\n Answer:" - # Generate the answer - # Use the tokenizer to determine token amount of the query - await self.model.send_request( - GPT_QUERY, UsageService.count_tokens_static(GPT_QUERY) - ) + return response - print(texts)