Spaces:
Runtime error
Runtime error
import math | |
from collections import OrderedDict | |
from typing import List, Optional, Tuple, cast | |
import attr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .attention import ( | |
AttentionInfo, | |
DenseAttentionMask, | |
DenseCausalAttentionMask, | |
make_full_layout, | |
to_attention_info, | |
) | |
from .utils import Affine, LayerNorm, zero_key_bias_grad | |
# Constants used in the original CLIP implementation. | |
image_channel_means = [122.77093945, 116.74601272, 104.09373519] | |
image_channel_stds = [68.50053285, 66.63215831, 70.32316309] | |
class TextEmbedding(nn.Module): | |
n_vocab: int = attr.ib() | |
n_context: int = attr.ib() | |
n_state: int = attr.ib() | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
w_voc = torch.empty((self.n_vocab, self.n_state), dtype=torch.float32, device=self.device) | |
w_pos = torch.empty((self.n_context, self.n_state), dtype=torch.float32, device=self.device) | |
with torch.no_grad(): | |
w_voc.normal_(std=0.02) | |
w_pos.normal_(std=0.01) | |
self.w_voc = nn.Parameter(w_voc) | |
self.w_pos = nn.Parameter(w_pos) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if len(x.shape) != 2: | |
raise ValueError() | |
return F.embedding(x, self.w_voc) + self.w_pos[None, :, :] | |
class ImageEmbedding(nn.Module): | |
image_size: int = attr.ib() | |
patch_size: int = attr.ib() | |
n_state: int = attr.ib() | |
n_timestep: int = attr.ib(default=0) | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
if self.image_size % self.patch_size != 0: | |
raise ValueError() | |
n_patch = self.image_size // self.patch_size | |
patch_proj = torch.empty( | |
(self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device | |
) | |
w_pos = torch.empty( | |
(1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device | |
) | |
with torch.no_grad(): | |
if self.n_timestep == 0: | |
pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device) | |
pred_state.normal_(std=1 / np.sqrt(self.n_state)) | |
self.pred_state = nn.Parameter(pred_state) | |
else: | |
w_t = torch.empty( | |
(self.n_timestep, self.n_state), dtype=torch.float32, device=self.device | |
) | |
w_t.normal_(std=1 / np.sqrt(self.n_state)) | |
self.w_t = nn.Parameter(w_t) | |
patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2))) | |
w_pos.normal_(std=1 / np.sqrt(self.n_state)) | |
self.patch_proj = nn.Parameter(patch_proj) | |
self.w_pos = nn.Parameter(w_pos) | |
self.channel_means = torch.tensor( | |
image_channel_means, dtype=torch.float32, device=self.device | |
)[None, :, None, None] | |
self.channel_stds = torch.tensor( | |
image_channel_stds, dtype=torch.float32, device=self.device | |
)[None, :, None, None] | |
self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) | |
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor: | |
if len(x.shape) != 4: | |
raise ValueError("input should be 4d") | |
if x.shape[1] != 3: | |
raise ValueError("input should have 3 channels") | |
if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size): | |
raise ValueError(f"input is not {self.image_size} x {self.image_size}") | |
if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None): | |
raise ValueError() | |
if self.n_timestep != 0: | |
assert t is not None | |
if len(t.shape) != 1: | |
raise ValueError() | |
if t.shape[0] != x.shape[0]: | |
raise ValueError() | |
x = (x - self.channel_means) / self.channel_stds | |
x = F.conv2d(x, self.patch_proj, stride=self.patch_size) | |
x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute( | |
0, 2, 1 | |
) | |
sot = ( | |
self.pred_state[None, None].expand(x.shape[0], -1, -1) | |
if self.n_timestep == 0 | |
else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None] | |
) | |
x = torch.cat((sot, x), dim=1) + self.w_pos[None] | |
return self.ln(x) | |
class AttentionResblock(nn.Module): | |
n_state: int = attr.ib() | |
n_resblocks: int = attr.ib() | |
attn_fn: AttentionInfo = attr.ib() | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.n_head_state = self.n_state // self.attn_fn.n_heads | |
self.qk_scale = 1 / np.sqrt(self.n_head_state) | |
self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) | |
self.f_q = Affine( | |
self.n_state, | |
self.n_state, | |
std=1 / math.sqrt(self.n_state), | |
use_bias=True, | |
bias_filter_fn=zero_key_bias_grad, | |
device=self.device, | |
) | |
self.f_k = Affine( | |
self.n_state, | |
self.n_state, | |
std=1 / math.sqrt(self.n_state), | |
use_bias=False, | |
bias_filter_fn=zero_key_bias_grad, | |
device=self.device, | |
) | |
self.f_v = Affine( | |
self.n_state, | |
self.n_state, | |
std=1 / math.sqrt(self.n_state), | |
use_bias=True, | |
bias_filter_fn=zero_key_bias_grad, | |
device=self.device, | |
) | |
self.f_c = Affine( | |
self.n_state, | |
self.n_state, | |
use_bias=True, | |
std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), | |
device=self.device, | |
) # XXX | |
def forward(self, m: torch.Tensor) -> torch.Tensor: | |
n_context = m.shape[1] | |
n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context | |
n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context | |
assert n_query_pad >= 0 | |
assert n_key_pad >= 0 | |
r = m | |
r = self.ln(r) | |
q, k, v = self.f_q(r), self.f_k(r), self.f_v(r) | |
if n_query_pad != 0: | |
q = F.pad(q, (0, 0, 0, n_query_pad)) | |
if n_key_pad != 0: | |
k = F.pad(k, (0, 0, 0, n_key_pad)) | |
v = F.pad(v, (0, 0, 0, n_key_pad)) | |
q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) | |
k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) | |
v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) | |
w = torch.einsum( | |
"bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale) | |
) | |
if hasattr(self.attn_fn, "pytorch_attn_bias"): | |
bias = self.attn_fn.pytorch_attn_bias | |
assert len(bias.shape) in {2, 3} | |
if len(bias.shape) == 2: | |
w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1) | |
elif len(bias.shape) == 3: | |
w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1) | |
else: | |
w = torch.softmax(w, dim=-1) | |
r = torch.einsum("bhck,bhkd->bhcd", w, v) | |
r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state)) | |
if n_query_pad != 0: | |
r = r[:, :-n_query_pad] | |
assert r.shape[1] == n_context | |
r = self.f_c(r) | |
return m + r | |
class FullyConnectedResblock(nn.Module): | |
""" | |
Not imported from other files because we retain Alec's original inits. | |
""" | |
n_state: int = attr.ib() | |
n_resblocks: int = attr.ib() | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) | |
self.f_1 = Affine( | |
self.n_state, | |
4 * self.n_state, | |
use_bias=True, | |
std=np.sqrt(2 / (4 * self.n_state)), | |
device=self.device, | |
) | |
self.f_2 = Affine( | |
4 * self.n_state, | |
self.n_state, | |
use_bias=True, | |
std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), | |
device=self.device, | |
) # XXX | |
def forward(self, m: torch.Tensor) -> torch.Tensor: | |
r = m | |
r = self.ln(r) | |
r = self.f_2(F.gelu(self.f_1(r))) | |
return m + r | |
class TransformerBlock(nn.Module): | |
n_state: int = attr.ib() | |
n_resblocks: int = attr.ib() | |
attn_fn: AttentionInfo = attr.ib() | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.f_attn = AttentionResblock( | |
self.n_state, | |
self.n_resblocks, | |
self.attn_fn, | |
self.device, | |
) | |
self.f_mlp = FullyConnectedResblock(self.n_state, self.n_resblocks, self.device) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.f_mlp(self.f_attn(x)) | |
class TextFeatureExtractor(nn.Module): | |
n_state: int = attr.ib() | |
n_embd: int = attr.ib() | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) | |
self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) | |
def forward( | |
self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False | |
) -> torch.Tensor: | |
if len(text.shape) != 3: | |
raise ValueError("expected text to be 3d") | |
if len(text_len.shape) != 1: | |
raise ValueError("expected text length to be 1d") | |
if text.shape[0] != text_len.shape[0]: | |
raise ValueError("text and text_len have inconsistent batch dimensions") | |
index = (text_len - 1)[:, None, None].expand(-1, 1, text.shape[2]) | |
x = torch.gather(text, dim=1, index=index) | |
assert list(x.shape) == [text.shape[0], 1, text.shape[2]] | |
if return_probe_features: | |
return x[:, 0] | |
x = self.ln(x) | |
return self.f(x[:, 0]) | |
class ImageFeatureExtractor(nn.Module): | |
n_state: int = attr.ib() | |
n_embd: int = attr.ib() | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) | |
self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) | |
def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch.Tensor: | |
if return_probe_features: | |
return x[:, 0] | |
x = self.ln(x[:, :1]) | |
return self.f(x[:, 0]) | |
class TextEncoder(nn.Module): | |
n_bpe_vocab: int = attr.ib() | |
max_text_len: int = attr.ib() | |
n_embd: int = attr.ib() | |
n_head: int = attr.ib() | |
n_xf_blocks: int = attr.ib() | |
n_head_state: int = attr.ib(default=64) | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
block_size: int = attr.ib(init=False, default=32) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.n_state = self.n_head * self.n_head_state | |
n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size)) | |
n_pad = n_rounded_context - self.max_text_len | |
args = ( | |
n_rounded_context, | |
n_rounded_context, | |
self.block_size, | |
self.n_head, | |
False, | |
n_pad, | |
n_pad, | |
) | |
mask = DenseCausalAttentionMask(*args) | |
attn_fn = to_attention_info(mask) | |
m = 1 - make_full_layout(mask).astype(np.float32) | |
m[m == 1] = -1e10 | |
attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) | |
blocks: List[Tuple[str, nn.Module]] = [ | |
( | |
"input", | |
TextEmbedding( | |
self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device | |
), | |
) | |
] | |
for i in range(self.n_xf_blocks): | |
blocks.append( | |
( | |
f"block_{i}", | |
TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), | |
) | |
) | |
blocks.append( | |
("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device)) | |
) | |
self.blocks = nn.ModuleDict(OrderedDict(blocks)) | |
def forward( | |
self, | |
text: torch.Tensor, | |
text_len: torch.Tensor, | |
return_probe_features: bool = False, | |
) -> torch.Tensor: | |
n_batch = text.shape[0] | |
h = self.blocks["input"](text) | |
for i in range(self.n_xf_blocks): | |
h = self.blocks[f"block_{i}"](h) | |
h = self.blocks["output"](h, text_len, return_probe_features=return_probe_features) | |
assert list(h.shape) == [ | |
n_batch, | |
self.n_embd if not return_probe_features else self.n_state, | |
] | |
return h | |
class ImageEncoder(nn.Module): | |
image_size: int = attr.ib() | |
patch_size: int = attr.ib() | |
n_embd: int = attr.ib() | |
n_head: int = attr.ib() | |
n_xf_blocks: int = attr.ib() | |
n_head_state: int = attr.ib(default=64) | |
n_timestep: int = attr.ib(default=0) | |
device: torch.device = attr.ib(default=torch.device("cuda")) | |
block_size: int = attr.ib(init=False, default=32) | |
def __attrs_post_init__(self) -> None: | |
super().__init__() | |
self.n_state = self.n_head * self.n_head_state | |
self.n_context = 1 + (self.image_size // self.patch_size) ** 2 | |
n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size)) | |
n_pad = n_rounded_context - self.n_context | |
args = ( | |
n_rounded_context, | |
n_rounded_context, | |
self.block_size, | |
self.n_head, | |
False, | |
n_pad, | |
n_pad, | |
) | |
mask = DenseAttentionMask(*args) | |
attn_fn = to_attention_info(mask) | |
m = 1 - make_full_layout(mask).astype(np.float32) | |
m[m == 1] = -1e10 | |
attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) | |
blocks: List[Tuple[str, nn.Module]] = [ | |
( | |
"input", | |
ImageEmbedding( | |
self.image_size, | |
self.patch_size, | |
self.n_state, | |
n_timestep=self.n_timestep, | |
device=self.device, | |
), | |
) | |
] | |
for i in range(self.n_xf_blocks): | |
blocks.append( | |
( | |
f"block_{i}", | |
TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), | |
) | |
) | |
blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device))) | |
self.blocks = nn.ModuleDict(OrderedDict(blocks)) | |
def forward( | |
self, | |
image: torch.Tensor, | |
timesteps: Optional[torch.Tensor] = None, | |
return_probe_features: bool = False, | |
) -> torch.Tensor: | |
n_batch = image.shape[0] | |
h = self.blocks["input"](image, t=timesteps) | |
for i in range(self.n_xf_blocks): | |
h = self.blocks[f"block_{i}"](h) | |
h = self.blocks["output"](h, return_probe_features=return_probe_features) | |
assert list(h.shape) == [ | |
n_batch, | |
self.n_embd if not return_probe_features else self.n_state, | |
] | |
return h | |