Spaces:
Runtime error
Runtime error
File size: 3,613 Bytes
d1a829e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from openai import OpenAI
import cohere
from qdrant_client import models
from src.prompts import RAG_CONTEXT_TEMPLATE
class Retriever:
"""Retriever class for retrieving documents from the database
For retrieving documents, the following steps are performed:
1. Create an embedding for the query
2. Get n documents from the database based on the query and filters (Mixed retrieval)
3. Rerank the documents based on the query and select top k documents, where k << n (ReRanking)
4. Create a context from the selected documents
"""
def __init__(self, embedding_model, llm_model, rerank_model, db_client, db_collection='hotels'):
self.db_collection = db_collection
self.db_client = db_client
self.rerank_model = rerank_model
self.openai_client = OpenAI()
self.co = cohere.Client()
self.embedding_model = embedding_model
self.llm_model = llm_model
self.max_retrieved_docs = 13
def _get_documents(self, query, top_k, city, price, rating):
"""Retrieve top n documents from the database based on the query and filters
Args:
query (str): query
top_k (int): number of documents to retrieve
city (str): city name
price (str): price range
rating (float): rating
Returns:
list: list of documents
"""
embedding = self.openai_client.embeddings.create(input=query, model=self.embedding_model)
filtr = []
if city:
filtr.append(models.FieldCondition(key="city", match=models.MatchValue(value=city)))
if price:
filtr.append(models.FieldCondition(key="price", match=models.MatchValue(value=price)))
if rating:
filtr.append(models.FieldCondition(key="rating", range=models.Range(gte=rating)))
response = self.db_client.search(
collection_name=self.db_collection,
query_vector=embedding.data[0].embedding,
limit=top_k,
query_filter=models.Filter(
must=filtr
),
)
return response
def _get_context(self, docs):
"""Create a context from the retrieved documents
Args:
docs (list): list of documents
Returns:
str: context
"""
context = ''
for i, doc in enumerate(docs, 1):
context += RAG_CONTEXT_TEMPLATE.format(id=i, hotel_name=doc.payload['hotel_name'], description=doc.payload['description'])
return context
def _reranker(self, docs, query, top_k):
"""Rerank the retrieved documents using Cohere based on the query and select top k documents
Args:
docs (list): list of documents
query (str): query
top_k (int): number of documents to select
Returns:
list: list of reranked documents
"""
texts = [doc.payload['description'] for doc in docs]
rerank_hits = self.co.rerank(query=query, documents=texts, top_n=top_k, model=self.rerank_model)
result = [docs[hit.index] for hit in rerank_hits[:top_k]]
return result
def __call__(self, query, top_k=3, city=None, price=None, rating=None):
docs = self._get_documents(query, top_k=max(self.max_retrieved_docs, top_k), city=city, price=price, rating=rating)
if len(docs) == 0:
return 'There are no such hotels'
docs = self._reranker(docs, query, top_k)
context = self._get_context(docs)
return context
|