import json import pandas as pd from transformers import pipeline class Distilbert: """ Class for distilBERT """ def __init__(self, model_path:str): self.pipe = pipeline("text-classification", model=model_path) def predict_query_type(self, custom_query): """ Predict the type of a given custom query using a pre-trained classifier. Parameters: custom_query (str): The input query for which the type needs to be predicted. Returns: str: The predicted label for the input query. """ # Get predictions from the classifier preds = self.pipe(custom_query, top_k=None) # Get the list of labels from the balanced dataset labels_dict = self.load_lables() labels = list(labels_dict.keys()) # Create a DataFrame from the predictions preds_df = pd.DataFrame(preds) # Process the labels to remove 'LABEL_' prefix and convert to integer preds_df['label'] = preds_df['label'].str.replace('LABEL_', '') preds_df['label'] = pd.to_numeric(preds_df['label'], errors='coerce').astype('Int64') preds_df = preds_df.sort_values('label').reset_index(drop=True) # Find the index of the maximum score max_score_index = preds_df['score'].idxmax() # Return the label corresponding to the maximum score return labels[max_score_index] def load_lables(self): """ This function reads the 'labels.json' file, which contains a dictionary mapping integers to string labels. The function then returns this dictionary. Returns: dict: label dictionary with key as label string value and dictionary value as int """ label_dict = {} with open('labels.json') as json_file: label_dict = json.load(json_file) return label_dict