diff --git a/models/index_model.py b/models/index_model.py index 3dc5516..1d0deca 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -219,7 +219,7 @@ class Index_handler: index = await self.loop.run_in_executor( None, partial(self.index_file, temp_path, embedding_model) ) - await self.usage_service.update_usage(embedding_model.last_token_usage) + await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) file_name = file.filename self.index_storage[ctx.user.id].add_index(index, ctx.user.id, file_name) @@ -248,7 +248,7 @@ class Index_handler: index = await self.loop.run_in_executor( None, partial(self.index_webpage, link, embedding_model) ) - await self.usage_service.update_usage(embedding_model.last_token_usage) + await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) # Make the url look nice, remove https, useless stuff, random characters file_name = ( @@ -290,7 +290,7 @@ class Index_handler: index = await self.loop.run_in_executor( None, partial(self.index_discord, document, embedding_model) ) - await self.usage_service.update_usage(embedding_model.last_token_usage) + await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) self.index_storage[ctx.user.id].add_index(index, ctx.user.id, channel.name) await ctx.respond("Index set") except Exception: @@ -349,8 +349,12 @@ class Index_handler: embed_model=embedding_model, ) await self.usage_service.update_usage( - llm_predictor.last_token_usage + embedding_model.last_token_usage + llm_predictor.last_token_usage ) + await self.usage_service.update_usage( + embedding_model.last_token_usage, embeddings=True + ) + # Now we have a list of tree indexes, we can compose them if not name: @@ -376,7 +380,7 @@ class Index_handler: simple_index = GPTSimpleVectorIndex( documents=documents, embed_model=embedding_model ) - await self.usage_service.update_usage(embedding_model.last_token_usage) + await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) if not name: name = f"composed_index_{date.today().month}_{date.today().day}.json" @@ -402,7 +406,7 @@ class Index_handler: index = await self.loop.run_in_executor( None, partial(self.index_discord, document, embedding_model) ) - await self.usage_service.update_usage(embedding_model.last_token_usage) + await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) Path(app_root_path() / "indexes" / str(ctx.guild.id)).mkdir( parents=True, exist_ok=True ) @@ -434,6 +438,7 @@ class Index_handler: try: llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) embedding_model = OpenAIEmbedding() + embedding_model.last_token_usage = 0 response = await self.loop.run_in_executor( None, partial( @@ -449,6 +454,7 @@ class Index_handler: ) print("The last token usage was ", llm_predictor.last_token_usage) await self.usage_service.update_usage(llm_predictor.last_token_usage) + await self.usage_service.update_usage(embedding_model.last_token_usage, embeddings=True) await ctx.respond( f"**Query:**\n\n{query.strip()}\n\n**Query response:**\n\n{response.response.strip()}" ) diff --git a/models/search_model.py b/models/search_model.py index ea80a92..eb2fe51 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -100,6 +100,7 @@ class Search: llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003")) # Now we can search the index for a query: + embedding_model.last_token_usage = 0 response = index.query( query, verbose=True, @@ -110,7 +111,10 @@ class Search: text_qa_template=self.qaprompt, ) await self.usage_service.update_usage( - llm_predictor.last_token_usage + embedding_model.last_token_usage + llm_predictor.last_token_usage + ) + await self.usage_service.update_usage( + embedding_model.last_token_usage, embeddings=True ) return response diff --git a/services/usage_service.py b/services/usage_service.py index 2d43c86..bfa424b 100644 --- a/services/usage_service.py +++ b/services/usage_service.py @@ -14,9 +14,12 @@ class UsageService: f.close() self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - async def update_usage(self, tokens_used): + async def update_usage(self, tokens_used, embeddings=False): tokens_used = int(tokens_used) - price = (tokens_used / 1000) * 0.02 + if not embeddings: + price = (tokens_used / 1000) * 0.02 # Just use the highest rate instead of model-based... I am overestimating on purpose. + else: + price = (tokens_used / 1000) * 0.0004 usage = await self.get_usage() print( f"Cost -> Old: {str(usage)} | New: {str(usage + float(price))}, used {str(float(price))} credits"