Full usage metrics reporting, nodes param for search, default 2 for search

Kaveen Kumarasinghe 2 years ago
parent 0c862901f0
commit b4b654828c

@ -840,6 +840,14 @@ class Commands(discord.Cog, name="Commands"):
max_value=8,
min_value=1,
)
@discord.option(
name="nodes",
description="The higher the number, the more accurate the results, but more expensive",
required=False,
input_type=discord.SlashCommandOptionType.integer,
max_value=10,
min_value=1,
)
@discord.guild_only()
async def search(self, ctx: discord.ApplicationContext, query: str, scope: int):
await self.search_cog.search_command(ctx, query, scope)
async def search(self, ctx: discord.ApplicationContext, query: str, scope: int, nodes:int):
await self.search_cog.search_command(ctx, query, scope, nodes)

@ -20,13 +20,15 @@ class SearchService(discord.Cog, name="SearchService"):
self,
bot,
gpt_model,
usage_service,
):
super().__init__()
self.bot = bot
self.model = Search(gpt_model)
self.usage_service = usage_service
self.model = Search(gpt_model, usage_service)
# Make a mapping of all the country codes and their full country names:
async def search_command(self, ctx, query, search_scope):
async def search_command(self, ctx, query, search_scope, nodes):
"""Command handler for the translation command"""
user_api_key = None
if USER_INPUT_API_KEYS:
@ -45,7 +47,7 @@ class SearchService(discord.Cog, name="SearchService"):
await ctx.defer()
response = await self.model.search(query, user_api_key, search_scope)
response = await self.model.search(query, user_api_key, search_scope,nodes)
await ctx.respond(
f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}"

@ -171,7 +171,7 @@ async def main():
EnvService.get_google_search_api_key()
and EnvService.get_google_search_engine_id()
):
bot.add_cog(SearchService(bot, model))
bot.add_cog(SearchService(bot, model, usage_service))
print("The Search service is enabled.")
bot.add_cog(

@ -27,7 +27,7 @@ from gpt_index import (
LLMPredictor,
QueryConfig,
PromptHelper,
IndexStructType,
IndexStructType, OpenAIEmbedding,
)
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
@ -38,7 +38,7 @@ from services.environment_service import EnvService, app_root_path
SHORT_TO_LONG_CACHE = {}
def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predictor):
def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predictor, embed_model):
index: [GPTSimpleVectorIndex, ComposableGraph] = index_storage[
user_id
].get_index_or_throw()
@ -49,6 +49,7 @@ def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predi
verbose=True,
child_branch_factor=2,
llm_predictor=llm_predictor,
embed_model=embed_model,
prompt_helper=prompthelper,
)
else:
@ -57,6 +58,7 @@ def get_and_query(user_id, index_storage, query, response_mode, nodes, llm_predi
response_mode=response_mode,
verbose=True,
llm_predictor=llm_predictor,
embed_model=embed_model,
similarity_top_k=nodes,
prompt_helper=prompthelper,
)
@ -331,9 +333,9 @@ class Index_handler:
if isinstance(_index.docstore.get_document(doc_id), Document)
]
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
tree_index = GPTTreeIndex(documents=documents, llm_predictor=llm_predictor)
print("The last token usage was ", llm_predictor.last_token_usage)
await self.usage_service.update_usage(llm_predictor.last_token_usage)
embedding_model = OpenAIEmbedding()
tree_index = GPTTreeIndex(documents=documents, llm_predictor=llm_predictor, embed_model=embedding_model)
await self.usage_service.update_usage(llm_predictor.last_token_usage+embedding_model.last_token_usage)
# Now we have a list of tree indexes, we can compose them
if not name:
@ -354,8 +356,10 @@ class Index_handler:
if isinstance(_index.docstore.get_document(doc_id), Document)
]
embedding_model = OpenAIEmbedding()
# Add everything into a simple vector index
simple_index = GPTSimpleVectorIndex(documents=documents)
simple_index = GPTSimpleVectorIndex(documents=documents, embed_model=embedding_model)
await self.usage_service.update_usage(embedding_model.last_token_usage)
if not name:
name = f"composed_index_{date.today().month}_{date.today().day}.json"
@ -410,6 +414,7 @@ class Index_handler:
try:
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
embedding_model = OpenAIEmbedding()
response = await self.loop.run_in_executor(
None,
partial(
@ -420,6 +425,7 @@ class Index_handler:
response_mode,
nodes,
llm_predictor,
embedding_model
),
)
print("The last token usage was ", llm_predictor.last_token_usage)

@ -11,17 +11,19 @@ from gpt_index import (
QuestionAnswerPrompt,
GPTSimpleVectorIndex,
BeautifulSoupWebReader,
Document,
Document, PromptHelper, LLMPredictor, OpenAIEmbedding,
)
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
from langchain import OpenAI
from services.environment_service import EnvService
from services.usage_service import UsageService
class Search:
def __init__(self, gpt_model):
def __init__(self, gpt_model, usage_service):
self.model = gpt_model
self.usage_service = usage_service
self.google_search_api_key = EnvService.get_google_search_api_key()
self.google_search_engine_id = EnvService.get_google_search_engine_id()
self.loop = asyncio.get_running_loop()
@ -55,7 +57,8 @@ class Search:
else:
return "An error occurred while searching."
async def search(self, query, user_api_key, search_scope):
async def search(self, query, user_api_key, search_scope, nodes):
DEFAULT_SEARCH_NODES = 2
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
@ -86,9 +89,14 @@ class Search:
except Exception as e:
traceback.print_exc()
prompthelper = PromptHelper(4096, 1024, 20)
index = GPTSimpleVectorIndex(documents)
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003"))
embedding_model = OpenAIEmbedding()
# Now we can search the index for a query:
response = index.query(query, text_qa_template=self.qaprompt)
response = index.query(query,embed_model=embedding_model,llm_predictor=llm_predictor,prompt_helper=prompthelper, similarity_top_k=nodes or DEFAULT_SEARCH_NODES, text_qa_template=self.qaprompt)
await self.usage_service.update_usage(llm_predictor.last_token_usage, embedding_model.last_token_usage)
return response

Loading…
Cancel
Save