File size: 3,613 Bytes
d1a829e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from openai import OpenAI
import cohere
from qdrant_client import models
from src.prompts import RAG_CONTEXT_TEMPLATE


class Retriever:
    """Retriever class for retrieving documents from the database
        For retrieving documents, the following steps are performed:
            1. Create an embedding for the query
            2. Get n documents from the database based on the query and filters (Mixed retrieval)
            3. Rerank the documents based on the query and select top k documents, where k << n (ReRanking)
            4. Create a context from the selected documents
    """
    def __init__(self, embedding_model, llm_model, rerank_model, db_client, db_collection='hotels'):
        self.db_collection = db_collection
        self.db_client = db_client
        self.rerank_model = rerank_model
        self.openai_client = OpenAI()
        self.co = cohere.Client()
        self.embedding_model = embedding_model
        self.llm_model = llm_model
        self.max_retrieved_docs = 13

    def _get_documents(self, query, top_k, city, price, rating):
        """Retrieve top n documents from the database based on the query and filters

        Args:
            query (str): query
            top_k (int): number of documents to retrieve
            city (str): city name
            price (str): price range
            rating (float): rating

        Returns:
            list: list of documents
        """
        embedding = self.openai_client.embeddings.create(input=query, model=self.embedding_model)
        filtr = []
        if city:
            filtr.append(models.FieldCondition(key="city", match=models.MatchValue(value=city)))
        if price:
            filtr.append(models.FieldCondition(key="price", match=models.MatchValue(value=price)))
        if rating:
            filtr.append(models.FieldCondition(key="rating", range=models.Range(gte=rating)))
        response = self.db_client.search(
            collection_name=self.db_collection,
            query_vector=embedding.data[0].embedding,
            limit=top_k,
            query_filter=models.Filter(
                must=filtr
            ),
        )
        return response

    def _get_context(self, docs):
        """Create a context from the retrieved documents

        Args:
            docs (list): list of documents

        Returns:
            str: context
        """
        context = ''
        for i, doc in enumerate(docs, 1):
            context += RAG_CONTEXT_TEMPLATE.format(id=i, hotel_name=doc.payload['hotel_name'], description=doc.payload['description'])
        return context

    def _reranker(self, docs, query, top_k):
        """Rerank the retrieved documents using Cohere based on the query and select top k documents

        Args:
            docs (list): list of documents
            query (str): query
            top_k (int): number of documents to select

        Returns:
            list: list of reranked documents
        """
        texts = [doc.payload['description'] for doc in docs]
        rerank_hits = self.co.rerank(query=query, documents=texts, top_n=top_k, model=self.rerank_model)
        result = [docs[hit.index] for hit in rerank_hits[:top_k]]
        return result

    def __call__(self, query, top_k=3, city=None, price=None, rating=None):
        docs = self._get_documents(query, top_k=max(self.max_retrieved_docs, top_k), city=city, price=price, rating=rating)
        if len(docs) == 0:
            return 'There are no such hotels'
        docs = self._reranker(docs, query, top_k)
        context = self._get_context(docs)
        return context