Spaces:
Sleeping
Sleeping
from langchain.vectorstores import FAISS | |
from langchain.llms import GooglePalm | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.document_loaders import TextLoader | |
from langchain.document_loaders import Docx2txtLoader | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import os | |
from dotenv import load_dotenv | |
vector_index_path = "assets/vectordb/faiss_index" | |
def load_env_variables(): | |
load_dotenv() # take environment variables from .env | |
def create_vector_db(filename, instructor_embeddings): | |
if filename.endswith(".pdf"): | |
loader = PyPDFLoader(file_path=filename) | |
elif filename.endswith(".doc") or filename.endswith(".docx"): | |
loader = Docx2txtLoader(filename) | |
elif filename.endswith("txt") or filename.endswith("TXT"): | |
loader = TextLoader(filename) | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=10) | |
splits = text_splitter.split_documents(loader.load()) | |
# data = loader.load() | |
# Create a FAISS instance for vector database from 'data' | |
vectordb = FAISS.from_documents(documents=splits, | |
embedding=instructor_embeddings) | |
# Save vector database locally | |
vectordb.save_local(vector_index_path) | |
def get_qa_chain(instructor_embeddings, llm): | |
# Load the vector database from the local folder | |
vectordb = FAISS.load_local(vector_index_path, instructor_embeddings) | |
# Create a retriever for querying the vector database | |
retriever = vectordb.as_retriever(search_type="similarity") | |
prompt_template = """ | |
You are a question answer agent and you must strictly follow below prompt template. | |
Given the following context and a question, generate an answer based on this context only. | |
In the answer try to provide as much text as possible from "response" section in the source document context without making much changes. | |
Keep answers brief and well-structured. Do not give one word answers. | |
If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer. | |
CONTEXT: {context} | |
QUESTION: {question}""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
chain = RetrievalQA.from_chain_type(llm=llm, | |
chain_type="stuff", # or map-reduce | |
retriever=retriever, | |
input_key="query", | |
return_source_documents=True, # return source document from the vector db | |
chain_type_kwargs={"prompt": PROMPT}, | |
verbose=True) | |
return chain | |
def load_model_params(): | |
load_env_variables() | |
# Create Google Palm LLM model | |
llm = GooglePalm(google_api_key=os.environ["GOOGLE_API_KEY"], temperature=0.1) | |
# # Initialize instructor embeddings using the Hugging Face model | |
instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large") | |
return llm, instructor_embeddings | |
def document_parser(instructor_embeddings, llm): | |
chain = get_qa_chain(instructor_embeddings=instructor_embeddings, llm=llm) | |
return chain | |