Spaces:
Sleeping
Sleeping
from typing import List, Tuple, Dict, Any | |
import time | |
from tqdm.notebook import tqdm | |
from rich import print | |
from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params | |
from llama_index.finetuning import EmbeddingQAFinetuneDataset | |
from weaviate_interface import WeaviateClient | |
def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset, | |
class_name: str, | |
retriever: WeaviateClient, | |
retrieve_limit: int=5, | |
chunk_size: int=256, | |
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'], | |
display_properties: List[str]=['doc_id', 'guest', 'content'], | |
dir_outpath: str='./eval_results', | |
include_miss_info: bool=False, | |
user_def_params: Dict[str,Any]=None | |
) -> Dict[str, str|int|float]: | |
''' | |
Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector | |
hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses | |
and their associated queries that did not return a hit, for deeper analysis. Text file with results output | |
is automatically saved in the dir_outpath directory. | |
Args: | |
----- | |
dataset: EmbeddingQAFinetuneDataset | |
Dataset to be used for evaluation | |
class_name: str | |
Name of Class on Weaviate host to be used for retrieval | |
retriever: WeaviateClient | |
WeaviateClient object to be used for retrieval | |
retrieve_limit: int=5 | |
Number of documents to retrieve from Weaviate host | |
chunk_size: int=256 | |
Number of tokens used to chunk text. This value is purely for results | |
recording purposes and does not affect results. | |
display_properties: List[str]=['doc_id', 'content'] | |
List of properties to be returned from Weaviate host for display in response | |
dir_outpath: str='./eval_results' | |
Directory path for saving results. Directory will be created if it does not | |
already exist. | |
include_miss_info: bool=False | |
Option to include queries and their associated kw and vector response values | |
for queries that are "total misses" | |
user_def_params : dict=None | |
Option for user to pass in a dictionary of user-defined parameters and their values. | |
''' | |
results_dict = {'n':retrieve_limit, | |
'Retriever': retriever.model_name_or_path, | |
'chunk_size': chunk_size, | |
'kw_hit_rate': 0, | |
'kw_mrr': 0, | |
'vector_hit_rate': 0, | |
'vector_mrr': 0, | |
'total_misses': 0, | |
'total_questions':0 | |
} | |
#add hnsw configs and user defined params (if any) | |
results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys) | |
start = time.perf_counter() | |
miss_info = [] | |
for query_id, q in tqdm(dataset.queries.items(), 'Queries'): | |
results_dict['total_questions'] += 1 | |
hit = False | |
#make Keyword, Vector, and Hybrid calls to Weaviate host | |
try: | |
kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties) | |
vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties) | |
#collect doc_ids and position of doc_ids to check for document matches | |
kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)} | |
vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)} | |
#extract doc_id for scoring purposes | |
doc_id = dataset.relevant_docs[query_id][0] | |
#increment hit_rate counters and mrr scores | |
if doc_id in kw_doc_ids: | |
results_dict['kw_hit_rate'] += 1 | |
results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id] | |
hit = True | |
if doc_id in vector_doc_ids: | |
results_dict['vector_hit_rate'] += 1 | |
results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id] | |
hit = True | |
# if no hits, let's capture that | |
if not hit: | |
results_dict['total_misses'] += 1 | |
miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response}) | |
except Exception as e: | |
print(e) | |
continue | |
#use raw counts to calculate final scores | |
calc_hit_rate_scores(results_dict) | |
calc_mrr_scores(results_dict) | |
end = time.perf_counter() - start | |
print(f'Total Processing Time: {round(end/60, 2)} minutes') | |
record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True) | |
if include_miss_info: | |
return results_dict, miss_info | |
return results_dict |