diff --git a/models/index_model.py b/models/index_model.py index 5e7965b..3450c77 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -56,7 +56,7 @@ RemoteReader = download_loader("RemoteReader") RemoteDepthReader = download_loader("RemoteDepthReader") -async def get_and_query( +def get_and_query( user_id, index_storage, query, @@ -70,22 +70,24 @@ async def get_and_query( user_id ].get_index_or_throw() if isinstance(index, GPTTreeIndex): - response = await index.aquery( + response = index.query( query, child_branch_factor=child_branch_factor, llm_predictor=llm_predictor, refine_template=CHAT_REFINE_PROMPT, embed_model=embed_model, + use_async=True, # optimizer=SentenceEmbeddingOptimizer(threshold_cutoff=0.7) ) else: - response = await index.aquery( + response = index.query( query, response_mode=response_mode, llm_predictor=llm_predictor, embed_model=embed_model, similarity_top_k=nodes, refine_template=CHAT_REFINE_PROMPT, + use_async=True, # optimizer=SentenceEmbeddingOptimizer(threshold_cutoff=0.7) ) return response @@ -922,15 +924,29 @@ class Index_handler: try: embedding_model = OpenAIEmbedding() embedding_model.last_token_usage = 0 - response = await get_and_query( - ctx.user.id, - self.index_storage, - query, - response_mode, - nodes, - self.llm_predictor, - embedding_model, - child_branch_factor, + # response = await get_and_query( + # ctx.user.id, + # self.index_storage, + # query, + # response_mode, + # nodes, + # self.llm_predictor, + # embedding_model, + # child_branch_factor, + # ) + response = await self.loop.run_in_executor( + None, + partial( + get_and_query, + ctx.user.id, + self.index_storage, + query, + response_mode, + nodes, + self.llm_predictor, + embedding_model, + child_branch_factor, + ), ) print("The last token usage was ", self.llm_predictor.last_token_usage) await self.usage_service.update_usage( diff --git a/models/search_model.py b/models/search_model.py index 031d53a..f1e1c44 100644 --- a/models/search_model.py +++ b/models/search_model.py @@ -1,7 +1,6 @@ import asyncio import os -import random -import re + import tempfile import traceback from datetime import datetime, date @@ -9,7 +8,6 @@ from functools import partial from pathlib import Path import discord -from bs4 import BeautifulSoup import aiohttp from langchain.llms import OpenAIChat from llama_index import ( @@ -17,23 +15,18 @@ from llama_index import ( GPTSimpleVectorIndex, BeautifulSoupWebReader, Document, - PromptHelper, LLMPredictor, OpenAIEmbedding, SimpleDirectoryReader, - GPTTreeIndex, MockLLMPredictor, MockEmbedding, ) from llama_index.indices.knowledge_graph import GPTKnowledgeGraphIndex -from llama_index.langchain_helpers.chatgpt import ChatGPTLLMPredictor from llama_index.prompts.chat_prompts import CHAT_REFINE_PROMPT -from llama_index.prompts.prompt_type import PromptType from llama_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR from langchain import OpenAI from services.environment_service import EnvService -from services.usage_service import UsageService MAX_SEARCH_PRICE = EnvService.get_max_search_price() @@ -455,24 +448,35 @@ class Search: embedding_model.last_token_usage = 0 if not deep: - response = await index.aquery( - query, - embed_model=embedding_model, - llm_predictor=llm_predictor, - refine_template=CHAT_REFINE_PROMPT, - similarity_top_k=nodes or DEFAULT_SEARCH_NODES, - text_qa_template=self.qaprompt, - response_mode=response_mode, + response = await self.loop.run_in_executor( + None, + partial( + index.query, + query, + embed_model=embedding_model, + llm_predictor=llm_predictor, + refine_template=CHAT_REFINE_PROMPT, + similarity_top_k=nodes or DEFAULT_SEARCH_NODES, + text_qa_template=self.qaprompt, + use_async=True, + response_mode=response_mode, + ), ) else: - response = await index.aquery( - query, - embed_model=embedding_model, - llm_predictor=llm_predictor, - refine_template=CHAT_REFINE_PROMPT, - similarity_top_k=nodes or DEFAULT_SEARCH_NODES, - text_qa_template=self.qaprompt, - response_mode=response_mode, + response = await self.loop.run_in_executor( + None, + partial( + index.query, + query, + embedding_mode="hybrid", + llm_predictor=llm_predictor, + refine_template=CHAT_REFINE_PROMPT, + include_text=True, + embed_model=embedding_model, + use_async=True, + similarity_top_k=nodes or DEFAULT_SEARCH_NODES, + response_mode=response_mode, + ), ) await self.usage_service.update_usage(