Spaces:
Build error
Build error
import os | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import DirectoryLoader,UnstructuredFileLoader | |
from langchain_community.vectorstores import Qdrant | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langchain_community.retrievers import BM25Retriever | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.exceptions import ResponseHandlingException | |
from glob import glob | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
from transformers import AutoTokenizer, AutoModel | |
from sentence_transformers import models, SentenceTransformer | |
from langchain.embeddings.base import Embeddings | |
from qdrant_client.models import VectorParams | |
import torch | |
# from llama_index import SimpleDirectoryReader, StorageContext | |
class ClinicalBertEmbeddings(Embeddings): | |
def __init__(self, model_name: str = "medicalai/ClinicalBERT"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
self.model.eval() | |
def embed(self, text: str): | |
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
embeddings = self.mean_pooling(outputs, inputs['attention_mask']) | |
return embeddings.squeeze().numpy() | |
def mean_pooling(self, model_output, attention_mask): | |
token_embeddings = model_output[0] # First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def embed_documents(self, texts): | |
return [self.embed(text) for text in texts] | |
def embed_query(self, text): | |
return self.embed(text) | |
# Initialize our custom embeddings class | |
embeddings = ClinicalBertEmbeddings() | |
# # Load ClinicalBERT | |
# tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT") | |
# model = AutoModel.from_pretrained("medicalai/ClinicalBERT") | |
# # Use mean pooling | |
# pooling_model = models.Pooling(model.config.hidden_size, pooling_mode_mean_tokens=True) | |
# # Create custom SentenceTransformer model | |
# # sentence_transformer_model = SentenceTransformer(modules=[model, pooling_model]) | |
# # Use mean pooling | |
# sentence_transformer_model = SentenceTransformer(modules=[model]) | |
# embeddings = SentenceTransformerEmbeddings(model=sentence_transformer_model) | |
print(embeddings) | |
# Tokenize sentences | |
# loading from directory | |
loader = DirectoryLoader("Data/", glob="**/*.pdf", show_progress=True, loader_cls=UnstructuredFileLoader) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=70) | |
texts = text_splitter.split_documents(documents) | |
# Get elements | |
print("Create Vector Database") | |
url = "http://localhost:6333" | |
client = QdrantClient(url=url, prefer_grpc=False) | |
collection_name = "vector_db" | |
try: | |
# Check if the collection already exists | |
collection_info = client.get_collection(collection_name=collection_name) | |
print(f"Collection '{collection_name}' already exists.") | |
except Exception as e: | |
print(f"Collection '{collection_name}' does not exist. Creating a new one.") | |
print(f"Error: {e}") | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config= VectorParams( | |
size=768, | |
distance="Cosine" | |
) | |
) | |
qdrant = Qdrant.from_documents( | |
texts, | |
embeddings, | |
url=url, | |
prefer_grpc=False, | |
collection_name=collection_name, | |
# optimizer_config=optimizer_config | |
) | |
keyword_retriever = BM25Retriever.from_documents(texts) | |
keyword_retriever.k = 3 | |
print("vector database created.................") | |