Spaces:
Sleeping
Sleeping
import math | |
from collections import OrderedDict | |
from functools import partial | |
from typing import Any, Callable, List, NamedTuple, Optional | |
import torch | |
import torch.nn as nn | |
try: | |
from torch.hub import load_state_dict_from_url | |
except ImportError: | |
from torch.utils.model_zoo import load_url as load_state_dict_from_url | |
model_urls = { | |
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", | |
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", | |
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", | |
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", | |
} | |
class MLPBlock(nn.Sequential): | |
"""Transformer MLP block.""" | |
def __init__(self, in_dim: int, mlp_dim: int, dropout: float): | |
super().__init__() | |
self.linear_1 = nn.Linear(in_dim, mlp_dim) | |
self.act = nn.GELU() | |
self.dropout_1 = nn.Dropout(dropout) | |
self.linear_2 = nn.Linear(mlp_dim, in_dim) | |
self.dropout_2 = nn.Dropout(dropout) | |
nn.init.xavier_uniform_(self.linear_1.weight) | |
nn.init.xavier_uniform_(self.linear_2.weight) | |
nn.init.normal_(self.linear_1.bias, std=1e-6) | |
nn.init.normal_(self.linear_2.bias, std=1e-6) | |
class EncoderBlock(nn.Module): | |
"""Transformer encoder block.""" | |
def __init__( | |
self, | |
num_heads: int, | |
hidden_dim: int, | |
mlp_dim: int, | |
dropout: float, | |
attention_dropout: float, | |
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
# Attention block | |
self.ln_1 = norm_layer(hidden_dim) | |
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) | |
self.dropout = nn.Dropout(dropout) | |
# MLP block | |
self.ln_2 = norm_layer(hidden_dim) | |
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) | |
def forward(self, input: torch.Tensor): | |
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") | |
x = self.ln_1(input) | |
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) | |
x = self.dropout(x) | |
x = x + input | |
y = self.ln_2(x) | |
y = self.mlp(y) | |
return x + y | |
class Encoder(nn.Module): | |
"""Transformer Model Encoder for sequence to sequence translation.""" | |
def __init__( | |
self, | |
seq_length: int, | |
num_layers: int, | |
num_heads: int, | |
hidden_dim: int, | |
mlp_dim: int, | |
dropout: float, | |
attention_dropout: float, | |
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), | |
): | |
super().__init__() | |
# Note that batch_size is on the first dim because | |
# we have batch_first=True in nn.MultiAttention() by default | |
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT | |
self.dropout = nn.Dropout(dropout) | |
layers: OrderedDict[str, nn.Module] = OrderedDict() | |
for i in range(num_layers): | |
layers[f"encoder_layer_{i}"] = EncoderBlock( | |
num_heads, | |
hidden_dim, | |
mlp_dim, | |
dropout, | |
attention_dropout, | |
norm_layer, | |
) | |
self.layers = nn.Sequential(layers) | |
self.ln = norm_layer(hidden_dim) | |
def forward(self, input: torch.Tensor): | |
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") | |
input = input + self.pos_embedding | |
return self.ln(self.layers(self.dropout(input))) | |
class FeatureTransformer(nn.Module): | |
""" | |
Feaure Transformer | |
""" | |
def __init__( | |
self, | |
seq_length: int = 16, | |
num_layers: int = 2, | |
num_heads: int = 4, | |
hidden_dim: int = 768, | |
mlp_dim: int = 768, | |
dropout: float = 0.0, | |
attention_dropout: float = 0.0, | |
num_classes: int = 1, | |
representation_size: Optional[int] = None, | |
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), | |
) -> None: | |
super().__init__() | |
# _log_api_usage_once(self) | |
self.hidden_dim = hidden_dim | |
self.mlp_dim = mlp_dim | |
self.attention_dropout = attention_dropout | |
self.dropout = dropout | |
self.num_classes = num_classes | |
self.representation_size = representation_size | |
self.norm_layer = norm_layer | |
# Add a class token | |
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) | |
seq_length += 1 | |
self.encoder = Encoder( | |
seq_length, | |
num_layers, | |
num_heads, | |
hidden_dim, | |
mlp_dim, | |
dropout, | |
attention_dropout, | |
norm_layer, | |
) | |
self.seq_length = seq_length | |
heads_layers: OrderedDict[str, nn.Module] = OrderedDict() | |
if representation_size is None: | |
heads_layers["head"] = nn.Linear(hidden_dim, num_classes) | |
else: | |
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) | |
heads_layers["act"] = nn.Tanh() | |
heads_layers["head"] = nn.Linear(representation_size, num_classes) | |
self.heads = nn.Sequential(heads_layers) | |
if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): | |
fan_in = self.heads.pre_logits.in_features | |
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) | |
nn.init.zeros_(self.heads.pre_logits.bias) | |
if isinstance(self.heads.head, nn.Linear): | |
nn.init.zeros_(self.heads.head.weight) | |
nn.init.zeros_(self.heads.head.bias) | |
def forward(self, x: torch.Tensor): | |
# Expand the class token to the full batch | |
batch_class_token = self.class_token.expand(x.shape[0], -1, -1) | |
x = torch.cat([batch_class_token, x], dim=1) | |
x = self.encoder(x) | |
# Classifier "token" as used by standard language architectures | |
x = x[:, 0] | |
x = self.heads(x) | |
return x | |