nz commited on
Commit
b6368ec
1 Parent(s): c7f01cd

Create rita_modeling.py

Browse files
Files changed (1) hide show
  1. rita_modeling.py +281 -0
rita_modeling.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Hugging Face's logo Hugging Face
3
+
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Pricing
9
+ Please check your email address for a confirmation link
10
+ nz
11
+ /
12
+ RITA_s
13
+ PyTorch
14
+ Transformers
15
+ rita
16
+ Model card
17
+ Files and versions
18
+ Settings
19
+ RITA_s
20
+ / rita_modeling.py
21
+ nz's picture
22
+ nz
23
+ Update rita_modeling.py
24
+ b98e56e
25
+ about 21 hours ago
26
+ raw
27
+ history
28
+ blame
29
+ edit
30
+ delete
31
+ Safe
32
+ 8.68 kB
33
+ import math
34
+ import os
35
+ from dataclasses import dataclass
36
+ from typing import Optional, Tuple, Union
37
+
38
+ import torch
39
+ import torch.utils.checkpoint
40
+ from torch import nn
41
+ from torch.nn import CrossEntropyLoss
42
+
43
+ from transformers.modeling_outputs import (
44
+ BaseModelOutputWithPast,
45
+ BaseModelOutputWithPastAndCrossAttentions,
46
+ CausalLMOutputWithCrossAttentions,
47
+ CausalLMOutputWithPast,
48
+ )
49
+
50
+ from transformers.modeling_utils import PreTrainedModel
51
+ from transformers.utils import logging
52
+
53
+ from .rita_configuration import RITAConfig
54
+ import torch.nn.functional as F
55
+ logger = logging.get_logger(__name__)
56
+
57
+ @torch.jit.script
58
+ def RITA_gelu(hidden_states):
59
+ return hidden_states * 0.5 * (1.0 + torch.tanh(0.79788456 * hidden_states * (1 + 0.044715 * hidden_states * hidden_states)))
60
+
61
+ class RITAGELU(nn.Module):
62
+ def __init__(self):
63
+ super().__init__()
64
+
65
+ def forward(self, hidden_states):
66
+ return RITA_gelu(hidden_states)
67
+
68
+ def rotate_half(x):
69
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
70
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
71
+
72
+ class RotaryEmbedding(nn.Module):
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ assert config.d_model % config.num_heads == 0
76
+
77
+ self.d_model = config.d_model
78
+ self.num_heads = config.num_heads
79
+ self.max_seq_len = config.max_seq_len
80
+
81
+ head_dim = self.d_model // self.num_heads
82
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
83
+ self.register_buffer('inv_freq', inv_freq)
84
+ self.seq_len_cached = None
85
+ self.cos_cached = None
86
+ self.sin_cached = None
87
+
88
+ def forward(self, x: torch.FloatTensor, seq_dim=1) -> torch.FloatTensor:
89
+ seq_len = x.shape[seq_dim]
90
+ if seq_len != self.seq_len_cached:
91
+ self.seq_len_cached = seq_len
92
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
93
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
94
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
95
+ self.cos_cached = emb.cos()[None, None, :, :]
96
+ self.sin_cached = emb.sin()[None, None, :, :]
97
+ return self.cos_cached, self.sin_cached
98
+
99
+ def apply_rotary_pos_emb(self, q, k, cos, sin):
100
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
101
+
102
+
103
+ class SelfAttention(nn.Module):
104
+ """Implementation of MultiHeadAttention following `Karpathy's MinGPT <https://github.com/karpathy/minGPT>`_.
105
+ modified to use rotary embeddings.
106
+
107
+ Parameters
108
+ ----------
109
+ d_model: int,
110
+ total dimension of the model.
111
+ num_heads: int,
112
+ number of parallel attention heads.
113
+ num_layers: int,
114
+ number of layers in the model, used for the Megatron-like init.
115
+ rotaty_embedding: Optional[Block], default None,
116
+ a RotaryEmbedding Block to add positionnal information in Queries and Keys
117
+ dropout: float, default 0.1,
118
+ amount of dropout on the attention weights.
119
+ sigma: float, default 0.02,
120
+ standard deviation used for the init.
121
+ trainable: bool, default True,
122
+ if False, the Module parameters will be hidden from the optimizer.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ d_model: int,
128
+ num_heads: int,
129
+ num_layers: int,
130
+ rotary_embedding= None,
131
+ dropout: float = 0.1,
132
+ sigma=0.02,
133
+ use_cache: bool = False,
134
+ bias=True,
135
+ ):
136
+ super().__init__()
137
+ assert d_model % num_heads == 0
138
+ self.d_model = d_model
139
+ self.num_heads = num_heads
140
+ self.head_dim = self.d_model // self.num_heads
141
+ self.num_layers = num_layers
142
+ self.dropout = dropout
143
+ self.sigma = sigma
144
+ self.bias = bias
145
+
146
+ # key, query, value projections for all heads
147
+ self.key = nn.Linear(d_model, d_model, bias=bias)
148
+ self.query = nn.Linear(d_model, d_model, bias=bias)
149
+ self.value = nn.Linear(d_model, d_model, bias=bias)
150
+ # regularization
151
+ self.attn_drop = nn.Dropout(dropout)
152
+ self.resid_drop = nn.Dropout(dropout)
153
+ # output projection
154
+ self.proj = nn.Linear(d_model, d_model, bias=bias)
155
+
156
+ self.rotary_embedding = rotary_embedding
157
+ self.layer_id = None # will be set by the Transformer itself
158
+ self.use_cache = use_cache
159
+ self.qkv = None
160
+ self.bias = bias
161
+
162
+ def forward(
163
+ self,
164
+ x,
165
+ attn_mask: Optional[torch.BoolTensor] = None,
166
+ padding_mask: Optional[torch.BoolTensor] = None,
167
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
168
+
169
+ N, L, D = x.size() # Batch_size, Context_size, d_model
170
+
171
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
172
+ k = (
173
+ self.key(x).view(N, L, self.num_heads, D // self.num_heads).transpose(1, 2)
174
+ ) # (N, nh, L, hs)
175
+ q = (
176
+ self.query(x).view(N, L, self.num_heads, D // self.num_heads).transpose(1, 2)
177
+ ) # (N, nh, L, hs)
178
+ v = (
179
+ self.value(x).view(N, L, self.num_heads, D // self.num_heads).transpose(1, 2)
180
+ ) # (N, nh, L, hs)
181
+
182
+ if self.rotary_embedding is not None:
183
+ cos, sin = self.rotary_embedding(x)
184
+ q, k = self.rotary_embedding.apply_rotary_pos_emb(q, k, cos, sin)
185
+
186
+ # causal self-attention; Self-attend: (N, nh, L, hs) x (N, nh, hs, L) -> (N, nh, L, L)
187
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
188
+
189
+ if attn_mask is not None:
190
+ att[:,:,-L:, -L: ].masked_fill_(attn_mask.view(1, 1, L, L), float("-inf"))
191
+
192
+ att = (
193
+ att.transpose(0, 2)
194
+ .masked_fill(padding_mask.view(1, 1, N, L), float("-inf"))
195
+ .transpose(0, 2)
196
+ if padding_mask is not None
197
+ else att
198
+ )
199
+
200
+ att = F.softmax(att, dim=-1)
201
+ att = self.attn_drop(att)
202
+ y = att @ v # (N, nh, L, L) x (N, nh, L, hs) -> (N, nh, L, hs)
203
+ y = (
204
+ y.transpose(1, 2).contiguous().view(N, L, D)
205
+ ) # re-assemble all head outputs side by side
206
+
207
+ # output projection
208
+ y = self.resid_drop(self.proj(y))
209
+ return y
210
+
211
+ class DecoderLayer(nn.Module):
212
+ """Transformer block containing the self-attention module and the feedfoward module."""
213
+
214
+ def __init__(
215
+ self, config
216
+ ):
217
+ super().__init__()
218
+ self.self_attention = SelfAttention(config.d_model, config.num_heads, config.dropout, rotary_embedding=RotaryEmbedding(config))
219
+ self.attn_norm = nn.LayerNorm(config.d_model)
220
+ self.attn_dropout = nn.Dropout(config.dropout)
221
+
222
+ self.mlp = nn.Sequential(
223
+ nn.Linear(config.d_model, config.d_feedforward, bias=True),
224
+ RITAGELU(),
225
+ nn.Linear(config.d_feedforward, config.d_model, bias=True),
226
+ )
227
+ self.mlp_norm = nn.LayerNorm(config.d_model)
228
+ self.mlp_dropout = nn.Dropout(config.dropout)
229
+
230
+ def forward(
231
+ self,
232
+ x: torch.FloatTensor,
233
+ attn_mask: torch.BoolTensor,
234
+ padding_mask: Optional[torch.BoolTensor] = None,
235
+ ) -> torch.FloatTensor:
236
+ y = self.attn_norm(x)
237
+ y = self.self_attention(y, attn_mask=attn_mask, padding_mask=padding_mask)
238
+ x = x + self.attn_dropout(y)
239
+
240
+ y = self.mlp_norm(x)
241
+ y = self.mlp(y)
242
+ x = x + self.mlp_dropout(y)
243
+ return x
244
+
245
+ class RITAModel(PreTrainedModel):
246
+ config_class = RITAConfig
247
+ def __init__(
248
+ self,
249
+ config
250
+ ):
251
+ super().__init__(config)
252
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
253
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
254
+ self.final_norm = nn.LayerNorm(config.d_model)
255
+ self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
256
+
257
+ def forward(self, input_ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
258
+ x = self.embedding(input_ids) # N x L x D
259
+ if attn_mask == None:
260
+ attn_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
261
+ for layer in self.layers:
262
+ x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
263
+ x = self.final_norm(x) # N x L x D
264
+
265
+ if return_hidden:
266
+ return x
267
+ else:
268
+ return self.projector(x)
269
+
270
+ #Some common HF functions.
271
+ def get_input_embeddings(self):
272
+ return self.embedding
273
+
274
+ def set_input_embeddings(self, new_embeddings):
275
+ self.embedding = new_embeddings
276
+
277
+ def get_output_embeddings(self):
278
+ return self.projector
279
+
280
+ def set_output_embeddings(self, new_projector):
281
+ self.projector = new_projector