|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from .mobilefacenet import MobileFaceNet |
|
from .ir50 import Backbone |
|
from .vit_model import VisionTransformer, PatchEmbed |
|
from timm.layers import trunc_normal_, DropPath |
|
from thop import profile |
|
|
|
|
|
def load_pretrained_weights(model, checkpoint): |
|
import collections |
|
|
|
if "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
else: |
|
state_dict = checkpoint |
|
model_dict = model.state_dict() |
|
new_state_dict = collections.OrderedDict() |
|
matched_layers, discarded_layers = [], [] |
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith("module."): |
|
k = k[7:] |
|
if k in model_dict and model_dict[k].size() == v.size(): |
|
new_state_dict[k] = v |
|
matched_layers.append(k) |
|
else: |
|
discarded_layers.append(k) |
|
|
|
model_dict.update(new_state_dict) |
|
|
|
model.load_state_dict(model_dict) |
|
print("load_weight", len(matched_layers)) |
|
return model |
|
|
|
|
|
def window_partition(x, window_size, h_w, w_w): |
|
""" |
|
Args: |
|
x: (B, H, W, C) |
|
window_size: window size |
|
|
|
Returns: |
|
local window features (num_windows*B, window_size, window_size, C) |
|
""" |
|
B, H, W, C = x.shape |
|
x = x.view(B, h_w, window_size, w_w, window_size, C) |
|
windows = ( |
|
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
|
) |
|
return windows |
|
|
|
|
|
class window(nn.Module): |
|
def __init__(self, window_size, dim): |
|
super(window, self).__init__() |
|
self.window_size = window_size |
|
self.norm = nn.LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
x = x.permute(0, 2, 3, 1) |
|
B, H, W, C = x.shape |
|
x = self.norm(x) |
|
shortcut = x |
|
h_w = int(torch.div(H, self.window_size).item()) |
|
w_w = int(torch.div(W, self.window_size).item()) |
|
x_windows = window_partition(x, self.window_size, h_w, w_w) |
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) |
|
return x_windows, shortcut |
|
|
|
|
|
class WindowAttentionGlobal(nn.Module): |
|
""" |
|
Global window attention based on: "Hatamizadeh et al., |
|
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>" |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
num_heads, |
|
window_size, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
attn_drop=0.0, |
|
proj_drop=0.0, |
|
): |
|
""" |
|
Args: |
|
dim: feature size dimension. |
|
num_heads: number of attention head. |
|
window_size: window size. |
|
qkv_bias: bool argument for query, key, value learnable bias. |
|
qk_scale: bool argument to scaling query, key. |
|
attn_drop: attention dropout rate. |
|
proj_drop: output dropout rate. |
|
""" |
|
|
|
super().__init__() |
|
window_size = (window_size, window_size) |
|
self.window_size = window_size |
|
self.num_heads = num_heads |
|
head_dim = torch.div(dim, num_heads) |
|
self.scale = qk_scale or head_dim**-0.5 |
|
self.relative_position_bias_table = nn.Parameter( |
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) |
|
) |
|
coords_h = torch.arange(self.window_size[0]) |
|
coords_w = torch.arange(self.window_size[1]) |
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
|
coords_flatten = torch.flatten(coords, 1) |
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
|
relative_coords[:, :, 0] += self.window_size[0] - 1 |
|
relative_coords[:, :, 1] += self.window_size[1] - 1 |
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
|
relative_position_index = relative_coords.sum(-1) |
|
self.register_buffer("relative_position_index", relative_position_index) |
|
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
trunc_normal_(self.relative_position_bias_table, std=0.02) |
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, x, q_global): |
|
|
|
|
|
B_, N, C = x.shape |
|
B = q_global.shape[0] |
|
head_dim = int(torch.div(C, self.num_heads).item()) |
|
B_dim = int(torch.div(B_, B).item()) |
|
kv = ( |
|
self.qkv(x) |
|
.reshape(B_, N, 2, self.num_heads, head_dim) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
k, v = kv[0], kv[1] |
|
q_global = q_global.repeat(1, B_dim, 1, 1, 1) |
|
q = q_global.reshape(B_, self.num_heads, N, head_dim) |
|
q = q * self.scale |
|
attn = q @ k.transpose(-2, -1) |
|
relative_position_bias = self.relative_position_bias_table[ |
|
self.relative_position_index.view(-1) |
|
].view( |
|
self.window_size[0] * self.window_size[1], |
|
self.window_size[0] * self.window_size[1], |
|
-1, |
|
) |
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
|
attn = attn + relative_position_bias.unsqueeze(0) |
|
attn = self.softmax(attn) |
|
attn = self.attn_drop(attn) |
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
def _to_channel_last(x): |
|
""" |
|
Args: |
|
x: (B, C, H, W) |
|
|
|
Returns: |
|
x: (B, H, W, C) |
|
""" |
|
return x.permute(0, 2, 3, 1) |
|
|
|
|
|
def _to_channel_first(x): |
|
return x.permute(0, 3, 1, 2) |
|
|
|
|
|
def _to_query(x, N, num_heads, dim_head): |
|
B = x.shape[0] |
|
x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4) |
|
return x |
|
|
|
|
|
class Mlp(nn.Module): |
|
""" |
|
Multi-Layer Perceptron (MLP) block |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
drop=0.0, |
|
): |
|
""" |
|
Args: |
|
in_features: input features dimension. |
|
hidden_features: hidden features dimension. |
|
out_features: output features dimension. |
|
act_layer: activation function. |
|
drop: dropout rate. |
|
""" |
|
|
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
def window_reverse(windows, window_size, H, W, h_w, w_w): |
|
""" |
|
Args: |
|
windows: local window features (num_windows*B, window_size, window_size, C) |
|
window_size: Window size |
|
H: Height of image |
|
W: Width of image |
|
|
|
Returns: |
|
x: (B, H, W, C) |
|
""" |
|
B = int(windows.shape[0] / (H * W / window_size / window_size)) |
|
x = windows.view(B, h_w, w_w, window_size, window_size, -1) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|
return x |
|
|
|
|
|
class feedforward(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
window_size, |
|
mlp_ratio=4.0, |
|
act_layer=nn.GELU, |
|
drop=0.0, |
|
drop_path=0.0, |
|
layer_scale=None, |
|
): |
|
super(feedforward, self).__init__() |
|
if layer_scale is not None and type(layer_scale) in [int, float]: |
|
self.layer_scale = True |
|
self.gamma1 = nn.Parameter( |
|
layer_scale * torch.ones(dim), requires_grad=True |
|
) |
|
self.gamma2 = nn.Parameter( |
|
layer_scale * torch.ones(dim), requires_grad=True |
|
) |
|
else: |
|
self.gamma1 = 1.0 |
|
self.gamma2 = 1.0 |
|
self.window_size = window_size |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=int(dim * mlp_ratio), |
|
act_layer=act_layer, |
|
drop=drop, |
|
) |
|
self.norm = nn.LayerNorm(dim) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, attn_windows, shortcut): |
|
B, H, W, C = shortcut.shape |
|
h_w = int(torch.div(H, self.window_size).item()) |
|
w_w = int(torch.div(W, self.window_size).item()) |
|
x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w) |
|
x = shortcut + self.drop_path(self.gamma1 * x) |
|
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x))) |
|
return x |
|
|
|
|
|
class pyramid_trans_expr2(nn.Module): |
|
def __init__( |
|
self, |
|
img_size=224, |
|
num_classes=7, |
|
window_size=[28, 14, 7], |
|
num_heads=[2, 4, 8], |
|
dims=[64, 128, 256], |
|
embed_dim=768, |
|
): |
|
super().__init__() |
|
|
|
self.img_size = img_size |
|
self.num_heads = num_heads |
|
self.dim_head = [] |
|
for num_head, dim in zip(num_heads, dims): |
|
self.dim_head.append(int(torch.div(dim, num_head).item())) |
|
self.num_classes = num_classes |
|
self.window_size = window_size |
|
self.N = [win * win for win in window_size] |
|
self.face_landback = MobileFaceNet([112, 112], 136) |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
mobilefacenet_path = os.path.join( |
|
script_dir, "pretrain", "mobilefacenet_model_best.pth.tar" |
|
) |
|
ir50_path = os.path.join(script_dir, "pretrain","ir50.pth") |
|
|
|
print(mobilefacenet_path) |
|
face_landback_checkpoint = torch.load( |
|
mobilefacenet_path, |
|
map_location=lambda storage, loc: storage, |
|
weights_only=False, |
|
) |
|
self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"]) |
|
|
|
for param in self.face_landback.parameters(): |
|
param.requires_grad = False |
|
|
|
self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim) |
|
|
|
self.ir_back = Backbone(50, 0.0, "ir") |
|
ir_checkpoint = torch.load( |
|
ir50_path, map_location=lambda storage, loc: storage, weights_only=False |
|
) |
|
|
|
self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint) |
|
|
|
self.attn1 = WindowAttentionGlobal( |
|
dim=dims[0], num_heads=num_heads[0], window_size=window_size[0] |
|
) |
|
self.attn2 = WindowAttentionGlobal( |
|
dim=dims[1], num_heads=num_heads[1], window_size=window_size[1] |
|
) |
|
self.attn3 = WindowAttentionGlobal( |
|
dim=dims[2], num_heads=num_heads[2], window_size=window_size[2] |
|
) |
|
self.window1 = window(window_size=window_size[0], dim=dims[0]) |
|
self.window2 = window(window_size=window_size[1], dim=dims[1]) |
|
self.window3 = window(window_size=window_size[2], dim=dims[2]) |
|
self.conv1 = nn.Conv2d( |
|
in_channels=dims[0], |
|
out_channels=dims[0], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
) |
|
self.conv2 = nn.Conv2d( |
|
in_channels=dims[1], |
|
out_channels=dims[1], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
) |
|
self.conv3 = nn.Conv2d( |
|
in_channels=dims[2], |
|
out_channels=dims[2], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
) |
|
|
|
dpr = [x.item() for x in torch.linspace(0, 0.5, 5)] |
|
self.ffn1 = feedforward( |
|
dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0] |
|
) |
|
self.ffn2 = feedforward( |
|
dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1] |
|
) |
|
self.ffn3 = feedforward( |
|
dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2] |
|
) |
|
|
|
self.last_face_conv = nn.Conv2d( |
|
in_channels=512, out_channels=256, kernel_size=3, padding=1 |
|
) |
|
|
|
self.embed_q = nn.Sequential( |
|
nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1), |
|
nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1), |
|
) |
|
self.embed_k = nn.Sequential( |
|
nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1) |
|
) |
|
self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768) |
|
|
|
def forward(self, x): |
|
x_face = F.interpolate(x, size=112) |
|
x_face1, x_face2, x_face3 = self.face_landback(x_face) |
|
x_face3 = self.last_face_conv(x_face3) |
|
x_face1, x_face2, x_face3 = ( |
|
_to_channel_last(x_face1), |
|
_to_channel_last(x_face2), |
|
_to_channel_last(x_face3), |
|
) |
|
|
|
q1, q2, q3 = ( |
|
_to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]), |
|
_to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]), |
|
_to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2]), |
|
) |
|
|
|
x_ir1, x_ir2, x_ir3 = self.ir_back(x) |
|
|
|
x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3) |
|
x_window1, shortcut1 = self.window1(x_ir1) |
|
x_window2, shortcut2 = self.window2(x_ir2) |
|
x_window3, shortcut3 = self.window3(x_ir3) |
|
|
|
o1, o2, o3 = ( |
|
self.attn1(x_window1, q1), |
|
self.attn2(x_window2, q2), |
|
self.attn3(x_window3, q3), |
|
) |
|
|
|
o1, o2, o3 = ( |
|
self.ffn1(o1, shortcut1), |
|
self.ffn2(o2, shortcut2), |
|
self.ffn3(o3, shortcut3), |
|
) |
|
|
|
o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3) |
|
|
|
o1, o2, o3 = ( |
|
self.embed_q(o1).flatten(2).transpose(1, 2), |
|
self.embed_k(o2).flatten(2).transpose(1, 2), |
|
self.embed_v(o3), |
|
) |
|
|
|
o = torch.cat([o1, o2, o3], dim=1) |
|
|
|
out = self.VIT(o) |
|
return out |
|
|
|
|
|
def compute_param_flop(): |
|
model = pyramid_trans_expr2() |
|
img = torch.rand(size=(1, 3, 224, 224)) |
|
flops, params = profile(model, inputs=(img,)) |
|
print(f"flops:{flops/1000**3}G,params:{params/1000**2}M") |
|
|