File size: 7,153 Bytes
5cfa59c
 
 
 
 
 
 
 
e58f28a
5cfa59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e58f28a
5cfa59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e58f28a
5cfa59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e58f28a
5cfa59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e58f28a
5cfa59c
 
e58f28a
5cfa59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import math

import torch
import torch.nn as nn
from torch.nn import functional as f
from transformers import PreTrainedModel
from transformers.activations import ACT2FN

from language_config import BigBrainLanguageConfig


def _make_casual_mask(size: int) -> torch.Tensor:
    return torch.tril(torch.ones(size, size))


class RootMeanSquareNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_eps = eps

    def forward(self, x: torch.Tensor):
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_eps)
        return self.weight * x


class MultiLayerPerceptron(nn.Module):
    def __init__(self, config: BigBrainLanguageConfig):
        super().__init__()
        self.config = config
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim: int, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.base = base
        self.cos = None
        self.sin = None

    def _build_cache(self, x: torch.Tensor):
        if self.cos is not None and x.shape[0] <= self.cos.shape[0]:
            return

        seq_len = x.shape[0]
        theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device)
        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
        idx_theta = torch.einsum('a,b->ab', seq_idx, theta)
        idx_theta = torch.cat([idx_theta, idx_theta], dim=1)

        self.cos = idx_theta.cos()[:, None, None, :]
        self.sin = idx_theta.sin()[:, None, None, :]

    def _neg_half(self, x: torch.Tensor):
        d_2 = self.dim // 2
        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

    def forward(self, x: torch.Tensor):
        self._build_cache(x)
        x_rope, x_pass = x[..., :self.dim], x[..., self.dim:]
        neg_half_x = self._neg_half(x_rope)
        x_rope = (x_rope * self.cos[:x.shape[0]]) + (neg_half_x * self.sin[:x.shape[0]])
        return torch.cat((x_rope, x_pass), dim=-1)


class RotaryMultiHeadAttention(nn.Module):
    def __init__(self, config: BigBrainLanguageConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads

        if (self.head_dim * config.num_attention_heads) != config.hidden_size:
            raise ValueError('num_embedd must be evenly divisible by num_heads')

        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.rope_e = RotaryPositionalEmbedding(self.head_dim, config.rope_theta)

    def _shape(self, tensor: torch.Tensor, batch_size: int, seq_len: int):
        return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _reshape(self, tensor: torch.Tensor, batch_size: int, seq_len: int):
        return tensor.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.hidden_size)

    def forward(self, states: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len, _ = states.size()

        q_states = self.rope_e(self._shape(self.q_proj(states), batch_size, seq_len))
        k_states = self.rope_e(self._shape(self.k_proj(states), batch_size, seq_len))
        v_states = self._shape(self.v_proj(states), batch_size, seq_len)

        attn_weights = torch.matmul(q_states, k_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        attn_weights = torch.clamp(attn_weights, min=-1024.0, max=1024.0)

        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        attn_weights = f.softmax(attn_weights, dim=-1)
        attn_outputs = torch.matmul(attn_weights, v_states)
        return self._reshape(attn_outputs, batch_size, seq_len)


class BigBrainDecoderLayer(nn.Module):
    def __init__(self, config: BigBrainLanguageConfig):
        super().__init__()
        self.config = config
        self.self_attn = RotaryMultiHeadAttention(config)
        self.feed_forward = MultiLayerPerceptron(config)
        self.input_norm = RootMeanSquareNorm(config.hidden_size, config.layer_norm_eps)
        self.attn_norm = RootMeanSquareNorm(config.hidden_size, config.layer_norm_eps)
        self.register_buffer('attn_mask', _make_casual_mask(config.max_position_embeddings))

    def forward(self, x: torch.Tensor):
        batch_size, seq_len, _ = x.size()
        mask = self.attn_mask[:seq_len, :seq_len]
        x = x + self.self_attn(self.input_norm(x), mask)
        x = x + self.feed_forward(self.attn_norm(x))
        return x


class BigBrainLanguageModel(PreTrainedModel):
    config_class = BigBrainLanguageConfig
    base_model_prefix = 'big-brain-lm'

    def __init__(self, config: BigBrainLanguageConfig):
        super().__init__(config)
        self.config = config
        self.tok_embed = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
        self.layers = nn.ModuleList([BigBrainDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = RootMeanSquareNorm(config.hidden_size, config.layer_norm_eps)
        self.linear = nn.Linear(config.hidden_size, config.vocab_size)
        self.post_init()

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def forward(self, input_ids: torch.Tensor, target_ids: torch.Tensor = None):
        hidden_states = self.tok_embed(input_ids)
        for decoder_layer in self.layers:
            hidden_states = decoder_layer(hidden_states)
        hidden_states = self.norm(hidden_states)
        hidden_states = self.linear(hidden_states)

        if target_ids is None:
            return hidden_states, None

        b, t, c = hidden_states.size()
        loss = f.cross_entropy(hidden_states.view(b * t, c), target_ids.view(b * t))
        return hidden_states, loss