import collections
import random
from collections import Counter, defaultdict
from dataclasses import dataclass, field, fields
from pathlib import Path
from platform import python_version
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import datasets
import tokenizers
import torch
import transformers
from datasets import Dataset
from huggingface_hub import CardData, DatasetFilter, ModelCard, dataset_info, list_datasets, model_info
from huggingface_hub.repocard_data import EvalResult, eval_results_to_model_index
from huggingface_hub.utils import yaml_dump
from sentence_transformers import __version__ as sentence_transformers_version
from transformers import PretrainedConfig, TrainerCallback
from transformers.integrations import CodeCarbonCallback
from transformers.modelcard import make_markdown_table
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from setfit import __version__ as setfit_version
from . import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from setfit.modeling import SetFitModel
from setfit.trainer import Trainer
class ModelCardCallback(TrainerCallback):
def __init__(self, trainer: "Trainer") -> None:
super().__init__()
self.trainer = trainer
callbacks = [
callback
for callback in self.trainer.callback_handler.callbacks
if isinstance(callback, CodeCarbonCallback)
]
if callbacks:
trainer.model.model_card_data.code_carbon_callback = callbacks[0]
def on_init_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", **kwargs
):
if not model.model_card_data.dataset_id:
# Inferring is hacky - it may break in the future, so let's be safe
try:
model.model_card_data.infer_dataset_id(self.trainer.train_dataset)
except Exception:
pass
dataset = self.trainer.eval_dataset or self.trainer.train_dataset
if dataset is not None:
if not model.model_card_data.widget:
model.model_card_data.set_widget_examples(dataset)
if self.trainer.train_dataset:
model.model_card_data.set_train_set_metrics(self.trainer.train_dataset)
# Does not work for multilabel
try:
model.model_card_data.num_classes = len(set(self.trainer.train_dataset["label"]))
model.model_card_data.set_label_examples(self.trainer.train_dataset)
except Exception:
pass
def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", **kwargs
) -> None:
# model.model_card_data.hyperparameters = extract_hyperparameters_from_trainer(self.trainer)
ignore_keys = {
"output_dir",
"logging_dir",
"logging_strategy",
"logging_first_step",
"logging_steps",
"evaluation_strategy",
"eval_steps",
"eval_delay",
"save_strategy",
"save_steps",
"save_total_limit",
"metric_for_best_model",
"greater_is_better",
"report_to",
"samples_per_label",
"show_progress_bar",
}
get_name_keys = {"loss", "distance_metric"}
args_dict = args.to_dict()
model.model_card_data.hyperparameters = {
key: value.__name__ if key in get_name_keys else value
for key, value in args_dict.items()
if key not in ignore_keys and value is not None
}
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: "SetFitModel",
metrics: Dict[str, float],
**kwargs,
) -> None:
if (
model.model_card_data.eval_lines_list
and model.model_card_data.eval_lines_list[-1]["Step"] == state.global_step
):
model.model_card_data.eval_lines_list[-1]["Validation Loss"] = metrics["eval_embedding_loss"]
else:
model.model_card_data.eval_lines_list.append(
{
# "Training Loss": self.state.log_history[-1]["loss"] if "loss" in self.state.log_history[-1] else "-",
"Epoch": state.epoch,
"Step": state.global_step,
"Training Loss": "-",
"Validation Loss": metrics["eval_embedding_loss"],
}
)
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: "SetFitModel",
logs: Dict[str, float],
**kwargs,
):
keys = {"embedding_loss", "polarity_embedding_loss", "aspect_embedding_loss"} & set(logs)
if keys:
if (
model.model_card_data.eval_lines_list
and model.model_card_data.eval_lines_list[-1]["Step"] == state.global_step
):
model.model_card_data.eval_lines_list[-1]["Training Loss"] = logs[keys.pop()]
else:
model.model_card_data.eval_lines_list.append(
{
"Epoch": state.epoch,
"Step": state.global_step,
"Training Loss": logs[keys.pop()],
"Validation Loss": "-",
}
)
YAML_FIELDS = [
"language",
"license",
"library_name",
"tags",
"datasets",
"metrics",
"pipeline_tag",
"widget",
"model-index",
"co2_eq_emissions",
"base_model",
"inference",
]
IGNORED_FIELDS = ["model"]
@dataclass
class SetFitModelCardData(CardData):
"""A dataclass storing data used in the model card.
Args:
language (`Optional[Union[str, List[str]]]`): The model language, either a string or a list,
e.g. "en" or ["en", "de", "nl"]
license (`Optional[str]`): The license of the model, e.g. "apache-2.0", "mit",
or "cc-by-nc-sa-4.0"
model_name (`Optional[str]`): The pretty name of the model, e.g. "SetFit with mBERT-base on SST2".
If not defined, uses encoder_name/encoder_id and dataset_name/dataset_id to generate a model name.
model_id (`Optional[str]`): The model ID when pushing the model to the Hub,
e.g. "tomaarsen/span-marker-mbert-base-multinerd".
dataset_name (`Optional[str]`): The pretty name of the dataset, e.g. "SST2".
dataset_id (`Optional[str]`): The dataset ID of the dataset, e.g. "dair-ai/emotion".
dataset_revision (`Optional[str]`): The dataset revision/commit that was for training/evaluation.
st_id (`Optional[str]`): The Sentence Transformers model ID.
Install [``codecarbon``](https://github.com/mlco2/codecarbon) to automatically track carbon emission usage and
include it in your model cards.
Example::
>>> model = SetFitModel.from_pretrained(
... "sentence-transformers/paraphrase-mpnet-base-v2",
... labels=["negative", "positive"],
... # Model card variables
... model_card_data=SetFitModelCardData(
... model_id="tomaarsen/setfit-paraphrase-mpnet-base-v2-sst2",
... dataset_name="SST2",
... dataset_id="sst2",
... license="apache-2.0",
... language="en",
... ),
... )
"""
# Potentially provided by the user
language: Optional[Union[str, List[str]]] = None
license: Optional[str] = None
tags: Optional[List[str]] = field(
default_factory=lambda: [
"setfit",
"sentence-transformers",
"text-classification",
"generated_from_setfit_trainer",
]
)
model_name: Optional[str] = None
model_id: Optional[str] = None
dataset_name: Optional[str] = None
dataset_id: Optional[str] = None
dataset_revision: Optional[str] = None
task_name: Optional[str] = None
st_id: Optional[str] = None
# Automatically filled by `ModelCardCallback` and the Trainer directly
hyperparameters: Dict[str, Any] = field(default_factory=dict, init=False)
eval_results_dict: Optional[Dict[str, Any]] = field(default_factory=dict, init=False)
eval_lines_list: List[Dict[str, float]] = field(default_factory=list, init=False)
metric_lines: List[Dict[str, float]] = field(default_factory=list, init=False)
widget: List[Dict[str, str]] = field(default_factory=list, init=False)
predict_example: Optional[str] = field(default=None, init=False)
label_example_list: List[Dict[str, str]] = field(default_factory=list, init=False)
tokenizer_warning: bool = field(default=False, init=False)
train_set_metrics_list: List[Dict[str, str]] = field(default_factory=list, init=False)
train_set_sentences_per_label_list: List[Dict[str, str]] = field(default_factory=list, init=False)
code_carbon_callback: Optional[CodeCarbonCallback] = field(default=None, init=False)
num_classes: Optional[int] = field(default=None, init=False)
best_model_step: Optional[int] = field(default=None, init=False)
metrics: List[str] = field(default_factory=lambda: ["accuracy"], init=False)
# Computed once, always unchanged
pipeline_tag: str = field(default="text-classification", init=False)
library_name: str = field(default="setfit", init=False)
version: Dict[str, str] = field(
default_factory=lambda: {
"python": python_version(),
"setfit": setfit_version,
"sentence_transformers": sentence_transformers_version,
"transformers": transformers.__version__,
"torch": torch.__version__,
"datasets": datasets.__version__,
"tokenizers": tokenizers.__version__,
},
init=False,
)
# ABSA-related arguments
absa: Dict[str, Any] = field(default=None, init=False, repr=False)
# Passed via `register_model` only
model: Optional["SetFitModel"] = field(default=None, init=False, repr=False)
head_class: Optional[str] = field(default=None, init=False, repr=False)
inference: Optional[bool] = field(default=True, init=False, repr=False)
def __post_init__(self):
# We don't want to save "ignore_metadata_errors" in our Model Card
if self.dataset_id:
if is_on_huggingface(self.dataset_id, is_model=False):
if self.language is None:
# if languages are not set, try to determine the language from the dataset on the Hub
try:
info = dataset_info(self.dataset_id)
except Exception:
pass
else:
if info.cardData:
self.language = info.cardData.get("language", self.language)
else:
logger.warning(
f"The provided {self.dataset_id!r} dataset could not be found on the Hugging Face Hub."
" Setting `dataset_id` to None."
)
self.dataset_id = None
if self.model_id and self.model_id.count("/") != 1:
logger.warning(
f"The provided {self.model_id!r} model ID should include the organization or user,"
' such as "tomaarsen/setfit-bge-small-v1.5-sst2-8-shot". Setting `model_id` to None.'
)
self.model_id = None
def set_best_model_step(self, step: int) -> None:
self.best_model_step = step
def set_widget_examples(self, dataset: Dataset) -> None:
samples = dataset.select(random.sample(range(len(dataset)), k=min(len(dataset), 5)))["text"]
self.widget = [{"text": sample} for sample in samples]
samples.sort(key=len)
if samples:
self.predict_example = samples[0]
def set_train_set_metrics(self, dataset: Dataset) -> None:
def add_naive_word_count(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["word_count"] = len(sample["text"].split(" "))
return sample
dataset = dataset.map(add_naive_word_count)
self.train_set_metrics_list = [
{
"Training set": "Word count",
"Min": min(dataset["word_count"]),
"Median": sum(dataset["word_count"]) / len(dataset),
"Max": max(dataset["word_count"]),
},
]
# E.g. if unlabeled via DistillationTrainer
if "label" not in dataset.column_names:
return
sample_label = dataset[0]["label"]
if isinstance(sample_label, collections.abc.Sequence) and not isinstance(sample_label, str):
return
try:
counter = Counter(dataset["label"])
if self.model.labels:
self.train_set_sentences_per_label_list = [
{
"Label": str_label,
"Training Sample Count": counter[
str_label if isinstance(sample_label, str) else self.model.label2id[str_label]
],
}
for str_label in self.model.labels
]
else:
self.train_set_sentences_per_label_list = [
{
"Label": self.model.labels[label]
if self.model.labels and isinstance(label, int)
else str(label),
"Training Sample Count": count,
}
for label, count in sorted(counter.items())
]
except Exception:
# There are some tricky edge cases possible, e.g. if the user provided integer labels that do not fall
# between 0 to num_classes-1, so we make sure we never cause errors.
pass
def set_label_examples(self, dataset: Dataset) -> None:
num_examples_per_label = 3
examples = defaultdict(list)
finished_labels = set()
for sample in dataset:
text = sample["text"]
label = sample["label"]
if label not in finished_labels:
examples[label].append(f"
{repr(text)}")
if len(examples[label]) >= num_examples_per_label:
finished_labels.add(label)
if len(finished_labels) == self.num_classes:
break
self.label_example_list = [
{
"Label": self.model.labels[label] if self.model.labels and isinstance(label, int) else label,
"Examples": "" + "".join(example_set) + "
",
}
for label, example_set in examples.items()
]
def infer_dataset_id(self, dataset: Dataset) -> None:
def subtuple_finder(tuple: Tuple[str], subtuple: Tuple[str]) -> int:
for i, element in enumerate(tuple):
if element == subtuple[0] and tuple[i : i + len(subtuple)] == subtuple:
return i
return -1
def normalize(dataset_id: str) -> str:
for token in "/\\_-":
dataset_id = dataset_id.replace(token, "")
return dataset_id.lower()
cache_files = dataset.cache_files
if cache_files and "filename" in cache_files[0]:
cache_path_parts = Path(cache_files[0]["filename"]).parts
# Check if the cachefile is under "huggingface/datasets"
subtuple = ("huggingface", "datasets")
index = subtuple_finder(cache_path_parts, subtuple)
if index == -1:
return
# Get the folder after "huggingface/datasets"
cache_dataset_name = cache_path_parts[index + len(subtuple)]
# If the dataset has an author:
if "___" in cache_dataset_name:
author, dataset_name = cache_dataset_name.split("___")
else:
author = None
dataset_name = cache_dataset_name
# Make sure the normalized dataset IDs match
dataset_list = [
dataset
for dataset in list_datasets(filter=DatasetFilter(author=author, dataset_name=dataset_name))
if normalize(dataset.id) == normalize(cache_dataset_name)
]
# If there's only one match, get the ID from it
if len(dataset_list) == 1:
self.dataset_id = dataset_list[0].id
def register_model(self, model: "SetFitModel") -> None:
self.model = model
head_class = model.model_head.__class__.__name__
self.head_class = {
"LogisticRegression": "[LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html)",
"SetFitHead": "[SetFitHead](huggingface.co/docs/setfit/reference/main#setfit.SetFitHead)",
}.get(head_class, head_class)
if not self.model_name:
if self.st_id:
self.model_name = f"SetFit with {self.st_id}"
if self.dataset_name or self.dataset_id:
self.model_name += f" on {self.dataset_name or self.dataset_id}"
else:
self.model_name = "SetFit"
self.inference = self.model.multi_target_strategy is None
def infer_st_id(self, setfit_model_id: str) -> None:
config_dict, _ = PretrainedConfig.get_config_dict(setfit_model_id)
st_id = config_dict.get("_name_or_path")
st_id_path = Path(st_id)
# Sometimes the name_or_path ends exactly with the model_id, e.g.
# "C:\\Users\\tom/.cache\\torch\\sentence_transformers\\BAAI_bge-small-en-v1.5\\"
candidate_model_ids = ["/".join(st_id_path.parts[-2:])]
# Sometimes the name_or_path its final part contains the full model_id, with "/" replaced with a "_", e.g.
# "/root/.cache/torch/sentence_transformers/sentence-transformers_all-mpnet-base-v2/"
# In that case, we take the last part, split on _, and try all combinations
# e.g. "a_b_c_d" -> ['a/b_c_d', 'a_b/c_d', 'a_b_c/d']
splits = st_id_path.name.split("_")
candidate_model_ids += ["_".join(splits[:idx]) + "/" + "_".join(splits[idx:]) for idx in range(1, len(splits))]
for model_id in candidate_model_ids:
if is_on_huggingface(model_id):
self.st_id = model_id
break
def set_st_id(self, model_id: str) -> None:
if is_on_huggingface(model_id):
self.st_id = model_id
def post_training_eval_results(self, results: Dict[str, float]) -> None:
def try_to_pure_python(value: Any) -> Any:
"""Try to convert a value from a Numpy or Torch scalar to pure Python, if not already pure Python"""
try:
if hasattr(value, "dtype"):
return value.item()
except Exception:
pass
return value
pure_python_results = {key: try_to_pure_python(value) for key, value in results.items()}
results_without_split = {
key.split("_", maxsplit=1)[1].title(): value for key, value in pure_python_results.items()
}
self.eval_results_dict = pure_python_results
self.metric_lines = [{"Label": "**all**", **results_without_split}]
def _maybe_round(self, v, decimals=4):
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
return f"{v:.{decimals}f}"
return str(v)
def to_dict(self) -> Dict[str, Any]:
super_dict = {field.name: getattr(self, field.name) for field in fields(self)}
# Compute required formats from the raw data
if self.eval_results_dict:
dataset_split = list(self.eval_results_dict.keys())[0].split("_")[0]
dataset_id = self.dataset_id or "unknown"
dataset_name = self.dataset_name or self.dataset_id or "Unknown"
eval_results = [
EvalResult(
task_type="text-classification",
dataset_type=dataset_id,
dataset_name=dataset_name,
dataset_split=dataset_split,
dataset_revision=self.dataset_revision,
metric_type=metric_key.split("_", maxsplit=1)[1],
metric_value=metric_value,
task_name="Text Classification",
metric_name=metric_key.split("_", maxsplit=1)[1].title(),
)
for metric_key, metric_value in self.eval_results_dict.items()
]
super_dict["metrics"] = [metric_key.split("_", maxsplit=1)[1] for metric_key in self.eval_results_dict]
super_dict["model-index"] = eval_results_to_model_index(self.model_name, eval_results)
eval_lines_list = [
{
key: f"**{self._maybe_round(value)}**" if line["Step"] == self.best_model_step else value
for key, value in line.items()
}
for line in self.eval_lines_list
]
super_dict["eval_lines"] = make_markdown_table(eval_lines_list)
super_dict["explain_bold_in_eval"] = "**" in super_dict["eval_lines"]
# Replace |:---:| with |:---| for left alignment
super_dict["label_examples"] = make_markdown_table(self.label_example_list).replace("-:|", "--|")
super_dict["train_set_metrics"] = make_markdown_table(self.train_set_metrics_list).replace("-:|", "--|")
super_dict["train_set_sentences_per_label_list"] = make_markdown_table(
self.train_set_sentences_per_label_list
).replace("-:|", "--|")
super_dict["metrics_table"] = make_markdown_table(self.metric_lines).replace("-:|", "--|")
if self.code_carbon_callback and self.code_carbon_callback.tracker:
emissions_data = self.code_carbon_callback.tracker._prepare_emissions_data()
super_dict["co2_eq_emissions"] = {
# * 1000 to convert kg to g
"emissions": float(emissions_data.emissions) * 1000,
"source": "codecarbon",
"training_type": "fine-tuning",
"on_cloud": emissions_data.on_cloud == "Y",
"cpu_model": emissions_data.cpu_model,
"ram_total_size": emissions_data.ram_total_size,
"hours_used": round(emissions_data.duration / 3600, 3),
}
if emissions_data.gpu_model:
super_dict["co2_eq_emissions"]["hardware_used"] = emissions_data.gpu_model
if self.dataset_id:
super_dict["datasets"] = [self.dataset_id]
if self.st_id:
super_dict["base_model"] = self.st_id
super_dict["model_max_length"] = self.model.model_body.get_max_seq_length()
if super_dict["num_classes"] is None:
if self.model.labels:
super_dict["num_classes"] = len(self.model.labels)
if super_dict["absa"]:
super_dict.update(super_dict.pop("absa"))
for key in IGNORED_FIELDS:
super_dict.pop(key, None)
return super_dict
def to_yaml(self, line_break=None) -> str:
return yaml_dump(
{key: value for key, value in self.to_dict().items() if key in YAML_FIELDS and value is not None},
sort_keys=False,
line_break=line_break,
).strip()
def is_on_huggingface(repo_id: str, is_model: bool = True) -> bool:
# Models with more than two 'sections' certainly are not public models
if len(repo_id.split("/")) > 2:
return False
try:
if is_model:
model_info(repo_id)
else:
dataset_info(repo_id)
return True
except Exception:
# Fetching models can fail for many reasons: Repository not existing, no internet access, HF down, etc.
return False
def generate_model_card(model: "SetFitModel") -> str:
template_path = Path(__file__).parent / "model_card_template.md"
model_card = ModelCard.from_template(card_data=model.model_card_data, template_path=template_path, hf_emoji="🤗")
return model_card.content