|
import logging |
|
import os |
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed |
|
|
|
import click |
|
import torch |
|
from langchain.docstore.document import Document |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
|
|
|
|
|
|
from constants import ( |
|
CHROMA_SETTINGS, |
|
DOCUMENT_MAP, |
|
EMBEDDING_MODEL_NAME, |
|
INGEST_THREADS, |
|
PERSIST_DIRECTORY, |
|
SOURCE_DIRECTORY, |
|
) |
|
|
|
|
|
def load_single_document(file_path: str) -> Document: |
|
|
|
file_extension = os.path.splitext(file_path)[1] |
|
loader_class = DOCUMENT_MAP.get(file_extension) |
|
if loader_class: |
|
loader = loader_class(file_path) |
|
else: |
|
raise ValueError("Document type is undefined") |
|
return loader.load()[0] |
|
|
|
|
|
def load_document_batch(filepaths): |
|
logging.info("Loading document batch") |
|
|
|
with ThreadPoolExecutor(len(filepaths)) as exe: |
|
|
|
futures = [exe.submit(load_single_document, name) for name in filepaths] |
|
|
|
data_list = [future.result() for future in futures] |
|
|
|
return (data_list, filepaths) |
|
|
|
|
|
def load_documents(source_dir: str) -> list[Document]: |
|
|
|
paths = [] |
|
for root, _, files in os.walk(source_dir): |
|
for file_name in files: |
|
file_extension = os.path.splitext(file_name)[1] |
|
source_file_path = os.path.join(root, file_name) |
|
if file_extension in DOCUMENT_MAP.keys(): |
|
paths.append(source_file_path) |
|
|
|
|
|
n_workers = min(INGEST_THREADS, max(len(paths), 1)) |
|
chunksize = round(len(paths) / n_workers) |
|
docs = [] |
|
with ProcessPoolExecutor(n_workers) as executor: |
|
futures = [] |
|
|
|
for i in range(0, len(paths), chunksize): |
|
|
|
filepaths = paths[i : (i + chunksize)] |
|
|
|
future = executor.submit(load_document_batch, filepaths) |
|
futures.append(future) |
|
|
|
for future in as_completed(futures): |
|
|
|
contents, _ = future.result() |
|
docs.extend(contents) |
|
|
|
return docs |
|
|
|
|
|
def split_documents(documents: list[Document]) -> tuple[list[Document], list[Document]]: |
|
|
|
text_docs, python_docs = [], [] |
|
for doc in documents: |
|
file_extension = os.path.splitext(doc.metadata["source"])[1] |
|
if file_extension == ".py": |
|
python_docs.append(doc) |
|
else: |
|
text_docs.append(doc) |
|
|
|
return text_docs, python_docs |
|
|
|
|
|
@click.command() |
|
@click.option( |
|
"--device_type", |
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
type=click.Choice( |
|
[ |
|
"cpu", |
|
"cuda", |
|
"ipu", |
|
"xpu", |
|
"mkldnn", |
|
"opengl", |
|
"opencl", |
|
"ideep", |
|
"hip", |
|
"ve", |
|
"fpga", |
|
"ort", |
|
"xla", |
|
"lazy", |
|
"vulkan", |
|
"mps", |
|
"meta", |
|
"hpu", |
|
"mtia", |
|
], |
|
), |
|
help="Device to run on. (Default is cuda)", |
|
) |
|
def main(device_type): |
|
|
|
logging.info(f"Loading documents from {SOURCE_DIRECTORY}") |
|
documents = load_documents(SOURCE_DIRECTORY) |
|
text_documents, python_documents = split_documents(documents) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
python_splitter = RecursiveCharacterTextSplitter.from_language( |
|
language=Language.PYTHON, chunk_size=880, chunk_overlap=200 |
|
) |
|
texts = text_splitter.split_documents(text_documents) |
|
texts.extend(python_splitter.split_documents(python_documents)) |
|
logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}") |
|
logging.info(f"Split into {len(texts)} chunks of text") |
|
|
|
|
|
embeddings = HuggingFaceInstructEmbeddings( |
|
model_name=EMBEDDING_MODEL_NAME, |
|
model_kwargs={"device": device_type}, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
db = Chroma.from_documents( |
|
texts, |
|
embeddings, |
|
persist_directory=PERSIST_DIRECTORY, |
|
client_settings=CHROMA_SETTINGS, |
|
|
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO |
|
) |
|
main() |
|
|