Spaces:
Sleeping
Sleeping
"""Conversational QA Chain""" | |
from __future__ import annotations | |
import os | |
import re | |
import time | |
import logging | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from langchain.chat_models import ChatOpenAI, ChatAnthropic | |
from langchain.memory import ConversationTokenBufferMemory | |
from convo_qa_chain import ConvoRetrievalChain | |
from toolkit.together_api_llm import TogetherLLM | |
from toolkit.retrivers import MyRetriever | |
from toolkit.local_llm import load_local_llm | |
from toolkit.utils import ( | |
Config, | |
choose_embeddings, | |
load_embedding, | |
load_pickle, | |
check_device, | |
) | |
app =FastAPI() | |
# Load the config file | |
configs = Config("configparser.ini") | |
logger = logging.getLogger(__name__) | |
os.environ["OPENAI_API_KEY"] = configs.openai_api_key | |
os.environ["ANTHROPIC_API_KEY"] = configs.anthropic_api_key | |
embedding = choose_embeddings(configs.embedding_name) | |
db_store_path = configs.db_dir | |
# get models | |
def get_llm(llm_name: str, temperature: float, max_tokens: int): | |
"""Get the LLM model from the model name.""" | |
if not os.path.exists(configs.local_model_dir): | |
os.makedirs(configs.local_model_dir) | |
splits = llm_name.split("|") # [provider, model_name, model_file] | |
if "openai" in splits[0].lower(): | |
llm_model = ChatOpenAI( | |
model=splits[1], | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
elif "anthropic" in splits[0].lower(): | |
llm_model = ChatAnthropic( | |
model=splits[1], | |
temperature=temperature, | |
max_tokens_to_sample=max_tokens, | |
) | |
elif "together" in splits[0].lower(): | |
llm_model = TogetherLLM( | |
model=splits[1], | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
elif "huggingface" in splits[0].lower(): | |
llm_model = load_local_llm( | |
model_id=splits[1], | |
model_basename=splits[-1], | |
temperature=temperature, | |
max_tokens=max_tokens, | |
device_type=check_device(), | |
) | |
else: | |
raise ValueError("Invalid Model Name") | |
return llm_model | |
llm = get_llm(configs.model_name, configs.temperature, configs.max_llm_generation) | |
# load retrieval database | |
db_embedding_chunks_small = load_embedding( | |
store_name=configs.embedding_name, | |
embedding=embedding, | |
suffix="chunks_small", | |
path=db_store_path, | |
) | |
db_embedding_chunks_medium = load_embedding( | |
store_name=configs.embedding_name, | |
embedding=embedding, | |
suffix="chunks_medium", | |
path=db_store_path, | |
) | |
db_docs_chunks_small = load_pickle( | |
prefix="docs_pickle", suffix="chunks_small", path=db_store_path | |
) | |
db_docs_chunks_medium = load_pickle( | |
prefix="docs_pickle", suffix="chunks_medium", path=db_store_path | |
) | |
file_names = load_pickle(prefix="file", suffix="names", path=db_store_path) | |
# Initialize the retriever | |
my_retriever = MyRetriever( | |
llm=llm, | |
embedding_chunks_small=db_embedding_chunks_small, | |
embedding_chunks_medium=db_embedding_chunks_medium, | |
docs_chunks_small=db_docs_chunks_small, | |
docs_chunks_medium=db_docs_chunks_medium, | |
first_retrieval_k=configs.first_retrieval_k, | |
second_retrieval_k=configs.second_retrieval_k, | |
num_windows=configs.num_windows, | |
retriever_weights=configs.retriever_weights, | |
) | |
# Initialize the memory | |
memory = ConversationTokenBufferMemory( | |
llm=llm, | |
memory_key="chat_history", | |
input_key="question", | |
output_key="answer", | |
return_messages=True, | |
max_token_limit=configs.max_chat_history, | |
) | |
# Initialize the QA chain | |
qa = ConvoRetrievalChain.from_llm( | |
llm, | |
my_retriever, | |
file_names=file_names, | |
memory=memory, | |
return_source_documents=False, | |
return_generated_question=False, | |
) | |
class Question(BaseModel): | |
question: str | |
def chat_with(str1: str): | |
resp = qa({"question": str1}) | |
answer = resp.get('answer', '') | |
return {'message': answer} | |
# @app.get("/") | |
# def chat_with(str1): | |
# resp = qa({"question": str1}) | |
# return {'message':resp} | |
''' | |
if __name__ == "__main__": | |
while True: | |
user_input = input("Human: ") | |
start_time = time.time() | |
user_input_ = re.sub(r"^Human: ", "", user_input) | |
print("*" * 6) | |
resp = qa({"question": user_input_}) | |
print() | |
print(f"AI:{resp['answer']}") | |
print(f"Time used: {time.time() - start_time}") | |
print("-" * 60) | |
''' |