File size: 3,333 Bytes
7661630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c077f9
52addb4
7661630
 
52addb4
 
7661630
 
 
 
52addb4
 
 
7661630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52addb4
7661630
 
52addb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from langchain_openai import OpenAIEmbeddings
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
from qdrant_client.http import models

from dotenv import load_dotenv

# Load environment variables
load_dotenv('.env')

class Retriever():
    def __init__(self):
        # Initialize Qdrant client
        qdrant_client = QdrantClient(
            url=os.getenv("QDRANT_URL"),
            api_key=os.getenv("QDRANT_API_KEY")
        )
        # Initialize Qdrant vector store
        self.vector_store = QdrantVectorStore(
            client=qdrant_client,
            collection_name="siel-ai-assignment",
            embedding=OpenAIEmbeddings(),
        )
        self.vector_store_user = QdrantVectorStore(
            client=qdrant_client,
            collection_name="siel-ai-user",
            embedding=OpenAIEmbeddings(),
        )
        self.filters = ['Taxation-Goods-and-service-Tax',
                        'Taxation-INCOME-TAX-LAW',
                        'Direct Tax Laws and International Taxation',
                        'Indirect Tax Laws',
                        'INDIAN Income Tax ACTS',
                        'ONLINESITES']
    
    def filter(self,query):
        retriever1 = self.vector_store.as_retriever(
                                            search_type="similarity_score_threshold",
                                            search_kwargs={"k": 7,
                                                           'score_threshold':0.7,
                                                            'filter':models.Filter(must=[models.FieldCondition(key="metadata.DOCUMENT_IS_ABOUT", match=models.MatchValue(value=self.filters[-1]),)])
                                                            },
                                        )
        retriever2 = self.vector_store.as_retriever(
                                            search_type="similarity_score_threshold",
                                            search_kwargs={"k": 17,
                                                           'score_threshold':0.7,
                                                            'filter':models.Filter(must_not=[models.FieldCondition(key="metadata.DOCUMENT_IS_ABOUT", match=models.MatchValue(value=self.filters[-1]),)])
                                                           },
                                        )
        ret = retriever1.invoke(query)+retriever2.invoke(query)
        return ret

    def id_filter(self,query,id):
        retriever1 = self.vector_store_user.as_retriever(
                                            search_type="similarity_score_threshold",
                                            search_kwargs={"k": 10,
                                                           'score_threshold':0.7,
                                                            'filter':models.Filter(must=[models.FieldCondition(key="metadata.ID", match=models.MatchValue(value=id),)])
                                                            }
                                        )
        ret = retriever1.invoke(query)
        return ret

    def data_retrieve(self, query=''):
        retrieved_docs = self.vector_store.similarity_search_with_score(query, k=20)
        return [doc for doc, _ in retrieved_docs]