Spaces:
Runtime error
Runtime error
import os | |
from langchain_openai import OpenAIEmbeddings | |
from qdrant_client import QdrantClient | |
from langchain_qdrant import QdrantVectorStore | |
from qdrant_client.http import models | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv('.env') | |
class Retriever(): | |
def __init__(self): | |
# Initialize Qdrant client | |
qdrant_client = QdrantClient( | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY") | |
) | |
# Initialize Qdrant vector store | |
self.vector_store = QdrantVectorStore( | |
client=qdrant_client, | |
collection_name="siel-ai-assignment", | |
embedding=OpenAIEmbeddings(), | |
) | |
self.vector_store_user = QdrantVectorStore( | |
client=qdrant_client, | |
collection_name="siel-ai-user", | |
embedding=OpenAIEmbeddings(), | |
) | |
self.filters = ['Taxation-Goods-and-service-Tax', | |
'Taxation-INCOME-TAX-LAW', | |
'Direct Tax Laws and International Taxation', | |
'Indirect Tax Laws', | |
'INDIAN Income Tax ACTS', | |
'ONLINESITES'] | |
def filter(self,query): | |
retriever1 = self.vector_store.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"k": 7, | |
'score_threshold':0.7, | |
'filter':models.Filter(must=[models.FieldCondition(key="metadata.DOCUMENT_IS_ABOUT", match=models.MatchValue(value=self.filters[-1]),)]) | |
}, | |
) | |
retriever2 = self.vector_store.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"k": 17, | |
'score_threshold':0.7, | |
'filter':models.Filter(must_not=[models.FieldCondition(key="metadata.DOCUMENT_IS_ABOUT", match=models.MatchValue(value=self.filters[-1]),)]) | |
}, | |
) | |
ret = retriever1.invoke(query)+retriever2.invoke(query) | |
return ret | |
def id_filter(self,query,id): | |
retriever1 = self.vector_store_user.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"k": 10, | |
'score_threshold':0.7, | |
'filter':models.Filter(must=[models.FieldCondition(key="metadata.ID", match=models.MatchValue(value=id),)]) | |
} | |
) | |
ret = retriever1.invoke(query) | |
return ret | |
def data_retrieve(self, query=''): | |
retrieved_docs = self.vector_store.similarity_search_with_score(query, k=20) | |
return [doc for doc, _ in retrieved_docs] | |