import types from contextlib import contextmanager from dataclasses import dataclass, field from time import monotonic_ns from typing import Any, Dict, List, NamedTuple, Optional, Tuple from datasets import Dataset, DatasetDict, load_dataset from sentence_transformers import losses from transformers.utils import copy_func from .data import create_fewshot_splits, create_fewshot_splits_multilabel from .losses import SupConLoss SEC_TO_NS_SCALE = 1000000000 DEV_DATASET_TO_METRIC = { "sst2": "accuracy", "imdb": "accuracy", "subj": "accuracy", "bbc-news": "accuracy", "enron_spam": "accuracy", "student-question-categories": "accuracy", "TREC-QC": "accuracy", "toxic_conversations": "matthews_correlation", } TEST_DATASET_TO_METRIC = { "emotion": "accuracy", "SentEval-CR": "accuracy", "sst5": "accuracy", "ag_news": "accuracy", "enron_spam": "accuracy", "amazon_counterfactual_en": "matthews_correlation", } MULTILINGUAL_DATASET_TO_METRIC = { f"amazon_reviews_multi_{lang}": "mae" for lang in ["en", "de", "es", "fr", "ja", "zh"] } LOSS_NAME_TO_CLASS = { "CosineSimilarityLoss": losses.CosineSimilarityLoss, "ContrastiveLoss": losses.ContrastiveLoss, "OnlineContrastiveLoss": losses.OnlineContrastiveLoss, "BatchSemiHardTripletLoss": losses.BatchSemiHardTripletLoss, "BatchAllTripletLoss": losses.BatchAllTripletLoss, "BatchHardTripletLoss": losses.BatchHardTripletLoss, "BatchHardSoftMarginTripletLoss": losses.BatchHardSoftMarginTripletLoss, "SupConLoss": SupConLoss, } def default_hp_space_optuna(trial) -> Dict[str, Any]: from transformers.integrations import is_optuna_available assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`" return { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), "num_epochs": trial.suggest_int("num_epochs", 1, 5), "num_iterations": trial.suggest_categorical("num_iterations", [5, 10, 20]), "seed": trial.suggest_int("seed", 1, 40), "batch_size": trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64]), } def load_data_splits( dataset: str, sample_sizes: List[int], add_data_augmentation: bool = False ) -> Tuple[DatasetDict, Dataset]: """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits.""" print(f"\n\n\n============== {dataset} ============") # Load one of the SetFit training sets from the Hugging Face Hub train_split = load_dataset(f"SetFit/{dataset}", split="train") train_splits = create_fewshot_splits(train_split, sample_sizes, add_data_augmentation, f"SetFit/{dataset}") test_split = load_dataset(f"SetFit/{dataset}", split="test") print(f"Test set: {len(test_split)}") return train_splits, test_split def load_data_splits_multilabel(dataset: str, sample_sizes: List[int]) -> Tuple[DatasetDict, Dataset]: """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits.""" print(f"\n\n\n============== {dataset} ============") # Load one of the SetFit training sets from the Hugging Face Hub train_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="train") train_splits = create_fewshot_splits_multilabel(train_split, sample_sizes) test_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="test") print(f"Test set: {len(test_split)}") return train_splits, test_split @dataclass class Benchmark: """ Performs simple benchmarks of code portions (measures elapsed time). Typical usage example: bench = Benchmark() with bench.track("Foo function"): foo() with bench.track("Bar function"): bar() bench.summary() """ out_path: Optional[str] = None summary_msg: str = field(default_factory=str) def print(self, msg: str) -> None: """ Prints to system out and optionally to specified out_path. """ print(msg) if self.out_path is not None: with open(self.out_path, "a+") as f: f.write(msg + "\n") @contextmanager def track(self, step): """ Computes the elapsed time for given code context. """ start = monotonic_ns() yield ns = monotonic_ns() - start msg = f"\n{'*' * 70}\n'{step}' took {ns / SEC_TO_NS_SCALE:.3f}s ({ns:,}ns)\n{'*' * 70}\n" print(msg) self.summary_msg += msg + "\n" def summary(self) -> None: """ Prints summary of all benchmarks performed. """ self.print(f"\n{'#' * 30}\nBenchmark Summary:\n{'#' * 30}\n\n{self.summary_msg}") class BestRun(NamedTuple): """ The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]). Parameters: run_id (`str`): The id of the best run. objective (`float`): The objective that was obtained for this run. hyperparameters (`Dict[str, Any]`): The hyperparameters picked to get this run. backend (`Any`): The relevant internal object used for optimization. For optuna this is the `study` object. """ run_id: str objective: float hyperparameters: Dict[str, Any] backend: Any = None def set_docstring(method, docstring, cls=None): copied_function = copy_func(method) copied_function.__doc__ = docstring return types.MethodType(copied_function, cls or method.__self__)