"""HuggingFace-compatible classification and regression models including pytorch-lightning models. """ __all__ = ("BypassNet", "ClassificationHead", "ClassifierLitModel", "GPT2ForSequenceClassification", "RegressorLitModel", "SequenceClassifierOutput") from dataclasses import dataclass from typing import List, Optional import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F from torchmetrics import AUROC, AveragePrecision from transformers import AdamW, GPT2Model, GPT2PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutputWithPast from transformers.adapters.model_mixin import ModelWithHeadsAdaptersMixin @dataclass class SequenceClassifierOutput(SequenceClassifierOutputWithPast): target: Optional[torch.LongTensor] = None class GPT2ForSequenceClassification(ModelWithHeadsAdaptersMixin, GPT2PreTrainedModel): """HuggingFace-compatible single- and multi-output (-task) classification model. `config` must be a `GPT2Config` instance with additional `num_tasks` and `num_labels` properties. For multi-task classification, the output is Bypass network with the reduction factor = `config.n_embd // config.n_head`. """ _keys_to_ignore_on_load_missing = [ r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight", r"output\..*"] def __init__(self, config): super().__init__(config) self.num_tasks = config.num_tasks self.num_labels = config.num_labels self.transformer = GPT2Model(config) if self.num_tasks > 1: self.output = BypassNet( config.n_embd, config.n_embd // config.n_head, config.num_tasks, config.num_labels, config.embd_pdrop) else: self.output = ClassificationHead( config.n_embd, config.n_embd // config.n_head, config.num_labels, config.embd_pdrop) self.init_weights() def forward(self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, adapter_names=None, label_mask=None): return_dict = return_dict or self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, adapter_names=adapter_names) hidden_states = transformer_outputs[0] if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] assert self.config.pad_token_id is not None or batch_size == 1, \ "Cannot handle batch sizes > 1 if no padding token is defined." if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = torch.ne( input_ids, self.config.pad_token_id).sum(-1) - 1 else: sequence_lengths = -1 if self.num_tasks == 1: logits = self.output(hidden_states)[range(batch_size), sequence_lengths] else: logits = self.output(hidden_states, batch_size, sequence_lengths) loss = None if labels is not None: if self.num_labels == 2: if label_mask is not None: nonempty_tasks = (label_mask == 1).view(-1) nonempty_logits = logits.view(-1, self.num_labels)[nonempty_tasks, :] nonempty_labels = labels.view(-1)[nonempty_tasks] else: nonempty_logits = logits.view(-1, self.num_labels) nonempty_labels = labels.view(-1) if len(labels.size()) == 1: labels = labels.reshape(1, -1) loss = F.cross_entropy(nonempty_logits, nonempty_labels) elif self.num_labels == 1: loss = F.mse_loss(logits.view(-1), labels.view(-1)) else: raise NotImplementedError( "Only binary classification and regression supported.") if self.num_tasks > 1: logits = logits.transpose(1, 2) if labels is not None and self.num_labels == 2 and self.num_tasks == 1: if label_mask is not None: labels = labels.view(-1) else: labels = nonempty_labels if not return_dict: output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, target=labels, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions) class BypassNet(nn.Module): """Bypass multi-task network from MoleculeNet project [Wu et al., 2018]. """ def __init__(self, hidden_size: int, intermediate_size: int, num_tasks: int, num_labels: int = 2, dropout: float = 0.2, use_bias: bool = False): super().__init__() self.independent = nn.ModuleList([ ClassificationHead(hidden_size, intermediate_size, num_labels, dropout, use_bias) for _ in range(num_tasks)]) self.shared = ClassificationHead(hidden_size, intermediate_size, num_labels, dropout, use_bias) def forward(self, hidden_states, batch_size, sequence_lengths): logits_list: List[torch.Tensor] = [] for layer in self.independent: logits_list.append(layer(hidden_states)) shared_logits: torch.Tensor = self.shared(hidden_states) for i in range(len(logits_list)): logits_list[i] = (logits_list[i] + shared_logits)[range(batch_size), sequence_lengths] return torch.stack(logits_list, dim=1) class ClassificationHead(nn.Module): """Two-layer feed-forward network with GELU activation and intermediate dropout. """ def __init__(self, hidden_size: int, intermediate_size: int, num_labels: int, dropout: float = 0.0, use_bias: bool = False): super().__init__() self.dense = nn.Linear(hidden_size, intermediate_size, bias=use_bias) self.act = nn.GELU() self.dropout = nn.Dropout(dropout) self.out_proj = nn.Linear(intermediate_size, num_labels, bias=use_bias) def forward(self, x, *args, **kwargs): x = self.dense(x) x = self.act(x) x = self.dropout(x) return self.out_proj(x) class ClassifierLitModel(pl.LightningModule): """Pytorch-lightning module for single- or multi-task classification. Trains GPT2 model using `AdamW` optimizer with exponential LR scheduler. Evaluates valid and test data on AUC-ROC and AUC-PRC. Args: transformer (`GPT2Model`): (Pretrained) HuggingFace GPT2 model. num_tasks (int): The number of classification tasks. has_empty_labels (bool) batch_size (int) learning_rate (float) scheduler_lambda (float) scheduler_step (int) weight_decay (float) """ def __init__(self, transformer: GPT2Model, num_tasks: int, has_empty_labels: bool, batch_size: int, learning_rate: float, scheduler_lambda: float, scheduler_step: int, weight_decay: float, *args, **kwargs): super().__init__() self.save_hyperparameters(ignore=("transformer", "num_tasks", "has_empty_labels")) self.transformer = transformer self.num_tasks = num_tasks def get_metrics(metric_cls): return [metric_cls(num_classes=2) for _ in range(num_tasks)] if has_empty_labels: self.train_roc = get_metrics(AUROC) self.val_roc = get_metrics(AUROC) self.test_roc = get_metrics(AUROC) self.train_prc = get_metrics(AveragePrecision) self.val_prc = get_metrics(AveragePrecision) self.test_prc = get_metrics(AveragePrecision) self.step = self._step_empty self.epoch_end = self._epoch_end_empty else: #self.train_roc = AUROC(num_classes=2) #self.val_roc = AUROC(num_classes=2) #self.test_roc = AUROC(num_classes=2) #self.train_prc = AveragePrecision(num_classes=2) #self.val_prc = AveragePrecision(num_classes=2) #self.test_prc = AveragePrecision(num_classes=2) self.train_roc = AUROC(task='multiclass',num_classes=2) self.val_roc = AUROC(task='multiclass',num_classes=2) self.test_roc = AUROC(task='multiclass',num_classes=2) self.train_prc = AveragePrecision(task='multiclass',num_classes=2) self.val_prc = AveragePrecision(task='multiclass',num_classes=2) self.test_prc = AveragePrecision(task='multiclass',num_classes=2) self.step = self._step_nonempty self.epoch_end = self._epoch_end_nonempty def forward(self, *args, **kwargs): return self.transformer(*args, **kwargs) def _step_empty(self, batch, batch_idx, roc, prc): outputs = self(**batch) if self.num_tasks == 1: outputs["target"] = outputs["target"][:, None] outputs["logits"] = outputs["logits"][:, :, None] for task_id in range(self.num_tasks): target = outputs["target"][:, task_id] nonempty_entries = target != -1 target = target[nonempty_entries] if target.unique().size(0) > 1: logits = outputs["logits"][:, :, task_id][nonempty_entries] roc[task_id](logits, target) prc[task_id](logits, target) return {"loss": outputs["loss"]} def _step_nonempty(self, batch, batch_idx, roc, prc): outputs = self(**batch) logits, target = outputs["logits"], outputs["target"] if target.unique().size(0) > 1: roc(logits, target) prc(logits, target) return {"loss": outputs["loss"]} def _epoch_end_empty(self, outputs_ignored, roc, prc, prefix): mean_roc = sum(a.compute() for a in roc) / self.num_tasks self.log(f"{prefix}_roc", mean_roc, on_step=False, on_epoch=True, prog_bar=True) mean_prc = sum(p.compute() for p in prc) / self.num_tasks #p.compute()[1] self.log(f"{prefix}_prc", mean_prc, on_step=False, on_epoch=True, prog_bar=True) def _epoch_end_nonempty(self, outputs, roc, prc, prefix): self.log(f"{prefix}_roc", roc.compute(), on_step=False, on_epoch=True, prog_bar=True) self.log(f"{prefix}_prc", prc.compute(), #prc.compute()[1] on_step=False, on_epoch=True, prog_bar=True) def training_step(self, batch, batch_idx): return self.step(batch, batch_idx, self.train_roc, self.train_prc) def training_epoch_end(self, outputs): self.epoch_end(outputs, self.train_roc, self.train_prc, "train") def validation_step(self, batch, batch_idx): return self.step(batch, batch_idx, self.val_roc, self.val_prc) def validation_epoch_end(self, outputs): self.epoch_end(outputs, self.val_roc, self.val_prc, "val") def test_step(self, batch, batch_idx): self.step(batch, batch_idx, self.test_roc, self.test_prc) def test_epoch_end(self, outputs): self.epoch_end(outputs, self.test_roc, self.test_prc, "test") def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, self.hparams.scheduler_lambda) return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step", "frequency": self.hparams.scheduler_step}} class RegressorLitModel(pl.LightningModule): def __init__(self, transformer: GPT2Model, batch_size: int, learning_rate: float, scheduler_lambda: float, scheduler_step: int, weight_decay: float, *args, **kwargs): super().__init__() self.save_hyperparameters(ignore="transformer") self.transformer = transformer def forward(self, *args, **kwargs): return self.transformer(*args, **kwargs) hidden_states = transformer_outputs[0] def step(self, batch, batch_idx): outputs = self(**batch) rmse_loss = torch.sqrt(outputs["loss"]) return {"loss": rmse_loss} def epoch_end(self, outputs, prefix): mean_rmse = torch.mean(torch.tensor([out["loss"] for out in outputs])) self.log(f"{prefix}_rmse", mean_rmse, on_step=False, on_epoch=True, prog_bar=True) def training_step(self, batch, batch_idx): return self.step(batch, batch_idx) def training_epoch_end(self, outputs): self.epoch_end(outputs, "train") def validation_step(self, batch, batch_idx): return self.step(batch, batch_idx) def validation_epoch_end(self, outputs): self.epoch_end(outputs, "val") def test_step(self, batch, batch_idx): return self.step(batch, batch_idx) def test_epoch_end(self, outputs): self.epoch_end(outputs, "test") def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, self.hparams.scheduler_lambda) return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step", "frequency": self.hparams.scheduler_step}}