import collections import random from collections import Counter, defaultdict from dataclasses import dataclass, field, fields from pathlib import Path from platform import python_version from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import datasets import tokenizers import torch import transformers from datasets import Dataset from huggingface_hub import CardData, DatasetFilter, ModelCard, dataset_info, list_datasets, model_info from huggingface_hub.repocard_data import EvalResult, eval_results_to_model_index from huggingface_hub.utils import yaml_dump from sentence_transformers import __version__ as sentence_transformers_version from transformers import PretrainedConfig, TrainerCallback from transformers.integrations import CodeCarbonCallback from transformers.modelcard import make_markdown_table from transformers.trainer_callback import TrainerControl, TrainerState from transformers.training_args import TrainingArguments from setfit import __version__ as setfit_version from . import logging logger = logging.get_logger(__name__) if TYPE_CHECKING: from setfit.modeling import SetFitModel from setfit.trainer import Trainer class ModelCardCallback(TrainerCallback): def __init__(self, trainer: "Trainer") -> None: super().__init__() self.trainer = trainer callbacks = [ callback for callback in self.trainer.callback_handler.callbacks if isinstance(callback, CodeCarbonCallback) ] if callbacks: trainer.model.model_card_data.code_carbon_callback = callbacks[0] def on_init_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", **kwargs ): if not model.model_card_data.dataset_id: # Inferring is hacky - it may break in the future, so let's be safe try: model.model_card_data.infer_dataset_id(self.trainer.train_dataset) except Exception: pass dataset = self.trainer.eval_dataset or self.trainer.train_dataset if dataset is not None: if not model.model_card_data.widget: model.model_card_data.set_widget_examples(dataset) if self.trainer.train_dataset: model.model_card_data.set_train_set_metrics(self.trainer.train_dataset) # Does not work for multilabel try: model.model_card_data.num_classes = len(set(self.trainer.train_dataset["label"])) model.model_card_data.set_label_examples(self.trainer.train_dataset) except Exception: pass def on_train_begin( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", **kwargs ) -> None: # model.model_card_data.hyperparameters = extract_hyperparameters_from_trainer(self.trainer) ignore_keys = { "output_dir", "logging_dir", "logging_strategy", "logging_first_step", "logging_steps", "evaluation_strategy", "eval_steps", "eval_delay", "save_strategy", "save_steps", "save_total_limit", "metric_for_best_model", "greater_is_better", "report_to", "samples_per_label", "show_progress_bar", } get_name_keys = {"loss", "distance_metric"} args_dict = args.to_dict() model.model_card_data.hyperparameters = { key: value.__name__ if key in get_name_keys else value for key, value in args_dict.items() if key not in ignore_keys and value is not None } def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", metrics: Dict[str, float], **kwargs, ) -> None: if ( model.model_card_data.eval_lines_list and model.model_card_data.eval_lines_list[-1]["Step"] == state.global_step ): model.model_card_data.eval_lines_list[-1]["Validation Loss"] = metrics["eval_embedding_loss"] else: model.model_card_data.eval_lines_list.append( { # "Training Loss": self.state.log_history[-1]["loss"] if "loss" in self.state.log_history[-1] else "-", "Epoch": state.epoch, "Step": state.global_step, "Training Loss": "-", "Validation Loss": metrics["eval_embedding_loss"], } ) def on_log( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", logs: Dict[str, float], **kwargs, ): keys = {"embedding_loss", "polarity_embedding_loss", "aspect_embedding_loss"} & set(logs) if keys: if ( model.model_card_data.eval_lines_list and model.model_card_data.eval_lines_list[-1]["Step"] == state.global_step ): model.model_card_data.eval_lines_list[-1]["Training Loss"] = logs[keys.pop()] else: model.model_card_data.eval_lines_list.append( { "Epoch": state.epoch, "Step": state.global_step, "Training Loss": logs[keys.pop()], "Validation Loss": "-", } ) YAML_FIELDS = [ "language", "license", "library_name", "tags", "datasets", "metrics", "pipeline_tag", "widget", "model-index", "co2_eq_emissions", "base_model", "inference", ] IGNORED_FIELDS = ["model"] @dataclass class SetFitModelCardData(CardData): """A dataclass storing data used in the model card. Args: language (`Optional[Union[str, List[str]]]`): The model language, either a string or a list, e.g. "en" or ["en", "de", "nl"] license (`Optional[str]`): The license of the model, e.g. "apache-2.0", "mit", or "cc-by-nc-sa-4.0" model_name (`Optional[str]`): The pretty name of the model, e.g. "SetFit with mBERT-base on SST2". If not defined, uses encoder_name/encoder_id and dataset_name/dataset_id to generate a model name. model_id (`Optional[str]`): The model ID when pushing the model to the Hub, e.g. "tomaarsen/span-marker-mbert-base-multinerd". dataset_name (`Optional[str]`): The pretty name of the dataset, e.g. "SST2". dataset_id (`Optional[str]`): The dataset ID of the dataset, e.g. "dair-ai/emotion". dataset_revision (`Optional[str]`): The dataset revision/commit that was for training/evaluation. st_id (`Optional[str]`): The Sentence Transformers model ID. Install [``codecarbon``]( to automatically track carbon emission usage and include it in your model cards. Example:: >>> model = SetFitModel.from_pretrained( ... "sentence-transformers/paraphrase-mpnet-base-v2", ... labels=["negative", "positive"], ... # Model card variables ... model_card_data=SetFitModelCardData( ... model_id="tomaarsen/setfit-paraphrase-mpnet-base-v2-sst2", ... dataset_name="SST2", ... dataset_id="sst2", ... license="apache-2.0", ... language="en", ... ), ... ) """ # Potentially provided by the user language: Optional[Union[str, List[str]]] = None license: Optional[str] = None tags: Optional[List[str]] = field( default_factory=lambda: [ "setfit", "sentence-transformers", "text-classification", "generated_from_setfit_trainer", ] ) model_name: Optional[str] = None model_id: Optional[str] = None dataset_name: Optional[str] = None dataset_id: Optional[str] = None dataset_revision: Optional[str] = None task_name: Optional[str] = None st_id: Optional[str] = None # Automatically filled by `ModelCardCallback` and the Trainer directly hyperparameters: Dict[str, Any] = field(default_factory=dict, init=False) eval_results_dict: Optional[Dict[str, Any]] = field(default_factory=dict, init=False) eval_lines_list: List[Dict[str, float]] = field(default_factory=list, init=False) metric_lines: List[Dict[str, float]] = field(default_factory=list, init=False) widget: List[Dict[str, str]] = field(default_factory=list, init=False) predict_example: Optional[str] = field(default=None, init=False) label_example_list: List[Dict[str, str]] = field(default_factory=list, init=False) tokenizer_warning: bool = field(default=False, init=False) train_set_metrics_list: List[Dict[str, str]] = field(default_factory=list, init=False) train_set_sentences_per_label_list: List[Dict[str, str]] = field(default_factory=list, init=False) code_carbon_callback: Optional[CodeCarbonCallback] = field(default=None, init=False) num_classes: Optional[int] = field(default=None, init=False) best_model_step: Optional[int] = field(default=None, init=False) metrics: List[str] = field(default_factory=lambda: ["accuracy"], init=False) # Computed once, always unchanged pipeline_tag: str = field(default="text-classification", init=False) library_name: str = field(default="setfit", init=False) version: Dict[str, str] = field( default_factory=lambda: { "python": python_version(), "setfit": setfit_version, "sentence_transformers": sentence_transformers_version, "transformers": transformers.__version__, "torch": torch.__version__, "datasets": datasets.__version__, "tokenizers": tokenizers.__version__, }, init=False, ) # ABSA-related arguments absa: Dict[str, Any] = field(default=None, init=False, repr=False) # Passed via `register_model` only model: Optional["SetFitModel"] = field(default=None, init=False, repr=False) head_class: Optional[str] = field(default=None, init=False, repr=False) inference: Optional[bool] = field(default=True, init=False, repr=False) def __post_init__(self): # We don't want to save "ignore_metadata_errors" in our Model Card if self.dataset_id: if is_on_huggingface(self.dataset_id, is_model=False): if self.language is None: # if languages are not set, try to determine the language from the dataset on the Hub try: info = dataset_info(self.dataset_id) except Exception: pass else: if info.cardData: self.language = info.cardData.get("language", self.language) else: logger.warning( f"The provided {self.dataset_id!r} dataset could not be found on the Hugging Face Hub." " Setting `dataset_id` to None." ) self.dataset_id = None if self.model_id and self.model_id.count("/") != 1: logger.warning( f"The provided {self.model_id!r} model ID should include the organization or user," ' such as "tomaarsen/setfit-bge-small-v1.5-sst2-8-shot". Setting `model_id` to None.' ) self.model_id = None def set_best_model_step(self, step: int) -> None: self.best_model_step = step def set_widget_examples(self, dataset: Dataset) -> None: samples =, k=min(len(dataset), 5)))["text"] self.widget = [{"text": sample} for sample in samples] samples.sort(key=len) if samples: self.predict_example = samples[0] def set_train_set_metrics(self, dataset: Dataset) -> None: def add_naive_word_count(sample: Dict[str, Any]) -> Dict[str, Any]: sample["word_count"] = len(sample["text"].split(" ")) return sample dataset = self.train_set_metrics_list = [ { "Training set": "Word count", "Min": min(dataset["word_count"]), "Median": sum(dataset["word_count"]) / len(dataset), "Max": max(dataset["word_count"]), }, ] # E.g. if unlabeled via DistillationTrainer if "label" not in dataset.column_names: return sample_label = dataset[0]["label"] if isinstance(sample_label, and not isinstance(sample_label, str): return try: counter = Counter(dataset["label"]) if self.model.labels: self.train_set_sentences_per_label_list = [ { "Label": str_label, "Training Sample Count": counter[ str_label if isinstance(sample_label, str) else self.model.label2id[str_label] ], } for str_label in self.model.labels ] else: self.train_set_sentences_per_label_list = [ { "Label": self.model.labels[label] if self.model.labels and isinstance(label, int) else str(label), "Training Sample Count": count, } for label, count in sorted(counter.items()) ] except Exception: # There are some tricky edge cases possible, e.g. if the user provided integer labels that do not fall # between 0 to num_classes-1, so we make sure we never cause errors. pass def set_label_examples(self, dataset: Dataset) -> None: num_examples_per_label = 3 examples = defaultdict(list) finished_labels = set() for sample in dataset: text = sample["text"] label = sample["label"] if label not in finished_labels: examples[label].append(f"
  • {repr(text)}
  • ") if len(examples[label]) >= num_examples_per_label: finished_labels.add(label) if len(finished_labels) == self.num_classes: break self.label_example_list = [ { "Label": self.model.labels[label] if self.model.labels and isinstance(label, int) else label, "Examples": "", } for label, example_set in examples.items() ] def infer_dataset_id(self, dataset: Dataset) -> None: def subtuple_finder(tuple: Tuple[str], subtuple: Tuple[str]) -> int: for i, element in enumerate(tuple): if element == subtuple[0] and tuple[i : i + len(subtuple)] == subtuple: return i return -1 def normalize(dataset_id: str) -> str: for token in "/\\_-": dataset_id = dataset_id.replace(token, "") return dataset_id.lower() cache_files = dataset.cache_files if cache_files and "filename" in cache_files[0]: cache_path_parts = Path(cache_files[0]["filename"]).parts # Check if the cachefile is under "huggingface/datasets" subtuple = ("huggingface", "datasets") index = subtuple_finder(cache_path_parts, subtuple) if index == -1: return # Get the folder after "huggingface/datasets" cache_dataset_name = cache_path_parts[index + len(subtuple)] # If the dataset has an author: if "___" in cache_dataset_name: author, dataset_name = cache_dataset_name.split("___") else: author = None dataset_name = cache_dataset_name # Make sure the normalized dataset IDs match dataset_list = [ dataset for dataset in list_datasets(filter=DatasetFilter(author=author, dataset_name=dataset_name)) if normalize( == normalize(cache_dataset_name) ] # If there's only one match, get the ID from it if len(dataset_list) == 1: self.dataset_id = dataset_list[0].id def register_model(self, model: "SetFitModel") -> None: self.model = model head_class = model.model_head.__class__.__name__ self.head_class = { "LogisticRegression": "[LogisticRegression](", "SetFitHead": "[SetFitHead](", }.get(head_class, head_class) if not self.model_name: if self.st_id: self.model_name = f"SetFit with {self.st_id}" if self.dataset_name or self.dataset_id: self.model_name += f" on {self.dataset_name or self.dataset_id}" else: self.model_name = "SetFit" self.inference = self.model.multi_target_strategy is None def infer_st_id(self, setfit_model_id: str) -> None: config_dict, _ = PretrainedConfig.get_config_dict(setfit_model_id) st_id = config_dict.get("_name_or_path") st_id_path = Path(st_id) # Sometimes the name_or_path ends exactly with the model_id, e.g. # "C:\\Users\\tom/.cache\\torch\\sentence_transformers\\BAAI_bge-small-en-v1.5\\" candidate_model_ids = ["/".join([-2:])] # Sometimes the name_or_path its final part contains the full model_id, with "/" replaced with a "_", e.g. # "/root/.cache/torch/sentence_transformers/sentence-transformers_all-mpnet-base-v2/" # In that case, we take the last part, split on _, and try all combinations # e.g. "a_b_c_d" -> ['a/b_c_d', 'a_b/c_d', 'a_b_c/d'] splits ="_") candidate_model_ids += ["_".join(splits[:idx]) + "/" + "_".join(splits[idx:]) for idx in range(1, len(splits))] for model_id in candidate_model_ids: if is_on_huggingface(model_id): self.st_id = model_id break def set_st_id(self, model_id: str) -> None: if is_on_huggingface(model_id): self.st_id = model_id def post_training_eval_results(self, results: Dict[str, float]) -> None: def try_to_pure_python(value: Any) -> Any: """Try to convert a value from a Numpy or Torch scalar to pure Python, if not already pure Python""" try: if hasattr(value, "dtype"): return value.item() except Exception: pass return value pure_python_results = {key: try_to_pure_python(value) for key, value in results.items()} results_without_split = { key.split("_", maxsplit=1)[1].title(): value for key, value in pure_python_results.items() } self.eval_results_dict = pure_python_results self.metric_lines = [{"Label": "**all**", **results_without_split}] def _maybe_round(self, v, decimals=4): if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals: return f"{v:.{decimals}f}" return str(v) def to_dict(self) -> Dict[str, Any]: super_dict = { getattr(self, for field in fields(self)} # Compute required formats from the raw data if self.eval_results_dict: dataset_split = list(self.eval_results_dict.keys())[0].split("_")[0] dataset_id = self.dataset_id or "unknown" dataset_name = self.dataset_name or self.dataset_id or "Unknown" eval_results = [ EvalResult( task_type="text-classification", dataset_type=dataset_id, dataset_name=dataset_name, dataset_split=dataset_split, dataset_revision=self.dataset_revision, metric_type=metric_key.split("_", maxsplit=1)[1], metric_value=metric_value, task_name="Text Classification", metric_name=metric_key.split("_", maxsplit=1)[1].title(), ) for metric_key, metric_value in self.eval_results_dict.items() ] super_dict["metrics"] = [metric_key.split("_", maxsplit=1)[1] for metric_key in self.eval_results_dict] super_dict["model-index"] = eval_results_to_model_index(self.model_name, eval_results) eval_lines_list = [ { key: f"**{self._maybe_round(value)}**" if line["Step"] == self.best_model_step else value for key, value in line.items() } for line in self.eval_lines_list ] super_dict["eval_lines"] = make_markdown_table(eval_lines_list) super_dict["explain_bold_in_eval"] = "**" in super_dict["eval_lines"] # Replace |:---:| with |:---| for left alignment super_dict["label_examples"] = make_markdown_table(self.label_example_list).replace("-:|", "--|") super_dict["train_set_metrics"] = make_markdown_table(self.train_set_metrics_list).replace("-:|", "--|") super_dict["train_set_sentences_per_label_list"] = make_markdown_table( self.train_set_sentences_per_label_list ).replace("-:|", "--|") super_dict["metrics_table"] = make_markdown_table(self.metric_lines).replace("-:|", "--|") if self.code_carbon_callback and self.code_carbon_callback.tracker: emissions_data = self.code_carbon_callback.tracker._prepare_emissions_data() super_dict["co2_eq_emissions"] = { # * 1000 to convert kg to g "emissions": float(emissions_data.emissions) * 1000, "source": "codecarbon", "training_type": "fine-tuning", "on_cloud": emissions_data.on_cloud == "Y", "cpu_model": emissions_data.cpu_model, "ram_total_size": emissions_data.ram_total_size, "hours_used": round(emissions_data.duration / 3600, 3), } if emissions_data.gpu_model: super_dict["co2_eq_emissions"]["hardware_used"] = emissions_data.gpu_model if self.dataset_id: super_dict["datasets"] = [self.dataset_id] if self.st_id: super_dict["base_model"] = self.st_id super_dict["model_max_length"] = self.model.model_body.get_max_seq_length() if super_dict["num_classes"] is None: if self.model.labels: super_dict["num_classes"] = len(self.model.labels) if super_dict["absa"]: super_dict.update(super_dict.pop("absa")) for key in IGNORED_FIELDS: super_dict.pop(key, None) return super_dict def to_yaml(self, line_break=None) -> str: return yaml_dump( {key: value for key, value in self.to_dict().items() if key in YAML_FIELDS and value is not None}, sort_keys=False, line_break=line_break, ).strip() def is_on_huggingface(repo_id: str, is_model: bool = True) -> bool: # Models with more than two 'sections' certainly are not public models if len(repo_id.split("/")) > 2: return False try: if is_model: model_info(repo_id) else: dataset_info(repo_id) return True except Exception: # Fetching models can fail for many reasons: Repository not existing, no internet access, HF down, etc. return False def generate_model_card(model: "SetFitModel") -> str: template_path = Path(__file__).parent / "" model_card = ModelCard.from_template(card_data=model.model_card_data, template_path=template_path, hf_emoji="🤗") return model_card.content