from typing import List, Tuple, Dict, Any import time from tqdm.notebook import tqdm from rich import print from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params from llama_index.finetuning import EmbeddingQAFinetuneDataset from weaviate_interface import WeaviateClient def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset, class_name: str, retriever: WeaviateClient, retrieve_limit: int=5, chunk_size: int=256, hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'], display_properties: List[str]=['doc_id', 'guest', 'content'], dir_outpath: str='./eval_results', include_miss_info: bool=False, user_def_params: Dict[str,Any]=None ) -> Dict[str, str|int|float]: ''' Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses and their associated queries that did not return a hit, for deeper analysis. Text file with results output is automatically saved in the dir_outpath directory. Args: ----- dataset: EmbeddingQAFinetuneDataset Dataset to be used for evaluation class_name: str Name of Class on Weaviate host to be used for retrieval retriever: WeaviateClient WeaviateClient object to be used for retrieval retrieve_limit: int=5 Number of documents to retrieve from Weaviate host chunk_size: int=256 Number of tokens used to chunk text. This value is purely for results recording purposes and does not affect results. display_properties: List[str]=['doc_id', 'content'] List of properties to be returned from Weaviate host for display in response dir_outpath: str='./eval_results' Directory path for saving results. Directory will be created if it does not already exist. include_miss_info: bool=False Option to include queries and their associated kw and vector response values for queries that are "total misses" user_def_params : dict=None Option for user to pass in a dictionary of user-defined parameters and their values. ''' results_dict = {'n':retrieve_limit, 'Retriever': retriever.model_name_or_path, 'chunk_size': chunk_size, 'kw_hit_rate': 0, 'kw_mrr': 0, 'vector_hit_rate': 0, 'vector_mrr': 0, 'total_misses': 0, 'total_questions':0 } #add hnsw configs and user defined params (if any) results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys) start = time.perf_counter() miss_info = [] for query_id, q in tqdm(dataset.queries.items(), 'Queries'): results_dict['total_questions'] += 1 hit = False #make Keyword, Vector, and Hybrid calls to Weaviate host try: kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties) vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties) #collect doc_ids and position of doc_ids to check for document matches kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)} vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)} #extract doc_id for scoring purposes doc_id = dataset.relevant_docs[query_id][0] #increment hit_rate counters and mrr scores if doc_id in kw_doc_ids: results_dict['kw_hit_rate'] += 1 results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id] hit = True if doc_id in vector_doc_ids: results_dict['vector_hit_rate'] += 1 results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id] hit = True # if no hits, let's capture that if not hit: results_dict['total_misses'] += 1 miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response}) except Exception as e: print(e) continue #use raw counts to calculate final scores calc_hit_rate_scores(results_dict) calc_mrr_scores(results_dict) end = time.perf_counter() - start print(f'Total Processing Time: {round(end/60, 2)} minutes') record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True) if include_miss_info: return results_dict, miss_info return results_dict