hubert-ecg-small / modeling_hubert_ecg.py
Edoardo-BS's picture
Upload 2 files
65e1c96 verified
raw
history blame
7.35 kB
import torch
import torch.nn as nn
from transformers import HubertModel
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput
from .configuration_hubert_ecg import HuBERTECGConfig
class HuBERTECG(HubertModel):
config_class = HuBERTECGConfig
def __init__(self, config: HuBERTECGConfig):
super().__init__(config)
self.config = config
self.pretraining_vocab_sizes = config.vocab_sizes
assert config.ensemble_length > 0 and config.ensemble_length == len(config.vocab_sizes), f"ensemble_length {config.ensemble_length} must be equal to len(vocab_sizes) {len(config.vocab_sizes)}"
# final projection layer to map encodings into the space of the codebook
self.final_proj = nn.ModuleList([nn.Linear(config.hidden_size, config.classifier_proj_size) for _ in range(config.ensemble_length)])
# embedding for codebooks
self.label_embedding = nn.ModuleList([nn.Embedding(vocab_size, config.classifier_proj_size) for vocab_size in config.vocab_sizes])
assert len(self.final_proj) == len(self.label_embedding), f"final_proj and label_embedding must have the same length"
def logits(self, transformer_output: torch.Tensor) -> torch.Tensor:
# takes (B, T, D)
# compute a projected output for each ensemble
projected_outputs = [final_projection(transformer_output) for final_projection in self.final_proj]
ensemble_logits = [torch.cosine_similarity(
projected_output.unsqueeze(2),
label_emb.weight.unsqueeze(0).unsqueeze(0),
dim=-1,
) / 0.1 for projected_output, label_emb in zip(projected_outputs, self.label_embedding)]
return ensemble_logits # returns [(BS, T, V)] * ensemble_length
class ActivationFunction(nn.Module):
def __init__(self, activation : str):
super(ActivationFunction, self).__init__()
self.activation = activation
if activation == 'tanh':
self.act = nn.Tanh()
elif activation == 'relu':
self.act = nn.ReLU()
elif activation == 'gelu':
self.act = nn.GELU()
elif activation == 'sigmoid':
self.act = nn.Sigmoid()
else:
raise ValueError('Activation function not supported')
def forward(self, x):
return self.act(x)
class HuBERTForECGClassification(nn.Module):
config_class = HuBERTECGConfig
def __init__(
self,
hubert_ecg : HuBERTECG,
num_labels : int,
classifier_hidden_size : int = None,
activation : str = 'tanh',
use_label_embedding : bool = False,
classifier_dropout_prob : float = 0.1):
super(HuBERTForECGClassification, self).__init__()
self.hubert_ecg = hubert_ecg
self.hubert_ecg.config.mask_time_prob = 0.0 # prevents masking
self.hubert_ecg.config.mask_feature_prob = 0.0 # prevents masking
# num_labels may be different from vocab_size when fine_tuning a pretrained hubert_ecg (in that case all modules are reused except embeddings, which could be frozen)
# otherwise, it should (not necessarily) be equal to vocab_size for consistency
self.num_labels = num_labels
self.config = self.hubert_ecg.config
self.classifier_hidden_size = classifier_hidden_size
self.activation = ActivationFunction(activation)
self.use_label_embedding = use_label_embedding
self.classifier_dropout = nn.Dropout(classifier_dropout_prob)
del self.hubert_ecg.label_embedding # not needed for classification
del self.hubert_ecg.final_proj # not needed for classification
if use_label_embedding: #for classification only
self.label_embedding = nn.Embedding(num_labels, self.config.hidden_size)
else:
# if num_labels == 1 the task is supposed to be a regression
if classifier_hidden_size is None: # no hidden layer
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
else:
self.classifier = nn.Sequential(
nn.Linear(self.config.hidden_size, classifier_hidden_size),
self.activation,
nn.Linear(classifier_hidden_size, num_labels)
)
def set_feature_extractor_trainable(self, trainable : bool):
'''Sets as (un)trainable the convolutional feature extractor of HuBERT-ECG'''
self.hubert_ecg.feature_extractor.requires_grad_(trainable)
def set_transformer_blocks_trainable(self, n_blocks : int):
''' Makes trainable only the last `n_blocks` of HuBERT-ECG transformer encoder'''
assert n_blocks >= 0, f"n_blocks (inserted {n_blocks}) should be >= 0"
assert n_blocks <= self.hubert_ecg.config.num_hidden_layers, f"n_blocks ({n_blocks}) should be <= {self.hubert_ecg.config.num_hidden_layers}"
self.hubert_ecg.encoder.requires_grad_(False)
for i in range(1, n_blocks+1):
self.hubert_ecg.encoder.layers[-i].requires_grad_(True)
def get_logits(self, pooled_output : torch.Tensor):
'''Computes cosine similary between transfomer pooled output, referred to as input representation, and look-up embedding matrix, that is a dense representation of labels.
In: pooled_output: (B, C) tensor
Out: (B, num_labels) tensor of similarities/logits to be sigmoided and used in BCE loss
'''
logits = torch.cosine_similarity(pooled_output.unsqueeze(1), self.label_embedding.weight.unsqueeze(0), dim=-1)
return logits
def forward(
self,
x: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, Union[Tuple, BaseModelOutput]]:
return_dict = return_dict if return_dict is not None else self.hubert_ecg.config.use_return_dict
output_hidden_states = True if self.hubert_ecg.config.use_weighted_layer_sum else output_hidden_states
encodings = self.hubert_ecg(
x,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
) # (B, T, D)
x = encodings.last_hidden_state
if attention_mask is None:
x = x.mean(dim=1) # (B, C)
else:
padding_mask = self.hubert_ecg._get_feature_vector_attention_mask(x.shape[1], attention_mask)
x[~padding_mask] = 0.0
x = x.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
x = self.classifier_dropout(x)
# (logits, hubert_output_dict)
# (B, num_labels), hidden_states, attentions
output = (
self.get_logits(x) if self.use_label_embedding else self.classifier(x),
encodings
)
return output