aapot
commited on
Commit
•
0394e28
1
Parent(s):
f7552ca
Update EasyLM
Browse files- EasyLM/bpt.py +228 -0
- EasyLM/data.py +26 -7
- EasyLM/jax_utils.py +1 -0
- EasyLM/models/gptj/gptj_serve.py +3 -3
- EasyLM/models/gptj/gptj_train.py +3 -5
- EasyLM/models/llama/convert_easylm_to_hf.py +18 -10
- EasyLM/models/llama/convert_hf_to_easylm.py +196 -0
- EasyLM/models/llama/llama_model.py +170 -69
- EasyLM/models/llama/llama_serve.py +3 -3
- EasyLM/models/llama/llama_train.py +5 -8
- EasyLM/models/roberta/roberta_train.py +3 -5
- EasyLM/optimizers.py +1 -2
- EasyLM/scripts/benchmark_attention.py +150 -0
- EasyLM/scripts/lm_eval_harness.py +5 -1
EasyLM/bpt.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
3 |
+
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
|
4 |
+
"""
|
5 |
+
|
6 |
+
import functools
|
7 |
+
from typing import NamedTuple
|
8 |
+
|
9 |
+
import flax.linen as nn
|
10 |
+
import jax
|
11 |
+
import jax.lax as lax
|
12 |
+
import jax.numpy as jnp
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
"""
|
16 |
+
Computing ffn blockwise without materializing the large hidden tensor, training
|
17 |
+
4x longer sequences than the memory-efficient transformer.
|
18 |
+
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
|
19 |
+
"""
|
20 |
+
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
|
21 |
+
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
|
22 |
+
# inputs: (batch, seq_len, dim)
|
23 |
+
# chunk_size: the chunk size to split the sequence
|
24 |
+
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
|
25 |
+
def scan_ffn(remat_ffn, carry, hidden_states):
|
26 |
+
outputs = remat_ffn(hidden_states, deterministic=deterministic)
|
27 |
+
return carry, outputs
|
28 |
+
scan_axis = inputs.ndim - 2
|
29 |
+
_, res = nn.scan(
|
30 |
+
scan_ffn,
|
31 |
+
variable_broadcast="params",
|
32 |
+
split_rngs={"params": False, "dropout": True},
|
33 |
+
in_axes=scan_axis,
|
34 |
+
out_axes=scan_axis,
|
35 |
+
)(remat_ffn, None, inputs)
|
36 |
+
res = rearrange(res, 'b c n d -> b (c n) d')
|
37 |
+
return res
|
38 |
+
|
39 |
+
|
40 |
+
"""
|
41 |
+
Compute attention blockwise without materializing the full attention matrix,
|
42 |
+
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
|
43 |
+
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
|
44 |
+
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
|
45 |
+
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
|
46 |
+
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
|
47 |
+
"""
|
48 |
+
def blockwise_attn(
|
49 |
+
query, key, value,
|
50 |
+
bias=None,
|
51 |
+
deterministic=True,
|
52 |
+
dropout_rng=None,
|
53 |
+
attn_pdrop=0.0,
|
54 |
+
causal=True,
|
55 |
+
query_chunk_size=2048,
|
56 |
+
key_chunk_size=2048,
|
57 |
+
dtype=jnp.float32,
|
58 |
+
policy=jax.checkpoint_policies.nothing_saveable(),
|
59 |
+
precision=None,
|
60 |
+
float32_logits=True,
|
61 |
+
prevent_cse=True,
|
62 |
+
):
|
63 |
+
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
|
64 |
+
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
|
65 |
+
# causal: whether to use causal mask
|
66 |
+
# policy: one of jax.checkpoint_policies
|
67 |
+
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
68 |
+
if float32_logits:
|
69 |
+
query = query.astype(jnp.float32)
|
70 |
+
key = key.astype(jnp.float32)
|
71 |
+
|
72 |
+
batch, q_len, num_heads, dim_per_head = query.shape
|
73 |
+
batch, kv_len, num_heads, dim_per_head = key.shape
|
74 |
+
batch, kv_len, num_heads, dim_per_head = value.shape
|
75 |
+
|
76 |
+
num_q = q_len // query_chunk_size
|
77 |
+
num_kv = kv_len // key_chunk_size
|
78 |
+
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
|
79 |
+
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
80 |
+
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
|
81 |
+
|
82 |
+
query = jnp.moveaxis(query, 1, 0)
|
83 |
+
key = jnp.moveaxis(key, 1, 0)
|
84 |
+
value = jnp.moveaxis(value, 1, 0)
|
85 |
+
|
86 |
+
if bias is not None:
|
87 |
+
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
|
88 |
+
assert bias_dim == 1 or bias_dim == broadcast_dim
|
89 |
+
if not deterministic and attn_pdrop > 0.0:
|
90 |
+
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
|
91 |
+
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
|
92 |
+
else:
|
93 |
+
attn_dropout = None
|
94 |
+
|
95 |
+
_chunk_bias_fn = functools.partial(
|
96 |
+
_chunk_attention_bias,
|
97 |
+
query_chunk_size, key_chunk_size, bias, deterministic,
|
98 |
+
attn_dropout, attn_pdrop, causal, dtype)
|
99 |
+
|
100 |
+
def scan_attention(args):
|
101 |
+
query_chunk, query_chunk_idx = args
|
102 |
+
|
103 |
+
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
|
104 |
+
def scan_kv_block(carry, args):
|
105 |
+
key_chunk, value_chunk, key_chunk_idx = args
|
106 |
+
(numerator, denominator, prev_max_score) = carry
|
107 |
+
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
|
108 |
+
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
|
109 |
+
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
|
110 |
+
attn_weights = attn_weights + bias_chunk
|
111 |
+
|
112 |
+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
113 |
+
max_score = jnp.maximum(prev_max_score, max_score)
|
114 |
+
max_score = jax.lax.stop_gradient(max_score)
|
115 |
+
exp_weights = jnp.exp(attn_weights - max_score)
|
116 |
+
exp_values = jnp.einsum(
|
117 |
+
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
|
118 |
+
)
|
119 |
+
correction = jnp.exp(prev_max_score - max_score)
|
120 |
+
numerator = numerator * correction + exp_values
|
121 |
+
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
|
122 |
+
return Carry(numerator, denominator, max_score), None
|
123 |
+
|
124 |
+
def skip_upper_half(carry, args):
|
125 |
+
key_chunk, value_chunk, key_chunk_idx = args
|
126 |
+
skip_block = jnp.array(False)
|
127 |
+
if causal:
|
128 |
+
skip_block = query_chunk_idx < key_chunk_idx
|
129 |
+
return jax.lax.cond(
|
130 |
+
skip_block,
|
131 |
+
lambda carry, args: (carry, None),
|
132 |
+
scan_kv_block,
|
133 |
+
carry,
|
134 |
+
args,
|
135 |
+
)
|
136 |
+
|
137 |
+
init_carry = Carry(
|
138 |
+
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
139 |
+
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
|
140 |
+
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
|
141 |
+
)
|
142 |
+
(numerator, denominator, max_score), _ = lax.scan(
|
143 |
+
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
|
144 |
+
)
|
145 |
+
outputs = (numerator / denominator).astype(dtype)
|
146 |
+
return outputs
|
147 |
+
|
148 |
+
_, res = lax.scan(
|
149 |
+
lambda _, x: ((), scan_attention(x)),
|
150 |
+
(), xs=(query, jnp.arange(0, num_q))
|
151 |
+
)
|
152 |
+
res = rearrange(res, 'n b c h d -> b (n c) h d')
|
153 |
+
return res
|
154 |
+
|
155 |
+
|
156 |
+
class Carry(NamedTuple):
|
157 |
+
numerator: jax.Array
|
158 |
+
denominator: jax.Array
|
159 |
+
max_so_far: jax.Array
|
160 |
+
|
161 |
+
|
162 |
+
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
|
163 |
+
bias, deterministic, attn_dropout, attn_pdrop, causal,
|
164 |
+
dtype, query_chunk_idx, key_chunk_idx):
|
165 |
+
query_offset = query_chunk_idx * query_chunk_size
|
166 |
+
key_offset = key_chunk_idx * key_chunk_size
|
167 |
+
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
|
168 |
+
if bias is not None:
|
169 |
+
chunk_bias = lax.dynamic_slice(
|
170 |
+
bias,
|
171 |
+
start_indices=(0, 0, query_offset, key_offset),
|
172 |
+
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
|
173 |
+
)
|
174 |
+
|
175 |
+
if causal:
|
176 |
+
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
|
177 |
+
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
|
178 |
+
offset = query_offset - key_offset
|
179 |
+
query_idx += offset
|
180 |
+
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
|
181 |
+
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
|
182 |
+
|
183 |
+
if not deterministic and attn_pdrop > 0.0:
|
184 |
+
attn_dropout_slice = lax.dynamic_slice(
|
185 |
+
attn_dropout,
|
186 |
+
start_indices=(0, 0, query_offset, key_offset),
|
187 |
+
slice_sizes=(
|
188 |
+
*attn_dropout.shape[:2],
|
189 |
+
min(attn_dropout.shape[-2], query_chunk_size),
|
190 |
+
min(attn_dropout.shape[-1], key_chunk_size),
|
191 |
+
),
|
192 |
+
)
|
193 |
+
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
|
194 |
+
return chunk_bias.astype(dtype)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
# test
|
199 |
+
def reference_attn(query, key, value, causal, dtype):
|
200 |
+
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
201 |
+
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
202 |
+
if causal:
|
203 |
+
mask_value = jnp.finfo(logits.dtype).min
|
204 |
+
_, q_seq_len, _, _ = query.shape
|
205 |
+
_, kv_seq_len, _, _ = key.shape
|
206 |
+
mask_shape = (q_seq_len, kv_seq_len)
|
207 |
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
208 |
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
209 |
+
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
210 |
+
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
211 |
+
weights = jax.nn.softmax(logits, axis=-1)
|
212 |
+
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
213 |
+
return out
|
214 |
+
|
215 |
+
# random inputs
|
216 |
+
shape = (1, 32, 8, 64)
|
217 |
+
query = jax.random.normal(jax.random.PRNGKey(0), shape)
|
218 |
+
key = jax.random.normal(jax.random.PRNGKey(1), shape)
|
219 |
+
value = jax.random.normal(jax.random.PRNGKey(2), shape)
|
220 |
+
|
221 |
+
causal = True
|
222 |
+
chunk_size = 4
|
223 |
+
policy = jax.checkpoint_policies.nothing_saveable()
|
224 |
+
|
225 |
+
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
|
226 |
+
reference = reference_attn(query, key, value, causal, 'float32')
|
227 |
+
|
228 |
+
assert jnp.allclose(reference, blockwise, atol=1e-6)
|
EasyLM/data.py
CHANGED
@@ -3,6 +3,7 @@ import pprint
|
|
3 |
import time
|
4 |
from functools import partial
|
5 |
import json
|
|
|
6 |
from multiprocessing import Pool
|
7 |
|
8 |
import h5py
|
@@ -59,6 +60,7 @@ class TextProcessor(object):
|
|
59 |
config.add_bos_token = True
|
60 |
config.add_eos_token = True
|
61 |
config.prepend_text = ''
|
|
|
62 |
if updates is not None:
|
63 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
64 |
return config
|
@@ -95,12 +97,26 @@ class TextProcessor(object):
|
|
95 |
else:
|
96 |
mask = 1.0
|
97 |
|
98 |
-
if field
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
103 |
loss_mask_buffer.append(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
else:
|
105 |
subfields = field.split('+')
|
106 |
text = self.config.subfield_separator.join(
|
@@ -136,6 +152,7 @@ class HuggingfaceDataset(object):
|
|
136 |
config.always_start_with_bos = False
|
137 |
config.start_seek_loc = 0
|
138 |
config.tokens_count_at_start = 0
|
|
|
139 |
|
140 |
if updates is not None:
|
141 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
@@ -163,6 +180,8 @@ class HuggingfaceDataset(object):
|
|
163 |
while True:
|
164 |
token_buffer = []
|
165 |
loss_mask_buffer = []
|
|
|
|
|
166 |
for index, example in enumerate(self._dataset):
|
167 |
self._index = index
|
168 |
if not self._eval_dataset and self._dataset_loc > index:
|
@@ -178,10 +197,10 @@ class HuggingfaceDataset(object):
|
|
178 |
'epoch': self._train_epochs,
|
179 |
}
|
180 |
batch = {
|
181 |
-
'input_tokens': np.array(token_buffer[:chunk_size], dtype=
|
182 |
self.config.batch_size, -1
|
183 |
),
|
184 |
-
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=
|
185 |
self.config.batch_size, -1
|
186 |
),
|
187 |
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
|
|
3 |
import time
|
4 |
from functools import partial
|
5 |
import json
|
6 |
+
import base64
|
7 |
from multiprocessing import Pool
|
8 |
|
9 |
import h5py
|
|
|
60 |
config.add_bos_token = True
|
61 |
config.add_eos_token = True
|
62 |
config.prepend_text = ''
|
63 |
+
config.base64_token_dtype = 'i4'
|
64 |
if updates is not None:
|
65 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
66 |
return config
|
|
|
97 |
else:
|
98 |
mask = 1.0
|
99 |
|
100 |
+
if field.startswith('<|') and field.endswith('|>'):
|
101 |
+
# Special tokens.
|
102 |
+
field = field[2:-2]
|
103 |
+
if field == 'bos':
|
104 |
+
token_buffer.append(self.tokenizer.bos_token_id)
|
105 |
+
elif field == 'eos':
|
106 |
+
token_buffer.append(self.tokenizer.eos_token_id)
|
107 |
+
else:
|
108 |
+
# Token ID specified directly.
|
109 |
+
token_buffer.append(int(field))
|
110 |
loss_mask_buffer.append(mask)
|
111 |
+
elif field.startswith('{') and field.endswith('}'):
|
112 |
+
field = field[1:-1]
|
113 |
+
# Base64 encoded raw tokens.
|
114 |
+
tokens = np.frombuffer(
|
115 |
+
base64.b64decode(example[field]),
|
116 |
+
dtype=self.config.base64_token_dtype
|
117 |
+
).tolist()
|
118 |
+
token_buffer.extend(tokens)
|
119 |
+
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
120 |
else:
|
121 |
subfields = field.split('+')
|
122 |
text = self.config.subfield_separator.join(
|
|
|
152 |
config.always_start_with_bos = False
|
153 |
config.start_seek_loc = 0
|
154 |
config.tokens_count_at_start = 0
|
155 |
+
config.batch_token_dtype = 'i4'
|
156 |
|
157 |
if updates is not None:
|
158 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
|
180 |
while True:
|
181 |
token_buffer = []
|
182 |
loss_mask_buffer = []
|
183 |
+
if not self._eval_dataset:
|
184 |
+
self._shuffle()
|
185 |
for index, example in enumerate(self._dataset):
|
186 |
self._index = index
|
187 |
if not self._eval_dataset and self._dataset_loc > index:
|
|
|
197 |
'epoch': self._train_epochs,
|
198 |
}
|
199 |
batch = {
|
200 |
+
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
|
201 |
self.config.batch_size, -1
|
202 |
),
|
203 |
+
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
|
204 |
self.config.batch_size, -1
|
205 |
),
|
206 |
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
EasyLM/jax_utils.py
CHANGED
@@ -400,3 +400,4 @@ def get_weight_decay_mask(exclusions):
|
|
400 |
def tree_apply(fns, tree):
|
401 |
""" Apply a pytree of functions to the pytree. """
|
402 |
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
|
|
|
|
400 |
def tree_apply(fns, tree):
|
401 |
""" Apply a pytree of functions to the pytree. """
|
402 |
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
|
403 |
+
|
EasyLM/models/gptj/gptj_serve.py
CHANGED
@@ -18,7 +18,7 @@ from transformers import GenerationConfig, FlaxLogitsProcessorList
|
|
18 |
from EasyLM.checkpoint import StreamingCheckpointer
|
19 |
from EasyLM.serving import LMServer
|
20 |
from EasyLM.jax_utils import (
|
21 |
-
JaxRNG, next_rng, match_partition_rules, tree_apply,
|
22 |
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
23 |
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
24 |
)
|
@@ -43,12 +43,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
43 |
load_checkpoint='',
|
44 |
tokenizer=GPTJConfig.get_tokenizer_config(),
|
45 |
lm_server=LMServer.get_default_config(),
|
|
|
46 |
)
|
47 |
|
48 |
|
49 |
def main(argv):
|
50 |
-
|
51 |
-
jax.distributed.initialize()
|
52 |
set_random_seed(FLAGS.seed)
|
53 |
|
54 |
prefix_tokenizer = GPTJConfig.get_tokenizer(
|
|
|
18 |
from EasyLM.checkpoint import StreamingCheckpointer
|
19 |
from EasyLM.serving import LMServer
|
20 |
from EasyLM.jax_utils import (
|
21 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
22 |
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
23 |
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
24 |
)
|
|
|
43 |
load_checkpoint='',
|
44 |
tokenizer=GPTJConfig.get_tokenizer_config(),
|
45 |
lm_server=LMServer.get_default_config(),
|
46 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
47 |
)
|
48 |
|
49 |
|
50 |
def main(argv):
|
51 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
|
|
52 |
set_random_seed(FLAGS.seed)
|
53 |
|
54 |
prefix_tokenizer = GPTJConfig.get_tokenizer(
|
EasyLM/models/gptj/gptj_train.py
CHANGED
@@ -15,7 +15,7 @@ from EasyLM.data import DatasetFactory
|
|
15 |
from EasyLM.checkpoint import StreamingCheckpointer
|
16 |
from EasyLM.optimizers import OptimizerFactory
|
17 |
from EasyLM.jax_utils import (
|
18 |
-
JaxRNG, next_rng, match_partition_rules,
|
19 |
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
20 |
set_random_seed, average_metrics, get_weight_decay_mask,
|
21 |
make_shard_and_gather_fns, tree_apply
|
@@ -25,7 +25,6 @@ from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
|
|
25 |
|
26 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
27 |
seed=42,
|
28 |
-
initialize_jax_distributed=False,
|
29 |
mesh_dim='1,-1,1',
|
30 |
dtype='fp32',
|
31 |
total_steps=10000,
|
@@ -45,13 +44,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
45 |
gptj=GPTJConfig.get_default_config(),
|
46 |
logger=mlxu.WandBLogger.get_default_config(),
|
47 |
log_all_worker=False,
|
|
|
48 |
)
|
49 |
|
50 |
|
51 |
def main(argv):
|
52 |
-
|
53 |
-
jax.distributed.initialize()
|
54 |
-
|
55 |
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
56 |
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
57 |
logger = mlxu.WandBLogger(
|
|
|
15 |
from EasyLM.checkpoint import StreamingCheckpointer
|
16 |
from EasyLM.optimizers import OptimizerFactory
|
17 |
from EasyLM.jax_utils import (
|
18 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
19 |
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
20 |
set_random_seed, average_metrics, get_weight_decay_mask,
|
21 |
make_shard_and_gather_fns, tree_apply
|
|
|
25 |
|
26 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
27 |
seed=42,
|
|
|
28 |
mesh_dim='1,-1,1',
|
29 |
dtype='fp32',
|
30 |
total_steps=10000,
|
|
|
44 |
gptj=GPTJConfig.get_default_config(),
|
45 |
logger=mlxu.WandBLogger.get_default_config(),
|
46 |
log_all_worker=False,
|
47 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
48 |
)
|
49 |
|
50 |
|
51 |
def main(argv):
|
52 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
|
|
|
|
53 |
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
54 |
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
55 |
logger = mlxu.WandBLogger(
|
EasyLM/models/llama/convert_easylm_to_hf.py
CHANGED
@@ -77,6 +77,14 @@ LLAMA_STANDARD_CONFIGS = {
|
|
77 |
'n_heads': 32,
|
78 |
'norm_eps': 1e-6,
|
79 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
'3b': {
|
81 |
'vocab_size': 64256,
|
82 |
'dim': 3200,
|
@@ -132,7 +140,7 @@ def match_keywords(string, positives, negatives):
|
|
132 |
|
133 |
def load_and_convert_checkpoint(path):
|
134 |
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
|
135 |
-
flax_params = flatten_dict(flax_params['params']
|
136 |
torch_params = {}
|
137 |
for key, tensor in flax_params.items():
|
138 |
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
|
@@ -219,7 +227,6 @@ def write_model(loaded, model_path, model_size):
|
|
219 |
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
220 |
|
221 |
config = LlamaConfig(
|
222 |
-
vocab_size=params["vocab_size"],
|
223 |
hidden_size=dim,
|
224 |
intermediate_size=params["intermediate_size"],
|
225 |
num_attention_heads=params["n_heads"],
|
@@ -235,12 +242,13 @@ def write_model(loaded, model_path, model_size):
|
|
235 |
|
236 |
print("Loading the checkpoint in a Llama model.")
|
237 |
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
|
238 |
-
print("Model parameter count", model.num_parameters())
|
239 |
# Avoid saving this as part of the config.
|
|
|
240 |
del model.config._name_or_path
|
241 |
|
242 |
print("Saving in the Transformers format.")
|
243 |
model.save_pretrained(model_path)
|
|
|
244 |
shutil.rmtree(tmp_model_path)
|
245 |
|
246 |
|
@@ -252,21 +260,21 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
|
252 |
"bos_token": {
|
253 |
"content": "<s>",
|
254 |
"lstrip": False,
|
255 |
-
"normalized":
|
256 |
"rstrip": False,
|
257 |
"single_word": False
|
258 |
},
|
259 |
"eos_token": {
|
260 |
"content": "</s>",
|
261 |
"lstrip": False,
|
262 |
-
"normalized":
|
263 |
"rstrip": False,
|
264 |
"single_word": False
|
265 |
},
|
266 |
"unk_token": {
|
267 |
"content": "<unk>",
|
268 |
"lstrip": False,
|
269 |
-
"normalized":
|
270 |
"rstrip": False,
|
271 |
"single_word": False
|
272 |
},
|
@@ -286,7 +294,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
|
286 |
"__type": "AddedToken",
|
287 |
"content": "<s>",
|
288 |
"lstrip": False,
|
289 |
-
"normalized":
|
290 |
"rstrip": False,
|
291 |
"single_word": False
|
292 |
},
|
@@ -294,7 +302,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
|
294 |
"__type": "AddedToken",
|
295 |
"content": "</s>",
|
296 |
"lstrip": False,
|
297 |
-
"normalized":
|
298 |
"rstrip": False,
|
299 |
"single_word": False
|
300 |
},
|
@@ -302,7 +310,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
|
302 |
"__type": "AddedToken",
|
303 |
"content": "<unk>",
|
304 |
"lstrip": False,
|
305 |
-
"normalized":
|
306 |
"rstrip": False,
|
307 |
"single_word": False
|
308 |
},
|
@@ -313,7 +321,7 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
|
313 |
|
314 |
|
315 |
def main(argv):
|
316 |
-
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""
|
317 |
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
|
318 |
# write_tokenizer(
|
319 |
# tokenizer_path=FLAGS.output_dir,
|
|
|
77 |
'n_heads': 32,
|
78 |
'norm_eps': 1e-6,
|
79 |
},
|
80 |
+
'1b': {
|
81 |
+
'vocab_size': 64256,
|
82 |
+
'dim': 2048,
|
83 |
+
'intermediate_size': 5504,
|
84 |
+
'n_layers': 22,
|
85 |
+
'n_heads': 16,
|
86 |
+
'norm_eps': 1e-6,
|
87 |
+
},
|
88 |
'3b': {
|
89 |
'vocab_size': 64256,
|
90 |
'dim': 3200,
|
|
|
140 |
|
141 |
def load_and_convert_checkpoint(path):
|
142 |
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
|
143 |
+
flax_params = flatten_dict(flax_params['params'], sep='.')
|
144 |
torch_params = {}
|
145 |
for key, tensor in flax_params.items():
|
146 |
if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
|
|
|
227 |
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
228 |
|
229 |
config = LlamaConfig(
|
|
|
230 |
hidden_size=dim,
|
231 |
intermediate_size=params["intermediate_size"],
|
232 |
num_attention_heads=params["n_heads"],
|
|
|
242 |
|
243 |
print("Loading the checkpoint in a Llama model.")
|
244 |
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
|
|
|
245 |
# Avoid saving this as part of the config.
|
246 |
+
print("Model parameter count", model.num_parameters())
|
247 |
del model.config._name_or_path
|
248 |
|
249 |
print("Saving in the Transformers format.")
|
250 |
model.save_pretrained(model_path)
|
251 |
+
model.save_pretrained(model_path, safe_serialization=True)
|
252 |
shutil.rmtree(tmp_model_path)
|
253 |
|
254 |
|
|
|
260 |
"bos_token": {
|
261 |
"content": "<s>",
|
262 |
"lstrip": False,
|
263 |
+
"normalized": True,
|
264 |
"rstrip": False,
|
265 |
"single_word": False
|
266 |
},
|
267 |
"eos_token": {
|
268 |
"content": "</s>",
|
269 |
"lstrip": False,
|
270 |
+
"normalized": True,
|
271 |
"rstrip": False,
|
272 |
"single_word": False
|
273 |
},
|
274 |
"unk_token": {
|
275 |
"content": "<unk>",
|
276 |
"lstrip": False,
|
277 |
+
"normalized": True,
|
278 |
"rstrip": False,
|
279 |
"single_word": False
|
280 |
},
|
|
|
294 |
"__type": "AddedToken",
|
295 |
"content": "<s>",
|
296 |
"lstrip": False,
|
297 |
+
"normalized": True,
|
298 |
"rstrip": False,
|
299 |
"single_word": False
|
300 |
},
|
|
|
302 |
"__type": "AddedToken",
|
303 |
"content": "</s>",
|
304 |
"lstrip": False,
|
305 |
+
"normalized": True,
|
306 |
"rstrip": False,
|
307 |
"single_word": False
|
308 |
},
|
|
|
310 |
"__type": "AddedToken",
|
311 |
"content": "<unk>",
|
312 |
"lstrip": False,
|
313 |
+
"normalized": True,
|
314 |
"rstrip": False,
|
315 |
"single_word": False
|
316 |
},
|
|
|
321 |
|
322 |
|
323 |
def main(argv):
|
324 |
+
assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != ""# and FLAGS.tokenizer_path != ""
|
325 |
assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
|
326 |
# write_tokenizer(
|
327 |
# tokenizer_path=FLAGS.output_dir,
|
EasyLM/models/llama/convert_hf_to_easylm.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python convert_hf_to_easylm.py \
|
4 |
+
--checkpoint_dir /path/hf_format_dir/ \
|
5 |
+
--output_file /path/easylm_format.stream \
|
6 |
+
--model_size 7b \
|
7 |
+
--streaming
|
8 |
+
"""
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
import mlxu
|
14 |
+
import torch
|
15 |
+
import flax
|
16 |
+
|
17 |
+
from EasyLM.checkpoint import StreamingCheckpointer
|
18 |
+
|
19 |
+
LLAMA_STANDARD_CONFIGS = {
|
20 |
+
'1b': {
|
21 |
+
'dim': 2048,
|
22 |
+
'intermediate_size': 5504,
|
23 |
+
'n_layers': 22,
|
24 |
+
'n_heads': 16,
|
25 |
+
'norm_eps': 1e-6,
|
26 |
+
},
|
27 |
+
'3b': {
|
28 |
+
'dim': 3200,
|
29 |
+
'intermediate_size': 8640,
|
30 |
+
'n_layers': 26,
|
31 |
+
'n_heads': 32,
|
32 |
+
'norm_eps': 1e-6,
|
33 |
+
},
|
34 |
+
"7b": {
|
35 |
+
"dim": 4096,
|
36 |
+
"intermediate_size": 11008,
|
37 |
+
"n_layers": 32,
|
38 |
+
"n_heads": 32,
|
39 |
+
"norm_eps": 1e-6,
|
40 |
+
},
|
41 |
+
"13b": {
|
42 |
+
"dim": 5120,
|
43 |
+
"intermediate_size": 13824,
|
44 |
+
"n_layers": 40,
|
45 |
+
"n_heads": 40,
|
46 |
+
"norm_eps": 1e-6,
|
47 |
+
},
|
48 |
+
"30b": {
|
49 |
+
"dim": 6656,
|
50 |
+
"intermediate_size": 17920,
|
51 |
+
"n_layers": 60,
|
52 |
+
"n_heads": 52,
|
53 |
+
"norm_eps": 1e-6,
|
54 |
+
},
|
55 |
+
"65b": {
|
56 |
+
"dim": 8192,
|
57 |
+
"intermediate_size": 22016,
|
58 |
+
"n_layers": 80,
|
59 |
+
"n_heads": 64,
|
60 |
+
"norm_eps": 1e-5,
|
61 |
+
},
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
def inverse_permute(params, w):
|
66 |
+
n_layers = params["n_layers"]
|
67 |
+
n_heads = params["n_heads"]
|
68 |
+
dim = params["dim"]
|
69 |
+
reshaped_w = w.reshape(n_heads, 2, dim // n_heads // 2, dim)
|
70 |
+
transposed_w = reshaped_w.transpose(0, 2, 1, 3)
|
71 |
+
inverted_w = transposed_w.reshape(dim, dim)
|
72 |
+
return inverted_w
|
73 |
+
|
74 |
+
|
75 |
+
def main(args):
|
76 |
+
start = time.time()
|
77 |
+
params = LLAMA_STANDARD_CONFIGS[args.model_size]
|
78 |
+
|
79 |
+
ckpt_paths = sorted(Path(args.checkpoint_dir).glob("*.bin"))
|
80 |
+
ckpt = {}
|
81 |
+
for i, ckpt_path in enumerate(ckpt_paths):
|
82 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
83 |
+
for k, v in checkpoint.items():
|
84 |
+
if k.startswith("model."):
|
85 |
+
k = k[6:]
|
86 |
+
ckpt[k] = v
|
87 |
+
print(f"Start convert weight to easylm format...")
|
88 |
+
jax_weights = {
|
89 |
+
"transformer": {
|
90 |
+
"wte": {"embedding": ckpt["embed_tokens.weight"].numpy()},
|
91 |
+
"ln_f": {"kernel": ckpt["norm.weight"].numpy()},
|
92 |
+
"h": {
|
93 |
+
"%d"
|
94 |
+
% (layer): {
|
95 |
+
"attention": {
|
96 |
+
"wq": {
|
97 |
+
"kernel": inverse_permute(
|
98 |
+
params,
|
99 |
+
ckpt[f"layers.{layer}.self_attn.q_proj.weight"].numpy(),
|
100 |
+
).transpose()
|
101 |
+
},
|
102 |
+
"wk": {
|
103 |
+
"kernel": inverse_permute(
|
104 |
+
params,
|
105 |
+
ckpt[f"layers.{layer}.self_attn.k_proj.weight"].numpy(),
|
106 |
+
).transpose()
|
107 |
+
},
|
108 |
+
"wv": {
|
109 |
+
"kernel": ckpt[f"layers.{layer}.self_attn.v_proj.weight"]
|
110 |
+
.numpy()
|
111 |
+
.transpose()
|
112 |
+
},
|
113 |
+
"wo": {
|
114 |
+
"kernel": ckpt[f"layers.{layer}.self_attn.o_proj.weight"]
|
115 |
+
.numpy()
|
116 |
+
.transpose()
|
117 |
+
},
|
118 |
+
},
|
119 |
+
"feed_forward": {
|
120 |
+
"w1": {
|
121 |
+
"kernel": ckpt[f"layers.{layer}.mlp.gate_proj.weight"]
|
122 |
+
.numpy()
|
123 |
+
.transpose()
|
124 |
+
},
|
125 |
+
"w2": {
|
126 |
+
"kernel": ckpt[f"layers.{layer}.mlp.down_proj.weight"]
|
127 |
+
.numpy()
|
128 |
+
.transpose()
|
129 |
+
},
|
130 |
+
"w3": {
|
131 |
+
"kernel": ckpt[f"layers.{layer}.mlp.up_proj.weight"]
|
132 |
+
.numpy()
|
133 |
+
.transpose()
|
134 |
+
},
|
135 |
+
},
|
136 |
+
"attention_norm": {
|
137 |
+
"kernel": ckpt[f"layers.{layer}.input_layernorm.weight"].numpy()
|
138 |
+
},
|
139 |
+
"ffn_norm": {
|
140 |
+
"kernel": ckpt[
|
141 |
+
f"layers.{layer}.post_attention_layernorm.weight"
|
142 |
+
].numpy()
|
143 |
+
},
|
144 |
+
}
|
145 |
+
for layer in range(params["n_layers"])
|
146 |
+
},
|
147 |
+
},
|
148 |
+
"lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()},
|
149 |
+
}
|
150 |
+
print(f"Convert weight to easylm format finished...")
|
151 |
+
print(f"Start to save...")
|
152 |
+
|
153 |
+
if args.streaming:
|
154 |
+
StreamingCheckpointer.save_train_state_to_file(jax_weights, args.output_file)
|
155 |
+
else:
|
156 |
+
with mlxu.open_file(args.output_file, "wb") as fout:
|
157 |
+
fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
|
158 |
+
|
159 |
+
print(
|
160 |
+
f"Save finished!!! take time: {time.time() - start} save path: {args.output_file}"
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
parser = argparse.ArgumentParser(description="hf to easylm format script")
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"--checkpoint_dir",
|
169 |
+
type=str,
|
170 |
+
help="Need to be converted model weight dir. it is a dir",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--output_file", type=str, help="Save model weight file path, it is a file."
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--model_size",
|
177 |
+
type=str,
|
178 |
+
default="7b",
|
179 |
+
choices=["7b", "13b", "30b", "65b"],
|
180 |
+
help="model size",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--streaming",
|
184 |
+
action="store_true",
|
185 |
+
default=True,
|
186 |
+
help="whether is model weight saved stream format",
|
187 |
+
)
|
188 |
+
|
189 |
+
args = parser.parse_args()
|
190 |
+
|
191 |
+
print(f"checkpoint_dir: {args.checkpoint_dir}")
|
192 |
+
print(f"output_file: {args.output_file}")
|
193 |
+
print(f"model_size: {args.model_size}")
|
194 |
+
print(f"streaming: {args.streaming}")
|
195 |
+
|
196 |
+
main(args)
|
EasyLM/models/llama/llama_model.py
CHANGED
@@ -3,6 +3,7 @@ from shutil import copyfile
|
|
3 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
import json
|
5 |
import tempfile
|
|
|
6 |
|
7 |
import numpy as np
|
8 |
import jax
|
@@ -15,8 +16,10 @@ from flax.linen import combine_masks, make_causal_mask
|
|
15 |
from flax.linen.attention import dot_product_attention_weights
|
16 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
17 |
from flax.linen import partitioning as nn_partitioning
|
|
|
18 |
|
19 |
import sentencepiece as spm
|
|
|
20 |
from transformers.configuration_utils import PretrainedConfig
|
21 |
from transformers.utils import logging
|
22 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
@@ -28,6 +31,7 @@ from ml_collections import ConfigDict
|
|
28 |
from ml_collections.config_dict import config_dict
|
29 |
from mlxu import function_args_to_config, load_pickle, open_file
|
30 |
|
|
|
31 |
from EasyLM.jax_utils import (
|
32 |
with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
|
33 |
)
|
@@ -82,6 +86,18 @@ LLAMA_STANDARD_CONFIGS = {
|
|
82 |
'use_cache': True,
|
83 |
'tie_word_embeddings': False,
|
84 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
'3b': {
|
86 |
'vocab_size': 64256,
|
87 |
'hidden_size': 3200,
|
@@ -219,7 +235,14 @@ class LLaMAConfig(PretrainedConfig):
|
|
219 |
embd_pdrop=0.0,
|
220 |
attn_pdrop=0.0,
|
221 |
tie_word_embeddings=False,
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
fcm_min_ratio=0.0,
|
224 |
fcm_max_ratio=0.0,
|
225 |
**kwargs,
|
@@ -236,7 +259,14 @@ class LLaMAConfig(PretrainedConfig):
|
|
236 |
self.resid_pdrop = resid_pdrop
|
237 |
self.embd_pdrop = embd_pdrop
|
238 |
self.attn_pdrop = attn_pdrop
|
239 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
self.fcm_min_ratio = fcm_min_ratio
|
241 |
self.fcm_max_ratio = fcm_max_ratio
|
242 |
super().__init__(
|
@@ -302,6 +332,7 @@ class LLaMAConfig(PretrainedConfig):
|
|
302 |
def get_tokenizer_config(updates=None):
|
303 |
config = ConfigDict()
|
304 |
config.vocab_file = ''
|
|
|
305 |
config.add_bos_token = False
|
306 |
config.add_eos_token = False
|
307 |
|
@@ -312,14 +343,23 @@ class LLaMAConfig(PretrainedConfig):
|
|
312 |
@classmethod
|
313 |
def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
|
314 |
config = cls.get_tokenizer_config(config)
|
315 |
-
assert config.vocab_file != '', 'vocab_file must be specified'
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
return tokenizer
|
324 |
|
325 |
@classmethod
|
@@ -515,53 +555,82 @@ class FlaxLLaMAAttention(nn.Module):
|
|
515 |
|
516 |
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
|
517 |
|
518 |
-
|
|
|
|
|
519 |
|
520 |
-
if self.has_variable("cache", "cached_key"):
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
)
|
|
|
526 |
else:
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
|
|
|
|
|
|
534 |
|
535 |
-
|
536 |
-
|
537 |
-
dropout_rng = self.make_rng("dropout")
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
if self.has_variable("cache", "cached_key") or init_cache:
|
542 |
-
xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
|
543 |
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
549 |
-
)
|
550 |
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
|
564 |
-
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
|
565 |
attn_output = self._merge_heads(attn_output)
|
566 |
attn_output = self.wo(attn_output)
|
567 |
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
@@ -617,13 +686,28 @@ class FlaxLLaMABlock(nn.Module):
|
|
617 |
precision: Optional[Union[jax.lax.Precision, str]]=None
|
618 |
|
619 |
def setup(self) -> None:
|
620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
self.config,
|
622 |
dtype=self.dtype,
|
623 |
param_dtype=self.param_dtype,
|
624 |
precision=self.precision,
|
625 |
)
|
626 |
-
self.feed_forward =
|
627 |
self.config,
|
628 |
dtype=self.dtype,
|
629 |
param_dtype=self.param_dtype,
|
@@ -654,20 +738,32 @@ class FlaxLLaMABlock(nn.Module):
|
|
654 |
):
|
655 |
attn_outputs = self.attention(
|
656 |
self.attention_norm(hidden_states),
|
657 |
-
attention_mask
|
658 |
-
position_ids
|
659 |
-
deterministic
|
660 |
-
init_cache
|
661 |
-
output_attentions
|
662 |
-
fcm_mask
|
663 |
)
|
664 |
attn_output = attn_outputs[0]
|
665 |
hidden_states = hidden_states + attn_output
|
666 |
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
671 |
hidden_states = hidden_states + feed_forward_hidden_states
|
672 |
|
673 |
return (hidden_states,) + attn_outputs[1:]
|
@@ -828,14 +924,19 @@ class FlaxLLaMABlockCollection(nn.Module):
|
|
828 |
|
829 |
def setup(self):
|
830 |
block = FlaxLLaMABlock
|
831 |
-
if self.config.
|
832 |
-
|
833 |
-
|
834 |
-
policy=get_gradient_checkpoint_policy(self.config.
|
835 |
)
|
836 |
-
block = FlaxLLaMACheckpointBlock
|
837 |
self.blocks = [
|
838 |
-
block(
|
|
|
|
|
|
|
|
|
|
|
|
|
839 |
]
|
840 |
|
841 |
def __call__(
|
@@ -862,7 +963,7 @@ class FlaxLLaMABlockCollection(nn.Module):
|
|
862 |
)
|
863 |
fcm_mask = jax.random.uniform(
|
864 |
self.make_rng('fcm'),
|
865 |
-
shape=(batch_size, 1,
|
866 |
) > fcm_ratio
|
867 |
fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
|
868 |
fcm_mask = fcm_mask.astype('bool')
|
@@ -1034,7 +1135,7 @@ class FlaxLLaMAForCausalLMModule(nn.Module):
|
|
1034 |
class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
|
1035 |
module_class = FlaxLLaMAForCausalLMModule
|
1036 |
|
1037 |
-
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[
|
1038 |
# initializing the cache
|
1039 |
batch_size, seq_length = input_ids.shape
|
1040 |
|
|
|
3 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
import json
|
5 |
import tempfile
|
6 |
+
from functools import partial
|
7 |
|
8 |
import numpy as np
|
9 |
import jax
|
|
|
16 |
from flax.linen.attention import dot_product_attention_weights
|
17 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
18 |
from flax.linen import partitioning as nn_partitioning
|
19 |
+
import einops
|
20 |
|
21 |
import sentencepiece as spm
|
22 |
+
from transformers import AutoTokenizer
|
23 |
from transformers.configuration_utils import PretrainedConfig
|
24 |
from transformers.utils import logging
|
25 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
|
31 |
from ml_collections.config_dict import config_dict
|
32 |
from mlxu import function_args_to_config, load_pickle, open_file
|
33 |
|
34 |
+
from EasyLM.bpt import blockwise_ffn, blockwise_attn
|
35 |
from EasyLM.jax_utils import (
|
36 |
with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
|
37 |
)
|
|
|
86 |
'use_cache': True,
|
87 |
'tie_word_embeddings': False,
|
88 |
},
|
89 |
+
'1b': {
|
90 |
+
'vocab_size': 64256,
|
91 |
+
'hidden_size': 2048,
|
92 |
+
'intermediate_size': 5504,
|
93 |
+
'num_hidden_layers': 22,
|
94 |
+
'num_attention_heads': 16,
|
95 |
+
'max_sequence_length': 2048,
|
96 |
+
'initializer_range': 0.02,
|
97 |
+
'rms_norm_eps': 1e-6,
|
98 |
+
'use_cache': True,
|
99 |
+
'tie_word_embeddings': False,
|
100 |
+
},
|
101 |
'3b': {
|
102 |
'vocab_size': 64256,
|
103 |
'hidden_size': 3200,
|
|
|
235 |
embd_pdrop=0.0,
|
236 |
attn_pdrop=0.0,
|
237 |
tie_word_embeddings=False,
|
238 |
+
remat_block='',
|
239 |
+
remat_attention='',
|
240 |
+
remat_mlp='',
|
241 |
+
scan_attention=False,
|
242 |
+
scan_mlp=False,
|
243 |
+
scan_query_chunk_size=1024,
|
244 |
+
scan_key_chunk_size=1024,
|
245 |
+
scan_mlp_chunk_size=1024,
|
246 |
fcm_min_ratio=0.0,
|
247 |
fcm_max_ratio=0.0,
|
248 |
**kwargs,
|
|
|
259 |
self.resid_pdrop = resid_pdrop
|
260 |
self.embd_pdrop = embd_pdrop
|
261 |
self.attn_pdrop = attn_pdrop
|
262 |
+
self.remat_block = remat_block
|
263 |
+
self.remat_attention = remat_attention
|
264 |
+
self.remat_mlp = remat_mlp
|
265 |
+
self.scan_attention = scan_attention
|
266 |
+
self.scan_mlp = scan_mlp
|
267 |
+
self.scan_query_chunk_size = scan_query_chunk_size
|
268 |
+
self.scan_key_chunk_size = scan_key_chunk_size
|
269 |
+
self.scan_mlp_chunk_size = scan_mlp_chunk_size
|
270 |
self.fcm_min_ratio = fcm_min_ratio
|
271 |
self.fcm_max_ratio = fcm_max_ratio
|
272 |
super().__init__(
|
|
|
332 |
def get_tokenizer_config(updates=None):
|
333 |
config = ConfigDict()
|
334 |
config.vocab_file = ''
|
335 |
+
config.pretrained_model_name_or_path = ''
|
336 |
config.add_bos_token = False
|
337 |
config.add_eos_token = False
|
338 |
|
|
|
343 |
@classmethod
|
344 |
def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
|
345 |
config = cls.get_tokenizer_config(config)
|
346 |
+
assert config.vocab_file != '' and config.pretrained_model_name_or_path != '', 'vocab_file or pretrained_model_name_or_path must be specified'
|
347 |
+
if config.pretrained_model_name_or_path != '':
|
348 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
349 |
+
config.pretrained_model_name_or_path,
|
350 |
+
add_bos_token=config.add_bos_token,
|
351 |
+
add_eos_token=config.add_eos_token,
|
352 |
+
padding_side=padding_side,
|
353 |
+
truncation_side=truncation_side,
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
tokenizer = LLaMATokenizer(
|
357 |
+
vocab_file=config.vocab_file,
|
358 |
+
add_bos_token=config.add_bos_token,
|
359 |
+
add_eos_token=config.add_eos_token,
|
360 |
+
padding_side=padding_side,
|
361 |
+
truncation_side=truncation_side,
|
362 |
+
)
|
363 |
return tokenizer
|
364 |
|
365 |
@classmethod
|
|
|
555 |
|
556 |
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
|
557 |
|
558 |
+
dropout_rng = None
|
559 |
+
if not deterministic and self.config.attn_pdrop > 0.0:
|
560 |
+
dropout_rng = self.make_rng("dropout")
|
561 |
|
562 |
+
if self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache):
|
563 |
+
# doesn't need blockwise attention if we are doing autoregressive decoding since no quadratic memory
|
564 |
+
|
565 |
+
# attention mask without nxn materlization, blockwise_attn will handle the rest
|
566 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
567 |
+
# transform boolean mask into float mask
|
568 |
+
attention_bias = lax.select(
|
569 |
+
attention_mask > 0,
|
570 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
571 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
572 |
+
)
|
573 |
+
attn_weights = None
|
574 |
+
attn_output = blockwise_attn(
|
575 |
+
xq,
|
576 |
+
xk,
|
577 |
+
xv,
|
578 |
+
bias=attention_bias,
|
579 |
+
deterministic=deterministic,
|
580 |
+
dropout_rng=dropout_rng,
|
581 |
+
attn_pdrop=self.config.attn_pdrop,
|
582 |
+
causal=True,
|
583 |
+
query_chunk_size=self.config.scan_query_chunk_size,
|
584 |
+
key_chunk_size=self.config.scan_key_chunk_size,
|
585 |
+
dtype=self.dtype,
|
586 |
+
policy=get_gradient_checkpoint_policy('nothing_saveable'),
|
587 |
+
precision=self.precision,
|
588 |
+
float32_logits=True,
|
589 |
+
prevent_cse=True,
|
590 |
)
|
591 |
+
attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
|
592 |
else:
|
593 |
+
query_length, key_length = xq.shape[1], xk.shape[1]
|
594 |
+
|
595 |
+
if self.has_variable("cache", "cached_key"):
|
596 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
597 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
598 |
+
causal_mask = lax.dynamic_slice(
|
599 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
600 |
+
)
|
601 |
+
else:
|
602 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
603 |
|
604 |
+
batch_size = hidden_states.shape[0]
|
605 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
|
|
606 |
|
607 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
608 |
+
attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
|
|
|
|
|
609 |
|
610 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
611 |
+
# and cache the keys and values step by step.
|
612 |
+
if self.has_variable("cache", "cached_key") or init_cache:
|
613 |
+
xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
|
|
|
|
|
614 |
|
615 |
+
# transform boolean mask into float mask
|
616 |
+
attention_bias = lax.select(
|
617 |
+
attention_mask > 0,
|
618 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
619 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
620 |
+
)
|
621 |
+
attn_weights = dot_product_attention_weights(
|
622 |
+
xq,
|
623 |
+
xk,
|
624 |
+
bias=attention_bias,
|
625 |
+
dropout_rng=dropout_rng,
|
626 |
+
dropout_rate=self.config.attn_pdrop,
|
627 |
+
deterministic=deterministic,
|
628 |
+
dtype=jnp.promote_types(self.dtype, jnp.float32),
|
629 |
+
precision=self.precision,
|
630 |
+
)
|
631 |
+
attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
|
632 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
|
633 |
|
|
|
634 |
attn_output = self._merge_heads(attn_output)
|
635 |
attn_output = self.wo(attn_output)
|
636 |
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
|
|
686 |
precision: Optional[Union[jax.lax.Precision, str]]=None
|
687 |
|
688 |
def setup(self) -> None:
|
689 |
+
attention_module = FlaxLLaMAAttention
|
690 |
+
mlp_module = FlaxLLaMAMLP
|
691 |
+
if self.config.remat_attention != '':
|
692 |
+
attention_module = remat(
|
693 |
+
FlaxLLaMAAttention, static_argnums=(3, 4, 5),
|
694 |
+
policy=get_gradient_checkpoint_policy(self.config.remat_attention),
|
695 |
+
prevent_cse=True,
|
696 |
+
)
|
697 |
+
if self.config.remat_mlp != '':
|
698 |
+
mlp_module = remat(
|
699 |
+
FlaxLLaMAMLP, static_argnums=(1,),
|
700 |
+
policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
|
701 |
+
prevent_cse=True,
|
702 |
+
)
|
703 |
+
|
704 |
+
self.attention = attention_module(
|
705 |
self.config,
|
706 |
dtype=self.dtype,
|
707 |
param_dtype=self.param_dtype,
|
708 |
precision=self.precision,
|
709 |
)
|
710 |
+
self.feed_forward = mlp_module(
|
711 |
self.config,
|
712 |
dtype=self.dtype,
|
713 |
param_dtype=self.param_dtype,
|
|
|
738 |
):
|
739 |
attn_outputs = self.attention(
|
740 |
self.attention_norm(hidden_states),
|
741 |
+
attention_mask,
|
742 |
+
position_ids,
|
743 |
+
deterministic,
|
744 |
+
init_cache,
|
745 |
+
output_attentions,
|
746 |
+
fcm_mask,
|
747 |
)
|
748 |
attn_output = attn_outputs[0]
|
749 |
hidden_states = hidden_states + attn_output
|
750 |
|
751 |
+
feed_forward_input = self.ffn_norm(hidden_states)
|
752 |
+
|
753 |
+
if self.config.scan_mlp:
|
754 |
+
feed_forward_hidden_states = blockwise_ffn(
|
755 |
+
self.feed_forward,
|
756 |
+
feed_forward_input,
|
757 |
+
self.config.scan_mlp_chunk_size,
|
758 |
+
deterministic,
|
759 |
+
)
|
760 |
+
else:
|
761 |
+
feed_forward_hidden_states = self.feed_forward(
|
762 |
+
feed_forward_input,
|
763 |
+
deterministic,
|
764 |
+
)
|
765 |
+
feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp"))
|
766 |
+
|
767 |
hidden_states = hidden_states + feed_forward_hidden_states
|
768 |
|
769 |
return (hidden_states,) + attn_outputs[1:]
|
|
|
924 |
|
925 |
def setup(self):
|
926 |
block = FlaxLLaMABlock
|
927 |
+
if self.config.remat_block != '':
|
928 |
+
block = remat(
|
929 |
+
FlaxLLaMABlock, static_argnums=(3, 4, 5),
|
930 |
+
policy=get_gradient_checkpoint_policy(self.config.remat_block)
|
931 |
)
|
|
|
932 |
self.blocks = [
|
933 |
+
block(
|
934 |
+
self.config,
|
935 |
+
name=str(i),
|
936 |
+
dtype=self.dtype,
|
937 |
+
param_dtype=self.param_dtype,
|
938 |
+
precision=self.precision
|
939 |
+
) for i in range(self.config.num_hidden_layers)
|
940 |
]
|
941 |
|
942 |
def __call__(
|
|
|
963 |
)
|
964 |
fcm_mask = jax.random.uniform(
|
965 |
self.make_rng('fcm'),
|
966 |
+
shape=(batch_size, 1, 1, seq_length)
|
967 |
) > fcm_ratio
|
968 |
fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
|
969 |
fcm_mask = fcm_mask.astype('bool')
|
|
|
1135 |
class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
|
1136 |
module_class = FlaxLLaMAForCausalLMModule
|
1137 |
|
1138 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
1139 |
# initializing the cache
|
1140 |
batch_size, seq_length = input_ids.shape
|
1141 |
|
EasyLM/models/llama/llama_serve.py
CHANGED
@@ -14,7 +14,7 @@ from transformers import GenerationConfig, FlaxLogitsProcessorList
|
|
14 |
from EasyLM.checkpoint import StreamingCheckpointer
|
15 |
from EasyLM.serving import LMServer
|
16 |
from EasyLM.jax_utils import (
|
17 |
-
JaxRNG, next_rng, match_partition_rules, tree_apply,
|
18 |
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
19 |
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
20 |
)
|
@@ -37,12 +37,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
37 |
load_checkpoint='',
|
38 |
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
39 |
lm_server=LMServer.get_default_config(),
|
|
|
40 |
)
|
41 |
|
42 |
|
43 |
def main(argv):
|
44 |
-
|
45 |
-
jax.distributed.initialize()
|
46 |
set_random_seed(FLAGS.seed)
|
47 |
|
48 |
prefix_tokenizer = LLaMAConfig.get_tokenizer(
|
|
|
14 |
from EasyLM.checkpoint import StreamingCheckpointer
|
15 |
from EasyLM.serving import LMServer
|
16 |
from EasyLM.jax_utils import (
|
17 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
|
18 |
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
|
19 |
with_sharding_constraint, FlaxTemperatureLogitsWarper
|
20 |
)
|
|
|
37 |
load_checkpoint='',
|
38 |
tokenizer=LLaMAConfig.get_tokenizer_config(),
|
39 |
lm_server=LMServer.get_default_config(),
|
40 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
41 |
)
|
42 |
|
43 |
|
44 |
def main(argv):
|
45 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
|
|
46 |
set_random_seed(FLAGS.seed)
|
47 |
|
48 |
prefix_tokenizer = LLaMAConfig.get_tokenizer(
|
EasyLM/models/llama/llama_train.py
CHANGED
@@ -15,7 +15,7 @@ from EasyLM.data import DatasetFactory
|
|
15 |
from EasyLM.checkpoint import StreamingCheckpointer
|
16 |
from EasyLM.optimizers import OptimizerFactory
|
17 |
from EasyLM.jax_utils import (
|
18 |
-
JaxRNG, next_rng, match_partition_rules,
|
19 |
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
20 |
set_random_seed, average_metrics, get_weight_decay_mask,
|
21 |
make_shard_and_gather_fns, with_sharding_constraint,
|
@@ -27,10 +27,8 @@ from EasyLM.models.llama.llama_model import (
|
|
27 |
|
28 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
29 |
seed=42,
|
30 |
-
initialize_jax_distributed=False,
|
31 |
mesh_dim='1,-1,1',
|
32 |
dtype='fp32',
|
33 |
-
param_dtype='fp32',
|
34 |
total_steps=10000,
|
35 |
load_llama_config='',
|
36 |
update_llama_config='',
|
@@ -48,13 +46,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
48 |
llama=LLaMAConfig.get_default_config(),
|
49 |
logger=mlxu.WandBLogger.get_default_config(),
|
50 |
log_all_worker=False,
|
|
|
51 |
)
|
52 |
|
53 |
|
54 |
def main(argv):
|
55 |
-
|
56 |
-
jax.distributed.initialize()
|
57 |
-
|
58 |
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
59 |
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
60 |
logger = mlxu.WandBLogger(
|
@@ -66,7 +63,6 @@ def main(argv):
|
|
66 |
|
67 |
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
|
68 |
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
69 |
-
|
70 |
if FLAGS.load_dataset_state != '':
|
71 |
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
72 |
|
@@ -90,6 +86,7 @@ def main(argv):
|
|
90 |
eos_token_id=dataset.tokenizer.eos_token_id,
|
91 |
))
|
92 |
if llama_config.vocab_size < dataset.vocab_size:
|
|
|
93 |
llama_config.update(dict(vocab_size=dataset.vocab_size))
|
94 |
|
95 |
model = FlaxLLaMAForCausalLMModule(
|
@@ -250,7 +247,7 @@ def main(argv):
|
|
250 |
metrics.update(average_metrics(eval_metric_list))
|
251 |
|
252 |
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
|
253 |
-
log_metrics = {"step": step}
|
254 |
log_metrics.update(metrics)
|
255 |
log_metrics.update(dataset_metrics)
|
256 |
log_metrics = jax.device_get(log_metrics)
|
|
|
15 |
from EasyLM.checkpoint import StreamingCheckpointer
|
16 |
from EasyLM.optimizers import OptimizerFactory
|
17 |
from EasyLM.jax_utils import (
|
18 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
|
19 |
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
|
20 |
set_random_seed, average_metrics, get_weight_decay_mask,
|
21 |
make_shard_and_gather_fns, with_sharding_constraint,
|
|
|
27 |
|
28 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
29 |
seed=42,
|
|
|
30 |
mesh_dim='1,-1,1',
|
31 |
dtype='fp32',
|
|
|
32 |
total_steps=10000,
|
33 |
load_llama_config='',
|
34 |
update_llama_config='',
|
|
|
46 |
llama=LLaMAConfig.get_default_config(),
|
47 |
logger=mlxu.WandBLogger.get_default_config(),
|
48 |
log_all_worker=False,
|
49 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
50 |
)
|
51 |
|
52 |
|
53 |
def main(argv):
|
54 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
|
|
|
|
55 |
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
56 |
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
57 |
logger = mlxu.WandBLogger(
|
|
|
63 |
|
64 |
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
|
65 |
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
|
|
|
66 |
if FLAGS.load_dataset_state != '':
|
67 |
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
|
68 |
|
|
|
86 |
eos_token_id=dataset.tokenizer.eos_token_id,
|
87 |
))
|
88 |
if llama_config.vocab_size < dataset.vocab_size:
|
89 |
+
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size)
|
90 |
llama_config.update(dict(vocab_size=dataset.vocab_size))
|
91 |
|
92 |
model = FlaxLLaMAForCausalLMModule(
|
|
|
247 |
metrics.update(average_metrics(eval_metric_list))
|
248 |
|
249 |
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
|
250 |
+
log_metrics = {"step": step + 1}
|
251 |
log_metrics.update(metrics)
|
252 |
log_metrics.update(dataset_metrics)
|
253 |
log_metrics = jax.device_get(log_metrics)
|
EasyLM/models/roberta/roberta_train.py
CHANGED
@@ -17,7 +17,7 @@ from EasyLM.data import DatasetFactory
|
|
17 |
from EasyLM.checkpoint import StreamingCheckpointer
|
18 |
from EasyLM.optimizers import OptimizerFactory
|
19 |
from EasyLM.jax_utils import (
|
20 |
-
JaxRNG, next_rng, match_partition_rules, get_float_dtype_by_name,
|
21 |
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
|
22 |
set_random_seed, average_metrics, get_weight_decay_mask,
|
23 |
make_shard_and_gather_fns, tree_apply
|
@@ -29,7 +29,6 @@ from EasyLM.models.roberta.roberta_model import (
|
|
29 |
|
30 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
31 |
seed=42,
|
32 |
-
initialize_jax_distributed=False,
|
33 |
mesh_dim='-1,1,1',
|
34 |
dtype='fp32',
|
35 |
mask_token_probability=0.15,
|
@@ -50,13 +49,12 @@ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
|
50 |
roberta=RobertaConfig.get_default_config(),
|
51 |
logger=mlxu.WandBLogger.get_default_config(),
|
52 |
log_all_worker=False,
|
|
|
53 |
)
|
54 |
|
55 |
|
56 |
def main(argv):
|
57 |
-
|
58 |
-
jax.distributed.initialize()
|
59 |
-
|
60 |
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
61 |
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
62 |
logger = mlxu.WandBLogger(
|
|
|
17 |
from EasyLM.checkpoint import StreamingCheckpointer
|
18 |
from EasyLM.optimizers import OptimizerFactory
|
19 |
from EasyLM.jax_utils import (
|
20 |
+
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
|
21 |
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
|
22 |
set_random_seed, average_metrics, get_weight_decay_mask,
|
23 |
make_shard_and_gather_fns, tree_apply
|
|
|
29 |
|
30 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
31 |
seed=42,
|
|
|
32 |
mesh_dim='-1,1,1',
|
33 |
dtype='fp32',
|
34 |
mask_token_probability=0.15,
|
|
|
49 |
roberta=RobertaConfig.get_default_config(),
|
50 |
logger=mlxu.WandBLogger.get_default_config(),
|
51 |
log_all_worker=False,
|
52 |
+
jax_distributed=JaxDistributedConfig.get_default_config(),
|
53 |
)
|
54 |
|
55 |
|
56 |
def main(argv):
|
57 |
+
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
|
|
|
|
|
58 |
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
|
59 |
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
|
60 |
logger = mlxu.WandBLogger(
|
EasyLM/optimizers.py
CHANGED
@@ -193,7 +193,6 @@ class AdamWOptimizerFactory(object):
|
|
193 |
|
194 |
return optimizer, optimizer_info
|
195 |
|
196 |
-
|
197 |
class LionOptimizerFactory(object):
|
198 |
""" Lion optimizer with cosine schedule. """
|
199 |
|
@@ -250,7 +249,7 @@ class LionOptimizerFactory(object):
|
|
250 |
|
251 |
|
252 |
class OptaxScheduledWeightDecayState(NamedTuple):
|
253 |
-
count:
|
254 |
|
255 |
|
256 |
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
|
|
193 |
|
194 |
return optimizer, optimizer_info
|
195 |
|
|
|
196 |
class LionOptimizerFactory(object):
|
197 |
""" Lion optimizer with cosine schedule. """
|
198 |
|
|
|
249 |
|
250 |
|
251 |
class OptaxScheduledWeightDecayState(NamedTuple):
|
252 |
+
count: jax.Array
|
253 |
|
254 |
|
255 |
def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
|
EasyLM/scripts/benchmark_attention.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from time import time
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import jax
|
6 |
+
import jax.flatten_util
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import mlxu
|
9 |
+
from EasyLM.bpt import blockwise_attn
|
10 |
+
from EasyLM.jax_utils import (
|
11 |
+
get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
16 |
+
seed=42,
|
17 |
+
dtype='fp32',
|
18 |
+
embed_dim=2048,
|
19 |
+
n_heads=16,
|
20 |
+
ref_attn_seq_len=2048,
|
21 |
+
eff_attn_seq_len=16384,
|
22 |
+
batch_size=1,
|
23 |
+
query_chunk_size=2048,
|
24 |
+
key_chunk_size=2048,
|
25 |
+
warmup_steps=40,
|
26 |
+
steps=200,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def main(argv):
|
31 |
+
|
32 |
+
def random_kqv(rng_key, seq_len):
|
33 |
+
rng_generator = JaxRNG(rng_key)
|
34 |
+
kqv = []
|
35 |
+
for i in range(3):
|
36 |
+
kqv.append(
|
37 |
+
jax.random.normal(
|
38 |
+
rng_generator(),
|
39 |
+
(FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads),
|
40 |
+
dtype=get_float_dtype_by_name(FLAGS.dtype)
|
41 |
+
)
|
42 |
+
)
|
43 |
+
return tuple(kqv)
|
44 |
+
|
45 |
+
def reference_attn(query, key, value):
|
46 |
+
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
47 |
+
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
|
48 |
+
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
|
49 |
+
mask_value = jnp.finfo(logits.dtype).min
|
50 |
+
_, q_seq_len, _, _ = query.shape
|
51 |
+
_, kv_seq_len, _, _ = key.shape
|
52 |
+
mask_shape = (q_seq_len, kv_seq_len)
|
53 |
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
54 |
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
55 |
+
causal_mask = (row_ids < col_ids)[None, None, :, :]
|
56 |
+
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
|
57 |
+
weights = jax.nn.softmax(logits, axis=-1)
|
58 |
+
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
|
59 |
+
return out
|
60 |
+
|
61 |
+
def efficient_attention(query, key, value):
|
62 |
+
dtype = get_float_dtype_by_name(FLAGS.dtype)
|
63 |
+
return blockwise_attn(
|
64 |
+
query, key, value,
|
65 |
+
bias=None,
|
66 |
+
deterministic=True,
|
67 |
+
dropout_rng=None,
|
68 |
+
attn_pdrop=0.0,
|
69 |
+
causal=True,
|
70 |
+
query_chunk_size=FLAGS.query_chunk_size,
|
71 |
+
key_chunk_size=FLAGS.key_chunk_size,
|
72 |
+
dtype=get_float_dtype_by_name(FLAGS.dtype),
|
73 |
+
policy=jax.checkpoint_policies.nothing_saveable(),
|
74 |
+
precision=None,
|
75 |
+
float32_logits=True,
|
76 |
+
prevent_cse=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
@partial(jax.jit, static_argnums=(1,))
|
81 |
+
def reference_attn_forward_backward(rng_key, seq_len):
|
82 |
+
@partial(jax.grad, argnums=(0, 1, 2))
|
83 |
+
@partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable())
|
84 |
+
def grad_fn(query, key, value):
|
85 |
+
out = reference_attn(query, key, value)
|
86 |
+
return jnp.mean(out)
|
87 |
+
|
88 |
+
query, key, value = random_kqv(rng_key, seq_len)
|
89 |
+
return jax.flatten_util.ravel_pytree(
|
90 |
+
grad_fn(query, key, value)[1]
|
91 |
+
)[0].mean()
|
92 |
+
|
93 |
+
@partial(jax.jit, static_argnums=(1,))
|
94 |
+
def efficient_attn_forward_backward(rng_key, seq_len):
|
95 |
+
@partial(jax.grad, argnums=(0, 1, 2))
|
96 |
+
def grad_fn(query, key, value):
|
97 |
+
out = efficient_attention(query, key, value)
|
98 |
+
return jnp.mean(out)
|
99 |
+
|
100 |
+
query, key, value = random_kqv(rng_key, seq_len)
|
101 |
+
return jax.flatten_util.ravel_pytree(
|
102 |
+
grad_fn(query, key, value)[1]
|
103 |
+
)[0].mean()
|
104 |
+
|
105 |
+
|
106 |
+
set_random_seed(FLAGS.seed)
|
107 |
+
|
108 |
+
jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
109 |
+
jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
110 |
+
|
111 |
+
all_results = []
|
112 |
+
for i in range(FLAGS.warmup_steps):
|
113 |
+
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
114 |
+
jax.block_until_ready(all_results)
|
115 |
+
|
116 |
+
start_time = time()
|
117 |
+
all_results = []
|
118 |
+
for i in range(FLAGS.steps):
|
119 |
+
all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len))
|
120 |
+
|
121 |
+
jax.block_until_ready(all_results)
|
122 |
+
elapsed_time_ref_attn = time() - start_time
|
123 |
+
print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds')
|
124 |
+
|
125 |
+
|
126 |
+
all_results = []
|
127 |
+
for i in range(FLAGS.warmup_steps):
|
128 |
+
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
129 |
+
jax.block_until_ready(all_results)
|
130 |
+
|
131 |
+
|
132 |
+
start_time = time()
|
133 |
+
all_results = []
|
134 |
+
for i in range(FLAGS.steps):
|
135 |
+
all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len))
|
136 |
+
|
137 |
+
jax.block_until_ready(all_results)
|
138 |
+
elapsed_time_efficient_attn = time() - start_time
|
139 |
+
print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds')
|
140 |
+
|
141 |
+
flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2
|
142 |
+
efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio
|
143 |
+
print(f'Efficiency: {efficiency:.3f}')
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
mlxu.run(main)
|
148 |
+
|
149 |
+
|
150 |
+
|
EasyLM/scripts/lm_eval_harness.py
CHANGED
@@ -20,6 +20,8 @@ from EasyLM.serving import LMClient
|
|
20 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
21 |
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
|
22 |
shots=0,
|
|
|
|
|
23 |
lm_client=LMClient.get_default_config(),
|
24 |
logger=mlxu.WandBLogger.get_default_config(),
|
25 |
)
|
@@ -51,7 +53,9 @@ def main(argv):
|
|
51 |
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
|
52 |
task_list = FLAGS.tasks.split(',')
|
53 |
results = evaluator.evaluate(
|
54 |
-
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
|
|
|
|
|
55 |
)
|
56 |
logger.log(flatten_dict(results['results'], sep='/'))
|
57 |
pprint.pprint(results)
|
|
|
20 |
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
21 |
tasks='wsc,piqa,winogrande,openbookqa,logiqa',
|
22 |
shots=0,
|
23 |
+
limit=0,
|
24 |
+
write_out=False,
|
25 |
lm_client=LMClient.get_default_config(),
|
26 |
logger=mlxu.WandBLogger.get_default_config(),
|
27 |
)
|
|
|
53 |
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
|
54 |
task_list = FLAGS.tasks.split(',')
|
55 |
results = evaluator.evaluate(
|
56 |
+
model, tasks.get_task_dict(task_list), False, FLAGS.shots,
|
57 |
+
limit=None if FLAGS.limit <= 0 else FLAGS.limit,
|
58 |
+
write_out=FLAGS.write_out,
|
59 |
)
|
60 |
logger.log(flatten_dict(results['results'], sep='/'))
|
61 |
pprint.pprint(results)
|