import math from collections import OrderedDict from functools import partial from typing import Any, Callable, List, NamedTuple, Optional import torch import torch.nn as nn try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url model_urls = { "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", } class MLPBlock(nn.Sequential): """Transformer MLP block.""" def __init__(self, in_dim: int, mlp_dim: int, dropout: float): super().__init__() self.linear_1 = nn.Linear(in_dim, mlp_dim) self.act = nn.GELU() self.dropout_1 = nn.Dropout(dropout) self.linear_2 = nn.Linear(mlp_dim, in_dim) self.dropout_2 = nn.Dropout(dropout) nn.init.xavier_uniform_(self.linear_1.weight) nn.init.xavier_uniform_(self.linear_2.weight) nn.init.normal_(self.linear_1.bias, std=1e-6) nn.init.normal_(self.linear_2.bias, std=1e-6) class EncoderBlock(nn.Module): """Transformer encoder block.""" def __init__( self, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float, attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() self.num_heads = num_heads # Attention block self.ln_1 = norm_layer(hidden_dim) self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) self.dropout = nn.Dropout(dropout) # MLP block self.ln_2 = norm_layer(hidden_dim) self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) x = x + input y = self.ln_2(x) y = self.mlp(y) return x + y class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" def __init__( self, seq_length: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float, attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() # Note that batch_size is on the first dim because # we have batch_first=True in nn.MultiAttention() by default self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT self.dropout = nn.Dropout(dropout) layers: OrderedDict[str, nn.Module] = OrderedDict() for i in range(num_layers): layers[f"encoder_layer_{i}"] = EncoderBlock( num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer, ) self.layers = nn.Sequential(layers) self.ln = norm_layer(hidden_dim) def forward(self, input: torch.Tensor): torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") input = input + self.pos_embedding return self.ln(self.layers(self.dropout(input))) class FeatureTransformer(nn.Module): """ Feaure Transformer """ def __init__( self, seq_length: int = 16, num_layers: int = 2, num_heads: int = 4, hidden_dim: int = 768, mlp_dim: int = 768, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1, representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ) -> None: super().__init__() # _log_api_usage_once(self) self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim self.attention_dropout = attention_dropout self.dropout = dropout self.num_classes = num_classes self.representation_size = representation_size self.norm_layer = norm_layer # Add a class token self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) seq_length += 1 self.encoder = Encoder( seq_length, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer, ) self.seq_length = seq_length heads_layers: OrderedDict[str, nn.Module] = OrderedDict() if representation_size is None: heads_layers["head"] = nn.Linear(hidden_dim, num_classes) else: heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) heads_layers["act"] = nn.Tanh() heads_layers["head"] = nn.Linear(representation_size, num_classes) self.heads = nn.Sequential(heads_layers) if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): fan_in = self.heads.pre_logits.in_features nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) nn.init.zeros_(self.heads.pre_logits.bias) if isinstance(self.heads.head, nn.Linear): nn.init.zeros_(self.heads.head.weight) nn.init.zeros_(self.heads.head.bias) def forward(self, x: torch.Tensor): # Expand the class token to the full batch batch_class_token = self.class_token.expand(x.shape[0], -1, -1) x = torch.cat([batch_class_token, x], dim=1) x = self.encoder(x) # Classifier "token" as used by standard language architectures x = x[:, 0] x = self.heads(x) return x