|
""" |
|
OpenAI's GPT-2 ported to PyTorch. |
|
""" |
|
import math |
|
|
|
import attr |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torch.utils.checkpoint |
|
|
|
|
|
@attr.s(auto_attribs=True, frozen=True) |
|
class HParams: |
|
n_vocab: int |
|
n_ctx: int |
|
n_embed: int |
|
n_hidden: int |
|
n_head: int |
|
n_layer: int |
|
gradient_checkpointing: bool = False |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, hparams: HParams): |
|
super().__init__() |
|
self.hparams = hparams |
|
self.wpe = nn.Embedding(hparams.n_ctx, hparams.n_embed) |
|
nn.init.normal_(self.wpe.weight, std=0.01) |
|
self.wte = nn.Embedding(hparams.n_vocab, hparams.n_embed) |
|
nn.init.normal_(self.wte.weight, std=0.02) |
|
self.blocks = nn.ModuleList( |
|
[Block(hparams) for _ in range(hparams.n_layer)]) |
|
self.ln_f = Norm(self.hparams.n_hidden) |
|
if hparams.n_hidden != hparams.n_embed: |
|
self.in_proj = Conv1D(hparams.n_embed, hparams.n_hidden) |
|
self.out_proj = Conv1D(hparams.n_hidden, hparams.n_embed) |
|
else: |
|
self.in_proj = self.out_proj = None |
|
|
|
def forward(self, x, past=None): |
|
|
|
past_length = 0 if past is None else past.shape[-2] |
|
batch_size, n_ctx = x.shape |
|
position = position_for(batch_size, n_ctx, past_length, x.device) |
|
h = self.wte(x) + self.wpe(position) |
|
assert h.shape == (batch_size, n_ctx, self.hparams.n_embed) |
|
if self.in_proj: |
|
h = self.in_proj(h) |
|
|
|
presents = [] |
|
for i, block in enumerate(self.blocks): |
|
if self.hparams.gradient_checkpointing: |
|
h, present = torch.utils.checkpoint.checkpoint( |
|
block, h, past[:, i] if past is not None else None) |
|
else: |
|
h, present = block( |
|
h, past=past[:, i] if past is not None else None) |
|
presents.append(present) |
|
h = self.ln_f(h) |
|
if self.out_proj: |
|
h = self.out_proj(h) |
|
|
|
h_flat = h.reshape([batch_size * n_ctx, self.hparams.n_embed]) |
|
logits = torch.matmul(h_flat, self.wte.weight.t()) |
|
logits = logits.reshape([batch_size, n_ctx, self.hparams.n_vocab]) |
|
return { |
|
'presents': torch.stack(tuple(presents), dim=1), |
|
'logits': logits, |
|
} |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, hparams: HParams): |
|
super().__init__() |
|
self.ln_1 = Norm(hparams.n_hidden) |
|
self.ln_2 = Norm(hparams.n_hidden) |
|
self.mlp = MLP(hparams.n_hidden, hparams.n_hidden * 4) |
|
self.attn = Attention(hparams) |
|
|
|
def forward(self, x, past): |
|
a, present = self.attn(self.ln_1(x), past=past) |
|
x = x + a |
|
m = self.mlp(self.ln_2(x)) |
|
x = x + m |
|
return x, present |
|
|
|
|
|
class Norm(nn.Module): |
|
""" Normalize to mean = 0, std = 1, then do a diagonal affine transform. |
|
""" |
|
def __init__(self, n_features, *, dim=-1, epsilon=1e-5): |
|
super().__init__() |
|
self.n_features = n_features |
|
self.dim = dim |
|
self.epsilon = epsilon |
|
self.g = nn.Parameter(torch.ones(n_features)) |
|
self.b = nn.Parameter(torch.zeros(n_features)) |
|
|
|
def forward(self, x): |
|
assert x.shape[-1] == self.n_features |
|
u = torch.mean(x, dim=self.dim, keepdim=True) |
|
xmu = x - u |
|
s = torch.mean(xmu * xmu, dim=self.dim, keepdim=True) |
|
return xmu * torch.rsqrt(s + self.epsilon) * self.g + self.b |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, n_features, n_hidden): |
|
super().__init__() |
|
self.c_fc = Conv1D(n_features, n_hidden) |
|
self.c_proj = Conv1D(n_hidden, n_features) |
|
|
|
def forward(self, x): |
|
x = gelu(self.c_fc(x)) |
|
x = self.c_proj(x) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, hparams: HParams): |
|
super().__init__() |
|
assert hparams.n_hidden % hparams.n_head == 0 |
|
self.hparams = hparams |
|
self.c_attn = Conv1D(hparams.n_hidden, hparams.n_hidden * 3) |
|
self.c_proj = Conv1D(hparams.n_hidden, hparams.n_hidden) |
|
|
|
def forward(self, x, past): |
|
assert len(x.shape) == 3 |
|
assert x.shape[-1] == self.hparams.n_hidden |
|
if past is not None: |
|
|
|
assert len(past.shape) == 5 |
|
assert past.shape[-1] == self.hparams.n_hidden |
|
c = self.c_attn(x) |
|
q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2)) |
|
present = torch.stack([k, v], dim=1) |
|
if past is not None: |
|
pk, pv = past[:, 0], past[:, 1] |
|
k = torch.cat([pk, k], dim=-2) |
|
v = torch.cat([pv, v], dim=-2) |
|
a = self.multihead_attn(q, k, v) |
|
a = self.merge_heads(a) |
|
a = self.c_proj(a) |
|
return a, present |
|
|
|
def split_heads(self, x): |
|
""" From [batch, sequence, features] to |
|
[batch, heads, sequence, features]. |
|
""" |
|
return self.split_states(x, self.hparams.n_head).permute(0, 2, 1, 3) |
|
|
|
@staticmethod |
|
def split_states(x, n): |
|
""" Reshape the last dimension of x into [n, x.shape[-1]/n]. |
|
""" |
|
*start, m = x.shape |
|
return x.reshape(start + [n, m // n]) |
|
|
|
def merge_heads(self, x): |
|
""" Reverse of split_heads. |
|
""" |
|
return self.merge_states(x.permute(0, 2, 1, 3)) |
|
|
|
@staticmethod |
|
def merge_states(x): |
|
""" Smash the last two dimensions of x into a single dimension. |
|
""" |
|
*start, a, b = x.shape |
|
return x.reshape(start + [a * b]) |
|
|
|
def mask_attn_weights(self, w): |
|
|
|
|
|
_, _, nd, ns = w.shape |
|
b = self.attention_mask(nd, ns, dtype=w.dtype, device=w.device) |
|
b = b.reshape((1, 1, nd, ns)) |
|
w = w * b - 1e4 * (1 - b) |
|
return w |
|
|
|
@staticmethod |
|
def attention_mask(nd, ns, *, dtype, device=None): |
|
""" 1's in the lower triangle, counting from the lower right corner. |
|
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), |
|
but doesn't produce garbage on TPUs. |
|
""" |
|
i = torch.arange(0, nd).unsqueeze(1) |
|
j = torch.arange(ns) |
|
return (i >= j - ns + nd).to(dtype=dtype, device=device) |
|
|
|
def multihead_attn(self, q, k, v): |
|
|
|
w = torch.matmul(q, k.permute(0, 1, 3, 2)) |
|
w = w / math.sqrt(v.shape[-1]) |
|
w = self.mask_attn_weights(w) |
|
w = F.softmax(w, dim=-1) |
|
a = torch.matmul(w, v) |
|
return a |
|
|
|
|
|
class Conv1D(nn.Linear): |
|
def reset_parameters(self): |
|
nn.init.normal_(self.weight, std=0.02) |
|
nn.init.zeros_(self.bias) |
|
|
|
|
|
def gelu(x, c=math.sqrt(2 / math.pi)): |
|
return 0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
|
|
|
def position_for(batch_size, n_steps, past_length, device=None): |
|
return (torch.arange(past_length, n_steps + past_length, device=device) |
|
.unsqueeze(0).repeat(batch_size, 1)) |
|
|