Spaces:
Runtime error
Runtime error
File size: 3,382 Bytes
891b88f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""
Transformer implementation adapted from CLIP ViT:
https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
"""
import math
import torch as th
import torch.nn as nn
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
class LayerNorm(nn.LayerNorm):
"""
Implementation that supports fp16 inputs but fp32 gains/biases.
"""
def forward(self, x: th.Tensor):
return super().forward(x.float()).to(x.dtype)
class MultiheadAttention(nn.Module):
def __init__(self, n_ctx, width, heads):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadAttention(heads, n_ctx)
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.c_proj(x)
return x
class MLP(nn.Module):
def __init__(self, width):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4)
self.c_proj = nn.Linear(width * 4, width)
self.gelu = nn.GELU()
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class QKVMultiheadAttention(nn.Module):
def __init__(self, n_heads: int, n_ctx: int):
super().__init__()
self.n_heads = n_heads
self.n_ctx = n_ctx
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.n_heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
q, k, v = th.split(qkv, attn_ch, dim=-1)
weight = th.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = th.softmax(weight.float(), dim=-1).type(wdtype)
return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
heads: int,
):
super().__init__()
self.attn = MultiheadAttention(
n_ctx,
width,
heads,
)
self.ln_1 = LayerNorm(width)
self.mlp = MLP(width)
self.ln_2 = LayerNorm(width)
def forward(self, x: th.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
layers: int,
heads: int,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx,
width,
heads,
)
for _ in range(layers)
]
)
def forward(self, x: th.Tensor):
for block in self.resblocks:
x = block(x)
return x
|