File size: 4,136 Bytes
3ac9dae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# setting device on GPU if available, else CPU
import os
from timeit import default_timer as timer
from typing import List
from langchain.document_loaders import PyPDFDirectoryLoader
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.chroma import Chroma
from langchain.vectorstores.faiss import FAISS
from app_modules.init import *
def load_documents(source_pdfs_path, urls) -> List:
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
documents = loader.load()
if urls is not None and len(urls) > 0:
for doc in documents:
source = doc.metadata["source"]
filename = source.split("/")[-1]
for url in urls:
if url.endswith(filename):
doc.metadata["url"] = url
break
return documents
def split_chunks(documents: List, chunk_size, chunk_overlap) -> List:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
return text_splitter.split_documents(documents)
def generate_index(
chunks: List, embeddings: HuggingFaceInstructEmbeddings
) -> VectorStore:
if using_faiss:
faiss_instructor_embeddings = FAISS.from_documents(
documents=chunks, embedding=embeddings
)
faiss_instructor_embeddings.save_local(index_path)
return faiss_instructor_embeddings
else:
chromadb_instructor_embeddings = Chroma.from_documents(
documents=chunks, embedding=embeddings, persist_directory=index_path
)
chromadb_instructor_embeddings.persist()
return chromadb_instructor_embeddings
# Constants
device_type, hf_pipeline_device_type = get_device_types()
hf_embeddings_model_name = (
os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
)
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get("CHROMADB_INDEX_PATH")
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
source_pdfs_path = os.environ.get("SOURCE_PDFS_PATH")
source_urls = os.environ.get("SOURCE_URLS")
chunk_size = os.environ.get("CHUNCK_SIZE")
chunk_overlap = os.environ.get("CHUNK_OVERLAP")
start = timer()
embeddings = HuggingFaceInstructEmbeddings(
model_name=hf_embeddings_model_name, model_kwargs={"device": device_type}
)
end = timer()
print(f"Completed in {end - start:.3f}s")
start = timer()
if not os.path.isdir(index_path):
print(
f"The index persist directory {index_path} is not present. Creating a new one."
)
os.mkdir(index_path)
if source_urls is not None:
# Open the file for reading
file = open(source_urls, "r")
# Read the contents of the file into a list of strings
lines = file.readlines()
# Close the file
file.close()
# Remove the newline characters from each string
source_urls = [line.strip() for line in lines]
print(
f"Loading {'' if source_urls is None else str(len(source_urls)) + ' '}PDF files from {source_pdfs_path}"
)
sources = load_documents(source_pdfs_path, source_urls)
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
chunks = split_chunks(
sources, chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap)
)
print(f"Generating index for {len(chunks)} chunks ...")
index = generate_index(chunks, embeddings)
else:
print(f"The index persist directory {index_path} is present. Loading index ...")
index = (
FAISS.load_local(index_path, embeddings)
if using_faiss
else Chroma(embedding_function=embeddings, persist_directory=index_path)
)
query = "hi"
print(f"Load relevant documents for standalone question: {query}")
start2 = timer()
docs = index.as_retriever().get_relevant_documents(query)
end = timer()
print(f"Completed in {end - start2:.3f}s")
print(docs)
end = timer()
print(f"Completed in {end - start:.3f}s")
|