more token usage captures for indexes

Kaveen Kumarasinghe 2 years ago
parent 1e6e857871
commit bea1ff903d

@ -136,20 +136,20 @@ class Index_handler:
)
# TODO We need to do predictions below for token usage.
def index_file(self, file_path) -> GPTSimpleVectorIndex:
def index_file(self, file_path, embed_model) -> GPTSimpleVectorIndex:
document = SimpleDirectoryReader(file_path).load_data()
index = GPTSimpleVectorIndex(document)
index = GPTSimpleVectorIndex(document, embed_model=embed_model)
return index
def index_gdoc(self, doc_id) -> GPTSimpleVectorIndex:
def index_gdoc(self, doc_id, embed_model) -> GPTSimpleVectorIndex:
document = GoogleDocsReader().load_data(doc_id)
index = GPTSimpleVectorIndex(document)
index = GPTSimpleVectorIndex(document, embed_model=embed_model)
return index
def index_youtube_transcript(self, link):
def index_youtube_transcript(self, link, embed_model):
documents = YoutubeTranscriptReader().load_data(ytlinks=[link])
index = GPTSimpleVectorIndex(
documents,
documents, embed_model=embed_model,
)
return index
@ -160,17 +160,17 @@ class Index_handler:
index = GPTSimpleVectorIndex.load_from_disk(file_path)
return index
def index_discord(self, document) -> GPTSimpleVectorIndex:
def index_discord(self, document, embed_model) -> GPTSimpleVectorIndex:
index = GPTSimpleVectorIndex(
document,
document, embed_model=embed_model,
)
return index
def index_webpage(self, url) -> GPTSimpleVectorIndex:
def index_webpage(self, url, embed_model) -> GPTSimpleVectorIndex:
documents = BeautifulSoupWebReader(
website_extractor=DEFAULT_WEBSITE_EXTRACTOR
).load_data(urls=[url])
index = GPTSimpleVectorIndex(documents)
index = GPTSimpleVectorIndex(documents, embed_model=embed_model)
return index
def reset_indexes(self, user_id):
@ -215,9 +215,11 @@ class Index_handler:
suffix=suffix, dir=temp_path, delete=False
) as temp_file:
await file.save(temp_file.name)
embedding_model = OpenAIEmbedding()
index = await self.loop.run_in_executor(
None, partial(self.index_file, temp_path)
None, partial(self.index_file, temp_path, embedding_model)
)
await self.usage_service.update_usage(embedding_model.last_token_usage)
file_name = file.filename
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name)
@ -236,15 +238,17 @@ class Index_handler:
# TODO Link validation
try:
embedding_model = OpenAIEmbedding()
# Check if the link contains youtube in it
if "youtube" in link:
index = await self.loop.run_in_executor(
None, partial(self.index_youtube_transcript, link)
None, partial(self.index_youtube_transcript, link, embedding_model)
)
else:
index = await self.loop.run_in_executor(
None, partial(self.index_webpage, link)
None, partial(self.index_webpage, link, embedding_model)
)
await self.usage_service.update_usage(embedding_model.last_token_usage)
# Make the url look nice, remove https, useless stuff, random characters
file_name = (
@ -282,9 +286,11 @@ class Index_handler:
document = await self.load_data(
channel_ids=[channel.id], limit=1000, oldest_first=False
)
embedding_model = OpenAIEmbedding()
index = await self.loop.run_in_executor(
None, partial(self.index_discord, document)
None, partial(self.index_discord, document, embedding_model)
)
await self.usage_service.update_usage(embedding_model.last_token_usage)
self.index_storage[ctx.user.id].add_index(index, ctx.user.id, channel.name)
await ctx.respond("Index set")
except Exception:
@ -392,9 +398,11 @@ class Index_handler:
document = await self.load_data(
channel_ids=channel_ids, limit=3000, oldest_first=False
)
embedding_model = OpenAIEmbedding()
index = await self.loop.run_in_executor(
None, partial(self.index_discord, document)
None, partial(self.index_discord, document, embedding_model)
)
await self.usage_service.update_usage(embedding_model.last_token_usage)
Path(app_root_path() / "indexes" / str(ctx.guild.id)).mkdir(
parents=True, exist_ok=True
)

Loading…
Cancel
Save