diff --git a/models/index_model.py b/models/index_model.py index fd21dea..3dc5516 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -136,20 +136,20 @@ class Index_handler: ) # TODO We need to do predictions below for token usage. - def index_file(self, file_path) -> GPTSimpleVectorIndex: + def index_file(self, file_path, embed_model) -> GPTSimpleVectorIndex: document = SimpleDirectoryReader(file_path).load_data() - index = GPTSimpleVectorIndex(document) + index = GPTSimpleVectorIndex(document, embed_model=embed_model) return index - def index_gdoc(self, doc_id) -> GPTSimpleVectorIndex: + def index_gdoc(self, doc_id, embed_model) -> GPTSimpleVectorIndex: document = GoogleDocsReader().load_data(doc_id) - index = GPTSimpleVectorIndex(document) + index = GPTSimpleVectorIndex(document, embed_model=embed_model) return index - def index_youtube_transcript(self, link): + def index_youtube_transcript(self, link, embed_model): documents = YoutubeTranscriptReader().load_data(ytlinks=[link]) index = GPTSimpleVectorIndex( - documents, + documents, embed_model=embed_model, ) return index @@ -160,17 +160,17 @@ class Index_handler: index = GPTSimpleVectorIndex.load_from_disk(file_path) return index - def index_discord(self, document) -> GPTSimpleVectorIndex: + def index_discord(self, document, embed_model) -> GPTSimpleVectorIndex: index = GPTSimpleVectorIndex( - document, + document, embed_model=embed_model, ) return index - def index_webpage(self, url) -> GPTSimpleVectorIndex: + def index_webpage(self, url, embed_model) -> GPTSimpleVectorIndex: documents = BeautifulSoupWebReader( website_extractor=DEFAULT_WEBSITE_EXTRACTOR ).load_data(urls=[url]) - index = GPTSimpleVectorIndex(documents) + index = GPTSimpleVectorIndex(documents, embed_model=embed_model) return index def reset_indexes(self, user_id): @@ -215,9 +215,11 @@ class Index_handler: suffix=suffix, dir=temp_path, delete=False ) as temp_file: await file.save(temp_file.name) + embedding_model = OpenAIEmbedding() index = await self.loop.run_in_executor( - None, partial(self.index_file, temp_path) + None, partial(self.index_file, temp_path, embedding_model) ) + await self.usage_service.update_usage(embedding_model.last_token_usage) file_name = file.filename self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name) @@ -236,15 +238,17 @@ class Index_handler: # TODO Link validation try: + embedding_model = OpenAIEmbedding() # Check if the link contains youtube in it if "youtube" in link: index = await self.loop.run_in_executor( - None, partial(self.index_youtube_transcript, link) + None, partial(self.index_youtube_transcript, link, embedding_model) ) else: index = await self.loop.run_in_executor( - None, partial(self.index_webpage, link) + None, partial(self.index_webpage, link, embedding_model) ) + await self.usage_service.update_usage(embedding_model.last_token_usage) # Make the url look nice, remove https, useless stuff, random characters file_name = ( @@ -282,9 +286,11 @@ class Index_handler: document = await self.load_data( channel_ids=[channel.id], limit=1000, oldest_first=False ) + embedding_model = OpenAIEmbedding() index = await self.loop.run_in_executor( - None, partial(self.index_discord, document) + None, partial(self.index_discord, document, embedding_model) ) + await self.usage_service.update_usage(embedding_model.last_token_usage) self.index_storage[ctx.user.id].add_index(index, ctx.user.id, channel.name) await ctx.respond("Index set") except Exception: @@ -392,9 +398,11 @@ class Index_handler: document = await self.load_data( channel_ids=channel_ids, limit=3000, oldest_first=False ) + embedding_model = OpenAIEmbedding() index = await self.loop.run_in_executor( - None, partial(self.index_discord, document) + None, partial(self.index_discord, document, embedding_model) ) + await self.usage_service.update_usage(embedding_model.last_token_usage) Path(app_root_path() / "indexes" / str(ctx.guild.id)).mkdir( parents=True, exist_ok=True )