|
import torch |
|
from torch import nn |
|
|
|
from collections import OrderedDict |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.weight.dtype != x.dtype: |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(self.weight.dtype)) |
|
return ret.type(orig_type) |
|
else: |
|
return super().forward(x) |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
n_head: int, |
|
attn_mask: torch.Tensor = None, |
|
): |
|
super().__init__() |
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
|
self.ln_1 = LayerNorm(d_model) |
|
self.mlp = nn.Sequential( |
|
OrderedDict( |
|
[ |
|
( |
|
"c_fc", |
|
nn.Linear(d_model, d_model * 4), |
|
), |
|
("gelu", QuickGELU()), |
|
( |
|
"c_proj", |
|
nn.Linear(d_model * 4, d_model), |
|
), |
|
] |
|
) |
|
) |
|
self.ln_2 = LayerNorm(d_model) |
|
self.attn_mask = attn_mask |
|
|
|
def attention(self, x: torch.Tensor): |
|
self.attn_mask = ( |
|
self.attn_mask.to(dtype=x.dtype, device=x.device) |
|
if self.attn_mask is not None |
|
else None |
|
) |
|
return self.attn( |
|
x, |
|
x, |
|
x, |
|
need_weights=False, |
|
attn_mask=self.attn_mask, |
|
)[0] |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = x + self.attention(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__( |
|
self, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
attn_mask: torch.Tensor = None, |
|
): |
|
super().__init__() |
|
self.width = width |
|
self.layers = layers |
|
self.resblocks = nn.Sequential( |
|
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.resblocks(x) |
|
|
|
|
|
class ConditionalViT(nn.Module): |
|
def __init__( |
|
self, |
|
input_resolution: int, |
|
patch_size: int, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
output_dim: int, |
|
): |
|
super().__init__() |
|
self.input_resolution = input_resolution |
|
self.output_dim = output_dim |
|
self.conv1 = nn.Conv2d( |
|
in_channels=3, |
|
out_channels=width, |
|
kernel_size=patch_size, |
|
stride=patch_size, |
|
bias=False, |
|
) |
|
|
|
scale = width**-0.5 |
|
|
|
self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
|
|
|
self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width)) |
|
|
|
self.positional_embedding = nn.Parameter( |
|
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) |
|
) |
|
self.ln_pre = LayerNorm(width) |
|
|
|
self.transformer = Transformer(width, layers, heads) |
|
self.ln_post = LayerNorm(width) |
|
self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052) |
|
|
|
self.proj = nn.Linear(width, output_dim, bias=False) |
|
|
|
def forward(self, imgs: torch.Tensor, c: torch.Tensor = None): |
|
""" |
|
imgs : Batch of images |
|
c : Text embedding. |
|
""" |
|
|
|
x = self.conv1(imgs) |
|
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] |
|
pos_embed = [self.positional_embedding] |
|
|
|
if c is not None: |
|
pos_embed += [self.c_pos_embedding] |
|
tokens += [c.unsqueeze(1)] |
|
|
|
x = torch.cat(tokens, dim=1) |
|
pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) |
|
|
|
x = x + pos_embed |
|
x = self.ln_pre(x) |
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
x = self.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
|
|
x = self.ln_post(x[:, 0, :]) |
|
|
|
x = self.proj(x) |
|
|
|
return x |
|
|