Medical_RAG / ingest.py
Surbhi123's picture
Upload folder using huggingface_hub
64772a4 verified
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader,UnstructuredFileLoader
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.retrievers import BM25Retriever
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import ResponseHandlingException
from glob import glob
from llama_index.vector_stores.qdrant import QdrantVectorStore
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import models, SentenceTransformer
from langchain.embeddings.base import Embeddings
from qdrant_client.models import VectorParams
import torch
# from llama_index import SimpleDirectoryReader, StorageContext
class ClinicalBertEmbeddings(Embeddings):
def __init__(self, model_name: str = "medicalai/ClinicalBERT"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def embed(self, text: str):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
return embeddings.squeeze().numpy()
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def embed_documents(self, texts):
return [self.embed(text) for text in texts]
def embed_query(self, text):
return self.embed(text)
# Initialize our custom embeddings class
embeddings = ClinicalBertEmbeddings()
# # Load ClinicalBERT
# tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
# model = AutoModel.from_pretrained("medicalai/ClinicalBERT")
# # Use mean pooling
# pooling_model = models.Pooling(model.config.hidden_size, pooling_mode_mean_tokens=True)
# # Create custom SentenceTransformer model
# # sentence_transformer_model = SentenceTransformer(modules=[model, pooling_model])
# # Use mean pooling
# sentence_transformer_model = SentenceTransformer(modules=[model])
# embeddings = SentenceTransformerEmbeddings(model=sentence_transformer_model)
print(embeddings)
# Tokenize sentences
# loading from directory
loader = DirectoryLoader("Data/", glob="**/*.pdf", show_progress=True, loader_cls=UnstructuredFileLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=70)
texts = text_splitter.split_documents(documents)
# Get elements
print("Create Vector Database")
url = "http://localhost:6333"
client = QdrantClient(url=url, prefer_grpc=False)
collection_name = "vector_db"
try:
# Check if the collection already exists
collection_info = client.get_collection(collection_name=collection_name)
print(f"Collection '{collection_name}' already exists.")
except Exception as e:
print(f"Collection '{collection_name}' does not exist. Creating a new one.")
print(f"Error: {e}")
client.create_collection(
collection_name=collection_name,
vectors_config= VectorParams(
size=768,
distance="Cosine"
)
)
qdrant = Qdrant.from_documents(
texts,
embeddings,
url=url,
prefer_grpc=False,
collection_name=collection_name,
# optimizer_config=optimizer_config
)
keyword_retriever = BM25Retriever.from_documents(texts)
keyword_retriever.k = 3
print("vector database created.................")