|
|
|
from api.llms.base import get_LLM |
|
from api.embedding_models.base import get_embedding_model |
|
from api.vector_index.base import get_vector_index |
|
from llama_index.core import Settings |
|
from llama_index.core.memory import ChatMemoryBuffer |
|
|
|
QUERY_ENGINE_MODE = "tree_summarize" |
|
CHAT_ENGINE_MODE = "context" |
|
TOP_K = 3 |
|
MEMORY_TOKEN_LIMIT = 8000 |
|
|
|
class QueryEngine: |
|
def __init__(self, |
|
embedding_model = "BAAI/bge-m3", |
|
llm = "aya:8b", |
|
vector_index = "chroma", |
|
force_new_db = False): |
|
self.embed_config = get_embedding_model(embedding_model) |
|
self.llm_config = get_LLM(llm) |
|
self.index = get_vector_index(vector_index, force_new_db) |
|
self.engine = self.index.as_query_engine( |
|
text_qa_template = self.llm_config.query_context_template, |
|
response_mode = QUERY_ENGINE_MODE, |
|
similarity_top_k = TOP_K, |
|
streaming = True |
|
) |
|
|
|
def query(self, user_input): |
|
return self.engine.query(user_input) |
|
|
|
def query_streaming(self, user_input): |
|
return self.engine.query(user_input) |
|
|
|
class ChatEngine: |
|
def __init__(self, |
|
embedding_model = "BAAI/bge-m3", |
|
llm = "gpt4o_mini", |
|
vector_index = "chroma", |
|
force_new_db = False): |
|
self.embed_config = get_embedding_model(embedding_model) |
|
self.llm_config = get_LLM(llm) |
|
self.index = get_vector_index(vector_index, force_new_db) |
|
self.engine = self.index.as_chat_engine( |
|
llm = Settings.llm, |
|
chat_mode = CHAT_ENGINE_MODE, |
|
verbose = False, |
|
memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_TOKEN_LIMIT), |
|
system_prompt = self.llm_config.system_prompt, |
|
context_template = self.llm_config.chat_context_template, |
|
response_mode = QUERY_ENGINE_MODE, |
|
similarity_top_k = TOP_K, |
|
streaming = True |
|
) |
|
|
|
def query(self, user_input): |
|
return self.engine.chat(user_input) |
|
|
|
def query_streaming(self, user_input): |
|
return self.engine.stream_chat(user_input) |
|
|