Spaces:
Runtime error
Runtime error
from typing import List, Dict | |
import os | |
import importlib | |
from abc import ABC, abstractmethod | |
import inspect | |
import shutil | |
import numpy as np | |
from utils.decoding import decode | |
from datasets import load_metric as hf_load_metric | |
from huggingface_hub import hf_hub_download | |
class Metric(ABC): | |
def __init__(self, **kwargs) -> None: | |
super().__init__() | |
self._kwargs = kwargs | |
self.prefix = os.path.splitext(os.path.basename(inspect.getfile(self.__class__)))[0] | |
self.requires_decoded = False | |
def __call__(self, id_to_pred, id_to_labels, is_decoded=False): | |
if self.requires_decoded and is_decoded is False: | |
id_to_pred = self._decode(id_to_pred) | |
id_to_labels = self._decode(id_to_labels) | |
return self._compute_metrics(id_to_pred, id_to_labels) | |
def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: | |
return | |
def _decode(self, id_to_something): | |
tokenizer = self._kwargs.get("tokenizer") | |
data_args = self._kwargs.get("data_args") | |
return decode(id_to_something, tokenizer, data_args) | |
class MetricCollection(Metric): | |
def __init__(self, metrics: List[Metric], **kwargs): | |
super().__init__(**kwargs) | |
self._metrics = metrics | |
def __call__(self, id_to_pred, id_to_labels): | |
return self._compute_metrics(id_to_pred, id_to_labels) | |
def _compute_metrics(self, id_to_pred, id_to_labels): | |
results = {} | |
id_to_pred_decoded = None | |
id_to_labels_decoded = None | |
for metric in self._metrics: | |
metric_prefix = f"{metric.prefix}/" if metric.prefix else "" | |
if metric.requires_decoded: | |
if id_to_pred_decoded is None: | |
id_to_pred_decoded = self._decode(id_to_pred) | |
if id_to_labels_decoded is None: | |
id_to_labels_decoded = self._decode(id_to_labels) | |
result = metric(id_to_pred_decoded, id_to_labels_decoded, is_decoded=True) | |
else: | |
result = metric(id_to_pred, id_to_labels) | |
results.update({f"{metric_prefix}{k}": np.mean(v) if type(v) is list else v for k, v in result.items() if type(v) is not str}) | |
results["num_predicted"] = len(id_to_pred) | |
results["mean_prediction_length_characters"] = np.mean([len(pred) for pred in id_to_pred_decoded.values()]) | |
elem = next(iter(id_to_pred.values())) | |
if not ((isinstance(elem, list) and isinstance(elem[0], str)) or isinstance(elem, str)): | |
tokenizer = self._kwargs["tokenizer"] | |
results["mean_prediction_length_tokens"] = np.mean( | |
[np.count_nonzero(np.array(pred) != tokenizer.pad_token_id) for pred in id_to_pred.values()] | |
) # includes BOS/EOS tokens | |
results = {key: round(value, 4) for key, value in results.items()} | |
return results | |
def load_metric(paths: List[str], **kwargs): | |
if paths is None or len(paths) == 0: | |
return None | |
if isinstance(paths, str): | |
paths = [paths] | |
else: | |
paths = [path for path in paths] | |
metric_cls_list = [] | |
scrolls_custom_metrics = [] | |
to_remove = [] | |
for i, path in enumerate(paths): | |
if not os.path.isfile(path): | |
scrolls_custom_metrics.append(path) | |
to_remove.append(i) | |
for i in sorted(to_remove, reverse=True): | |
del paths[i] | |
if len(scrolls_custom_metrics) > 0: | |
scrolls_custom_metrics.insert(0, "") # In order to have an identifying comma in the beginning | |
metric_cls_list.append(ScrollsWrapper(",".join(scrolls_custom_metrics), **kwargs)) | |
for path in paths: | |
path = path.strip() | |
if len(path) == 0: | |
continue | |
if os.path.isfile(path) is False: | |
path = os.path.join("src", "metrics", f"{path}.py") | |
module = path[:-3].replace(os.sep, ".") | |
metric_cls = import_main_class(module) | |
metric_cls_list.append(metric_cls(**kwargs)) | |
return MetricCollection(metric_cls_list, **kwargs) | |
# Modified from datasets.load | |
def import_main_class(module_path): | |
"""Import a module at module_path and return its main class""" | |
module = importlib.import_module(module_path) | |
main_cls_type = Metric | |
# Find the main class in our imported module | |
module_main_cls = None | |
for name, obj in module.__dict__.items(): | |
if isinstance(obj, type) and issubclass(obj, main_cls_type): | |
if inspect.isabstract(obj): | |
continue | |
module_main_cls = obj | |
break | |
return module_main_cls | |
class ScrollsWrapper(Metric): | |
def __init__(self, comma_separated_metric_names, **kwargs) -> None: | |
super().__init__(**kwargs) | |
self.prefix = None | |
self._metric = hf_load_metric(download_metric(), comma_separated_metric_names, keep_in_memory=True) | |
self.requires_decoded = True | |
def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: | |
return self._metric.compute(**self._metric.convert_from_map_format(id_to_pred, id_to_labels)) | |
class HFMetricWrapper(Metric): | |
def __init__(self, metric_name, **kwargs) -> None: | |
super().__init__(**kwargs) | |
self._metric = hf_load_metric(metric_name) | |
self.kwargs = HFMetricWrapper.metric_specific_kwargs.get(metric_name, {}) | |
self.requires_decoded = True | |
self.prefix = metric_name | |
self.requires_decoded = True | |
def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: | |
return self._metric.compute(**self.convert_from_map_format(id_to_pred, id_to_labels), **self.kwargs) | |
def convert_from_map_format(self, id_to_pred, id_to_labels): | |
index_to_id = list(id_to_pred.keys()) | |
predictions = [id_to_pred[id_] for id_ in index_to_id] | |
references = [id_to_labels[id_] for id_ in index_to_id] | |
return {"predictions": predictions, "references": references} | |
metric_specific_kwargs = { | |
'bertscore': { | |
# 'model_type': 'microsoft/deberta-large-mnli' or the larger 'microsoft/deberta-xlarge-mnli' | |
'model_type': 'facebook/bart-large-mnli', # has context window of 1024, | |
'num_layers': 11 # according to: https://docs.google.com/spreadsheets/d/1RKOVpselB98Nnh_EOC4A2BYn8_201tmPODpNWu4w7xI/edit#gid=0 | |
} | |
} | |
def download_metric(): | |
# here we load the custom metrics | |
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type='dataset') | |
updated_scrolls_metric_path = ( | |
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" | |
) | |
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) | |
return updated_scrolls_metric_path | |