Incarna-Mind / main.py
Asiya057's picture
update main.py
54f200e verified
"""Conversational QA Chain"""
from __future__ import annotations
import os
import re
import time
import logging
from fastapi import FastAPI
from pydantic import BaseModel
from langchain.chat_models import ChatOpenAI, ChatAnthropic
from langchain.memory import ConversationTokenBufferMemory
from convo_qa_chain import ConvoRetrievalChain
from toolkit.together_api_llm import TogetherLLM
from toolkit.retrivers import MyRetriever
from toolkit.local_llm import load_local_llm
from toolkit.utils import (
Config,
choose_embeddings,
load_embedding,
load_pickle,
check_device,
)
app =FastAPI()
# Load the config file
configs = Config("configparser.ini")
logger = logging.getLogger(__name__)
os.environ["OPENAI_API_KEY"] = configs.openai_api_key
os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key
embedding = choose_embeddings(configs.embedding_name)
db_store_path = configs.db_dir
# get models
def get_llm(llm_name: str, temperature: float, max_tokens: int):
"""Get the LLM model from the model name."""
if not os.path.exists(configs.local_model_dir):
os.makedirs(configs.local_model_dir)
splits = llm_name.split("|") # [provider, model_name, model_file]
if "openai" in splits[0].lower():
llm_model = ChatOpenAI(
model=splits[1],
temperature=temperature,
max_tokens=max_tokens,
)
elif "anthropic" in splits[0].lower():
llm_model = ChatAnthropic(
model=splits[1],
temperature=temperature,
max_tokens_to_sample=max_tokens,
)
elif "together" in splits[0].lower():
llm_model = TogetherLLM(
model=splits[1],
temperature=temperature,
max_tokens=max_tokens,
)
elif "huggingface" in splits[0].lower():
llm_model = load_local_llm(
model_id=splits[1],
model_basename=splits[-1],
temperature=temperature,
max_tokens=max_tokens,
device_type=check_device(),
)
else:
raise ValueError("Invalid Model Name")
return llm_model
llm = get_llm(configs.model_name, configs.temperature, configs.max_llm_generation)
# load retrieval database
db_embedding_chunks_small = load_embedding(
store_name=configs.embedding_name,
embedding=embedding,
suffix="chunks_small",
path=db_store_path,
)
db_embedding_chunks_medium = load_embedding(
store_name=configs.embedding_name,
embedding=embedding,
suffix="chunks_medium",
path=db_store_path,
)
db_docs_chunks_small = load_pickle(
prefix="docs_pickle", suffix="chunks_small", path=db_store_path
)
db_docs_chunks_medium = load_pickle(
prefix="docs_pickle", suffix="chunks_medium", path=db_store_path
)
file_names = load_pickle(prefix="file", suffix="names", path=db_store_path)
# Initialize the retriever
my_retriever = MyRetriever(
llm=llm,
embedding_chunks_small=db_embedding_chunks_small,
embedding_chunks_medium=db_embedding_chunks_medium,
docs_chunks_small=db_docs_chunks_small,
docs_chunks_medium=db_docs_chunks_medium,
first_retrieval_k=configs.first_retrieval_k,
second_retrieval_k=configs.second_retrieval_k,
num_windows=configs.num_windows,
retriever_weights=configs.retriever_weights,
)
# Initialize the memory
memory = ConversationTokenBufferMemory(
llm=llm,
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
max_token_limit=configs.max_chat_history,
)
# Initialize the QA chain
qa = ConvoRetrievalChain.from_llm(
llm,
my_retriever,
file_names=file_names,
memory=memory,
return_source_documents=False,
return_generated_question=False,
)
class Question(BaseModel):
question: str
@app.get("/chat/")
def chat_with(str1: str):
resp = qa({"question": str1})
answer = resp.get('answer', '')
return {'message': answer}
# @app.get("/")
# def chat_with(str1):
# resp = qa({"question": str1})
# return {'message':resp}
'''
if __name__ == "__main__":
while True:
user_input = input("Human: ")
start_time = time.time()
user_input_ = re.sub(r"^Human: ", "", user_input)
print("*" * 6)
resp = qa({"question": user_input_})
print()
print(f"AI:{resp['answer']}")
print(f"Time used: {time.time() - start_time}")
print("-" * 60)
'''