Chatbot / retriever.py
edithram23's picture
roll back
52addb4
raw
history blame
3.33 kB
import os
from langchain_openai import OpenAIEmbeddings
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
from qdrant_client.http import models
from dotenv import load_dotenv
# Load environment variables
load_dotenv('.env')
class Retriever():
def __init__(self):
# Initialize Qdrant client
qdrant_client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY")
)
# Initialize Qdrant vector store
self.vector_store = QdrantVectorStore(
client=qdrant_client,
collection_name="siel-ai-assignment",
embedding=OpenAIEmbeddings(),
)
self.vector_store_user = QdrantVectorStore(
client=qdrant_client,
collection_name="siel-ai-user",
embedding=OpenAIEmbeddings(),
)
self.filters = ['Taxation-Goods-and-service-Tax',
'Taxation-INCOME-TAX-LAW',
'Direct Tax Laws and International Taxation',
'Indirect Tax Laws',
'INDIAN Income Tax ACTS',
'ONLINESITES']
def filter(self,query):
retriever1 = self.vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 7,
'score_threshold':0.7,
'filter':models.Filter(must=[models.FieldCondition(key="metadata.DOCUMENT_IS_ABOUT", match=models.MatchValue(value=self.filters[-1]),)])
},
)
retriever2 = self.vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 17,
'score_threshold':0.7,
'filter':models.Filter(must_not=[models.FieldCondition(key="metadata.DOCUMENT_IS_ABOUT", match=models.MatchValue(value=self.filters[-1]),)])
},
)
ret = retriever1.invoke(query)+retriever2.invoke(query)
return ret
def id_filter(self,query,id):
retriever1 = self.vector_store_user.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 10,
'score_threshold':0.7,
'filter':models.Filter(must=[models.FieldCondition(key="metadata.ID", match=models.MatchValue(value=id),)])
}
)
ret = retriever1.invoke(query)
return ret
def data_retrieve(self, query=''):
retrieved_docs = self.vector_store.similarity_search_with_score(query, k=20)
return [doc for doc, _ in retrieved_docs]