vectorsearch / helpers.py
JPBianchi's picture
temp before HF pull
30ffb9e
raw history blame
No virus
5.17 kB
from typing import List, Tuple, Dict, Any
import time
from tqdm.notebook import tqdm
from rich import print
from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from weaviate_interface import WeaviateClient
def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset,
class_name: str,
retriever: WeaviateClient,
retrieve_limit: int=5,
chunk_size: int=256,
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
display_properties: List[str]=['doc_id', 'guest', 'content'],
dir_outpath: str='./eval_results',
include_miss_info: bool=False,
user_def_params: Dict[str,Any]=None
) -> Dict[str, str|int|float]:
'''
Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector
hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses
and their associated queries that did not return a hit, for deeper analysis. Text file with results output
is automatically saved in the dir_outpath directory.
Args:
-----
dataset: EmbeddingQAFinetuneDataset
Dataset to be used for evaluation
class_name: str
Name of Class on Weaviate host to be used for retrieval
retriever: WeaviateClient
WeaviateClient object to be used for retrieval
retrieve_limit: int=5
Number of documents to retrieve from Weaviate host
chunk_size: int=256
Number of tokens used to chunk text. This value is purely for results
recording purposes and does not affect results.
display_properties: List[str]=['doc_id', 'content']
List of properties to be returned from Weaviate host for display in response
dir_outpath: str='./eval_results'
Directory path for saving results. Directory will be created if it does not
already exist.
include_miss_info: bool=False
Option to include queries and their associated kw and vector response values
for queries that are "total misses"
user_def_params : dict=None
Option for user to pass in a dictionary of user-defined parameters and their values.
'''
results_dict = {'n':retrieve_limit,
'Retriever': retriever.model_name_or_path,
'chunk_size': chunk_size,
'kw_hit_rate': 0,
'kw_mrr': 0,
'vector_hit_rate': 0,
'vector_mrr': 0,
'total_misses': 0,
'total_questions':0
}
#add hnsw configs and user defined params (if any)
results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
start = time.perf_counter()
miss_info = []
for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
results_dict['total_questions'] += 1
hit = False
#make Keyword, Vector, and Hybrid calls to Weaviate host
try:
kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
#collect doc_ids and position of doc_ids to check for document matches
kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)}
vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)}
#extract doc_id for scoring purposes
doc_id = dataset.relevant_docs[query_id][0]
#increment hit_rate counters and mrr scores
if doc_id in kw_doc_ids:
results_dict['kw_hit_rate'] += 1
results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
hit = True
if doc_id in vector_doc_ids:
results_dict['vector_hit_rate'] += 1
results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
hit = True
# if no hits, let's capture that
if not hit:
results_dict['total_misses'] += 1
miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response})
except Exception as e:
print(e)
continue
#use raw counts to calculate final scores
calc_hit_rate_scores(results_dict)
calc_mrr_scores(results_dict)
end = time.perf_counter() - start
print(f'Total Processing Time: {round(end/60, 2)} minutes')
record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
if include_miss_info:
return results_dict, miss_info
return results_dict