Full usage metrics reporting, nodes param for search, default 2 for search

Kaveen Kumarasinghe 2 years ago
parent 0c862901f0
commit b4b654828c

@ -840,6 +840,14 @@ class Commands(discord.Cog, name="Commands"):
max_value=8, max_value=8,
min_value=1, min_value=1,
) )
@discord.option(
name="nodes",
description="The higher the number, the more accurate the results, but more expensive",
required=False,
input_type=discord.SlashCommandOptionType.integer,
max_value=10,
min_value=1,
)
@discord.guild_only() @discord.guild_only()
async def search(self, ctx: discord.ApplicationContext, query: str, scope: int): async def search(self, ctx: discord.ApplicationContext, query: str, scope: int, nodes:int):
await self.search_cog.search_command(ctx, query, scope) await self.search_cog.search_command(ctx, query, scope, nodes)

@ -20,13 +20,15 @@ class SearchService(discord.Cog, name="SearchService"):
self, self,
bot, bot,
gpt_model, gpt_model,
usage_service,
): ):
super().__init__() super().__init__()
self.bot = bot self.bot = bot
self.model = Search(gpt_model) self.usage_service = usage_service
self.model = Search(gpt_model, usage_service)
# Make a mapping of all the country codes and their full country names: # Make a mapping of all the country codes and their full country names:
async def search_command(self, ctx, query, search_scope): async def search_command(self, ctx, query, search_scope, nodes):
"""Command handler for the translation command""" """Command handler for the translation command"""
user_api_key = None user_api_key = None
if USER_INPUT_API_KEYS: if USER_INPUT_API_KEYS:
@ -45,7 +47,7 @@ class SearchService(discord.Cog, name="SearchService"):
await ctx.defer() await ctx.defer()
response = await self.model.search(query, user_api_key, search_scope) response = await self.model.search(query, user_api_key, search_scope,nodes)
await ctx.respond( await ctx.respond(
f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}" f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}"

@ -171,7 +171,7 @@ async def main():
EnvService.get_google_search_api_key() EnvService.get_google_search_api_key()
and EnvService.get_google_search_engine_id() and EnvService.get_google_search_engine_id()
): ):
bot.add_cog(SearchService(bot, model)) bot.add_cog(SearchService(bot, model, usage_service))
print("The Search service is enabled.") print("The Search service is enabled.")
bot.add_cog( bot.add_cog(

@ -27,7 +27,7 @@ from gpt_index import (
LLMPredictor, LLMPredictor,
QueryConfig, QueryConfig,
PromptHelper, PromptHelper,
IndexStructType, IndexStructType, OpenAIEmbedding,
) )
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
@ -38,7 +38,7 @@ from services.environment_service import EnvService, app_root_path
SHORT_TO_LONG_CACHE = {} SHORT_TO_LONG_CACHE = {}
def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predictor): def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predictor, embed_model):
index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[ index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[
user_id user_id
].get_index_or_throw() ].get_index_or_throw()
@ -49,6 +49,7 @@ def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predi
verbose=True, verbose=True,
child_branch_factor=2, child_branch_factor=2,
llm_predictor=llm_predictor, llm_predictor=llm_predictor,
embed_model=embed_model,
prompt_helper=prompthelper, prompt_helper=prompthelper,
) )
else: else:
@ -57,6 +58,7 @@ def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predi
response_mode=response_mode, response_mode=response_mode,
verbose=True, verbose=True,
llm_predictor=llm_predictor, llm_predictor=llm_predictor,
embed_model=embed_model,
similarity_top_k=nodes, similarity_top_k=nodes,
prompt_helper=prompthelper, prompt_helper=prompthelper,
) )
@ -331,9 +333,9 @@ class Index_handler:
if isinstance(_index.docstore.get_document(doc_id), Document) if isinstance(_index.docstore.get_document(doc_id), Document)
] ]
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
tree_index = GPTTreeIndex(documents=documents, llm_predictor=llm_predictor) embedding_model = OpenAIEmbedding()
print("The last token usage was ", llm_predictor.last_token_usage) tree_index = GPTTreeIndex(documents=documents, llm_predictor=llm_predictor, embed_model=embedding_model)
await self.usage_service.update_usage(llm_predictor.last_token_usage) await self.usage_service.update_usage(llm_predictor.last_token_usage+embedding_model.last_token_usage)
# Now we have a list of tree indexes, we can compose them # Now we have a list of tree indexes, we can compose them
if not name: if not name:
@ -354,8 +356,10 @@ class Index_handler:
if isinstance(_index.docstore.get_document(doc_id), Document) if isinstance(_index.docstore.get_document(doc_id), Document)
] ]
embedding_model = OpenAIEmbedding()
# Add everything into a simple vector index # Add everything into a simple vector index
simple_index = GPTSimpleVectorIndex(documents=documents) simple_index = GPTSimpleVectorIndex(documents=documents, embed_model=embedding_model)
await self.usage_service.update_usage(embedding_model.last_token_usage)
if not name: if not name:
name = f"composed_index_{date.today().month}_{date.today().day}.json" name = f"composed_index_{date.today().month}_{date.today().day}.json"
@ -410,6 +414,7 @@ class Index_handler:
try: try:
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
embedding_model = OpenAIEmbedding()
response = await self.loop.run_in_executor( response = await self.loop.run_in_executor(
None, None,
partial( partial(
@ -420,6 +425,7 @@ class Index_handler:
response_mode, response_mode,
nodes, nodes,
llm_predictor, llm_predictor,
embedding_model
), ),
) )
print("The last token usage was ", llm_predictor.last_token_usage) print("The last token usage was ", llm_predictor.last_token_usage)

@ -11,17 +11,19 @@ from gpt_index import (
QuestionAnswerPrompt, QuestionAnswerPrompt,
GPTSimpleVectorIndex, GPTSimpleVectorIndex,
BeautifulSoupWebReader, BeautifulSoupWebReader,
Document, Document, PromptHelper, LLMPredictor, OpenAIEmbedding,
) )
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
from langchain import OpenAI
from services.environment_service import EnvService from services.environment_service import EnvService
from services.usage_service import UsageService from services.usage_service import UsageService
class Search: class Search:
def __init__(self, gpt_model): def __init__(self, gpt_model, usage_service):
self.model = gpt_model self.model = gpt_model
self.usage_service = usage_service
self.google_search_api_key = EnvService.get_google_search_api_key() self.google_search_api_key = EnvService.get_google_search_api_key()
self.google_search_engine_id = EnvService.get_google_search_engine_id() self.google_search_engine_id = EnvService.get_google_search_engine_id()
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
@ -55,7 +57,8 @@ class Search:
else: else:
return "An error occurred while searching." return "An error occurred while searching."
async def search(self, query, user_api_key, search_scope): async def search(self, query, user_api_key, search_scope, nodes):
DEFAULT_SEARCH_NODES = 2
if not user_api_key: if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key os.environ["OPENAI_API_KEY"] = self.openai_key
else: else:
@ -86,9 +89,14 @@ class Search:
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
prompthelper = PromptHelper(4096, 1024, 20)
index = GPTSimpleVectorIndex(documents) index = GPTSimpleVectorIndex(documents)
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
embedding_model = OpenAIEmbedding()
# Now we can search the index for a query: # Now we can search the index for a query:
response = index.query(query, text_qa_template=self.qaprompt) response = index.query(query,embed_model=embedding_model,llm_predictor=llm_predictor,prompt_helper=prompthelper, similarity_top_k=nodes or DEFAULT_SEARCH_NODES, text_qa_template=self.qaprompt)
await self.usage_service.update_usage(llm_predictor.last_token_usage, embedding_model.last_token_usage)
return response return response

Loading…
Cancel
Save