cognitivess commited on
Commit
7b57135
1 Parent(s): 75afd51

Create modeling_flax_Cognitivess.py

Browse files
cognitivess_model/modeling_flax_Cognitivess.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Cognitivess AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Flax Cognitivess model."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
+ from flax.linen import combine_masks, make_causal_mask
25
+ from flax.linen.attention import dot_product_attention_weights
26
+ from flax.traverse_util import flatten_dict, unflatten_dict
27
+ from jax import lax
28
+
29
+ from ...modeling_flax_outputs import (
30
+ FlaxBaseModelOutput,
31
+ FlaxBaseModelOutputWithPast,
32
+ FlaxCausalLMOutput,
33
+ FlaxCausalLMOutputWithCrossAttentions,
34
+ )
35
+ from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging
36
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
37
+ from .configuration_Cognitivess import CognitivessConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "CognitivessConfig"
43
+ _REAL_CHECKPOINT_FOR_DOC = "CognitivessAI/cognitivess"
44
+ _CHECKPOINT_FOR_DOC = "ksmcg/Cognitivess-tiny"
45
+
46
+ Cognitivess_START_DOCSTRING = r"""
47
+
48
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
49
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
50
+ etc.)
51
+
52
+ This model is also a Flax Linen
53
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
54
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
55
+
56
+ Finally, this model supports inherent JAX features such as:
57
+
58
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
59
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
60
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
61
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
62
+
63
+ Parameters:
64
+ config ([`CognitivessConfig`]): Model configuration class with all the parameters of the model.
65
+ Initializing with a config file does not load the weights associated with the model, only the
66
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
67
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
68
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
69
+ `jax.numpy.bfloat16`.
70
+
71
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
72
+ specified all the computation will be performed with the given `dtype`.
73
+
74
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
75
+ parameters.**
76
+
77
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
78
+ [`~FlaxPreTrainedModel.to_bf16`].
79
+ """
80
+
81
+ Cognitivess_INPUTS_DOCSTRING = r"""
82
+ Args:
83
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
84
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
85
+ it.
86
+
87
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
88
+ [`PreTrainedTokenizer.__call__`] for details.
89
+
90
+ [What are input IDs?](../glossary#input-ids)
91
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
92
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
93
+
94
+ - 1 for tokens that are **not masked**,
95
+ - 0 for tokens that are **masked**.
96
+
97
+ [What are attention masks?](../glossary#attention-mask)
98
+
99
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
100
+ [`PreTrainedTokenizer.__call__`] for details.
101
+
102
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
103
+ `past_key_values`).
104
+
105
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
106
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
107
+ information on the default strategy.
108
+
109
+ - 1 indicates the head is **not masked**,
110
+ - 0 indicates the head is **masked**.
111
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
112
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
113
+ config.n_positions - 1]`.
114
+
115
+ [What are position IDs?](../glossary#position-ids)
116
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
117
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
118
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
119
+ output_attentions (`bool`, *optional*):
120
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
121
+ tensors for more detail.
122
+ output_hidden_states (`bool`, *optional*):
123
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
124
+ more detail.
125
+ return_dict (`bool`, *optional*):
126
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
127
+ """
128
+
129
+
130
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRMSNorm with Llama->Cognitivess
131
+ class FlaxCognitivessRMSNorm(nn.Module):
132
+ config: CognitivessConfig
133
+ dtype: jnp.dtype = jnp.float32
134
+
135
+ def setup(self):
136
+ self.epsilon = self.config.rms_norm_eps
137
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
138
+
139
+ def __call__(self, hidden_states):
140
+ variance = jnp.asarray(hidden_states, dtype=jnp.float32)
141
+ variance = jnp.power(variance, 2)
142
+ variance = variance.mean(-1, keepdims=True)
143
+ # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
144
+ hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
145
+
146
+ return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
147
+
148
+
149
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Cognitivess
150
+ class FlaxCognitivessRotaryEmbedding(nn.Module):
151
+ config: CognitivessConfig
152
+ dtype: jnp.dtype = jnp.float32
153
+
154
+ def setup(self):
155
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
156
+ self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
157
+
158
+ def __call__(self, key, query, position_ids):
159
+ sincos = self.sincos[position_ids]
160
+ sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
161
+
162
+ key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
163
+ query = apply_rotary_pos_emb(query, sin_pos, cos_pos)
164
+
165
+ key = jnp.asarray(key, dtype=self.dtype)
166
+ query = jnp.asarray(query, dtype=self.dtype)
167
+
168
+ return key, query
169
+
170
+
171
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Cognitivess
172
+ class FlaxCognitivessMLP(nn.Module):
173
+ config: CognitivessConfig
174
+ dtype: jnp.dtype = jnp.float32
175
+
176
+ def setup(self):
177
+ embed_dim = self.config.hidden_size
178
+ inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
179
+
180
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
181
+ self.act = ACT2FN[self.config.hidden_act]
182
+
183
+ self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
184
+ self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
185
+ self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
186
+
187
+ def __call__(self, hidden_states):
188
+ up_proj_states = self.up_proj(hidden_states)
189
+ gate_states = self.act(self.gate_proj(hidden_states))
190
+
191
+ hidden_states = self.down_proj(up_proj_states * gate_states)
192
+ return hidden_states
193
+
194
+
195
+ # Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
196
+ def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
197
+ return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
198
+
199
+
200
+ # Copied from transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions
201
+ def create_sinusoidal_positions(num_pos, dim):
202
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
203
+ freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
204
+
205
+ emb = np.concatenate((freqs, freqs), axis=-1)
206
+ out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
207
+ return jnp.array(out[:, :, :num_pos])
208
+
209
+
210
+ # Copied from transformers.models.llama.modeling_flax_llama.rotate_half
211
+ def rotate_half(tensor):
212
+ """Rotates half the hidden dims of the input."""
213
+ rotate_half_tensor = jnp.concatenate(
214
+ (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
215
+ )
216
+ return rotate_half_tensor
217
+
218
+
219
+ class FlaxCognitivessAttention(nn.Module):
220
+ config: CognitivessConfig
221
+ dtype: jnp.dtype = jnp.float32
222
+
223
+ def setup(self):
224
+ config = self.config
225
+ self.hidden_size = config.hidden_size
226
+ self.num_heads = config.num_attention_heads
227
+ self.head_dim = self.hidden_size // self.num_heads
228
+ self.num_key_value_heads = config.num_key_value_heads
229
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
230
+ self.max_position_embeddings = config.max_position_embeddings
231
+ self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
232
+ self.rope_theta = config.rope_theta
233
+ if (self.head_dim * self.num_heads) != self.hidden_size:
234
+ raise ValueError(
235
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
236
+ f" and `num_heads`: {self.num_heads})."
237
+ )
238
+ self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype)
239
+ self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
240
+ self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
241
+ self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
242
+ casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
243
+ self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
244
+ self.rotary_emb = FlaxCognitivessRotaryEmbedding(config, dtype=self.dtype)
245
+
246
+ def _split_heads(self, hidden_states, num_heads):
247
+ return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
248
+
249
+ def _merge_heads(self, hidden_states):
250
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
251
+
252
+ @nn.compact
253
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
254
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
255
+ """
256
+ This function takes projected key, value states from a single input token and concatenates the states to cached
257
+ states from previous steps. This function is slighly adapted from the official Flax repository:
258
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
259
+ """
260
+ # detect if we're initializing by absence of existing cache data.
261
+ is_initialized = self.has_variable("cache", "cached_key")
262
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
263
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
264
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
265
+
266
+ if is_initialized:
267
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
268
+ # update key, value caches with our new 1d spatial slices
269
+ cur_index = cache_index.value
270
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
271
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
272
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
273
+ cached_key.value = key
274
+ cached_value.value = value
275
+ num_updated_cache_vectors = query.shape[1]
276
+ cache_index.value = cache_index.value + num_updated_cache_vectors
277
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
278
+ pad_mask = jnp.broadcast_to(
279
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
280
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
281
+ )
282
+ attention_mask = combine_masks(pad_mask, attention_mask)
283
+ return key, value, attention_mask
284
+
285
+ def __call__(
286
+ self,
287
+ hidden_states: jnp.ndarray,
288
+ attention_mask: Optional[jnp.ndarray] = None,
289
+ position_ids: Optional[jnp.ndarray] = None,
290
+ deterministic: bool = True,
291
+ output_attentions: bool = False,
292
+ init_cache: bool = False,
293
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
294
+ query_states = self.q_proj(hidden_states)
295
+ key_states = self.k_proj(hidden_states)
296
+ value_states = self.v_proj(hidden_states)
297
+
298
+ query_states = self._split_heads(query_states, self.num_heads)
299
+ key_states = self._split_heads(key_states, self.num_key_value_heads)
300
+ value_states = self._split_heads(value_states, self.num_key_value_heads)
301
+
302
+ key_states, query_states = self.rotary_emb(key_states, query_states, position_ids)
303
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
304
+ if self.has_variable("cache", "cached_key"):
305
+ mask_shift = self.variables["cache"]["cache_index"]
306
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
307
+ causal_mask = lax.dynamic_slice(
308
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
309
+ )
310
+ else:
311
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
312
+
313
+ batch_size = hidden_states.shape[0]
314
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
315
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
316
+ attention_mask = combine_masks(attention_mask, causal_mask)
317
+
318
+ if self.has_variable("cache", "cached_key") or init_cache:
319
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
320
+ key_states, value_states, query_states, attention_mask
321
+ )
322
+ key_states = jnp.repeat(key_states, self.num_key_value_groups, axis=2)
323
+ value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2)
324
+
325
+ attention_bias = lax.select(
326
+ attention_mask > 0,
327
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
328
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
329
+ )
330
+
331
+ # usual dot product attention
332
+ attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
333
+ attn_weights = dot_product_attention_weights(
334
+ query_states,
335
+ key_states,
336
+ bias=attention_bias,
337
+ deterministic=deterministic,
338
+ dropout_rate=self.config.attention_dropout,
339
+ dtype=attention_dtype,
340
+ )
341
+
342
+ if self.attention_softmax_in_fp32:
343
+ attn_weights = attn_weights.astype(self.dtype)
344
+
345
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
346
+ attn_output = self._merge_heads(attn_output)
347
+ attn_output = self.o_proj(attn_output)
348
+
349
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
350
+ return outputs
351
+
352
+
353
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Cognitivess
354
+ class FlaxCognitivessDecoderLayer(nn.Module):
355
+ config: CognitivessConfig
356
+ dtype: jnp.dtype = jnp.float32
357
+
358
+ def setup(self):
359
+ self.input_layernorm = FlaxCognitivessRMSNorm(self.config, dtype=self.dtype)
360
+ self.self_attn = FlaxCognitivessAttention(self.config, dtype=self.dtype)
361
+ self.post_attention_layernorm = FlaxCognitivessRMSNorm(self.config, dtype=self.dtype)
362
+ self.mlp = FlaxCognitivessMLP(self.config, dtype=self.dtype)
363
+
364
+ def __call__(
365
+ self,
366
+ hidden_states,
367
+ attention_mask=None,
368
+ position_ids=None,
369
+ deterministic: bool = True,
370
+ init_cache: bool = False,
371
+ output_attentions: bool = False,
372
+ ):
373
+ residual = hidden_states
374
+ hidden_states = self.input_layernorm(hidden_states)
375
+ outputs = self.self_attn(
376
+ hidden_states,
377
+ attention_mask=attention_mask,
378
+ position_ids=position_ids,
379
+ deterministic=deterministic,
380
+ init_cache=init_cache,
381
+ output_attentions=output_attentions,
382
+ )
383
+ # residual connection
384
+ attn_output = outputs[0]
385
+ hidden_states = residual + attn_output
386
+
387
+ residual = hidden_states
388
+ hidden_states = self.post_attention_layernorm(hidden_states)
389
+ hidden_states = self.mlp(hidden_states)
390
+ # residual connection
391
+ hidden_states = residual + hidden_states
392
+
393
+ return (hidden_states,) + outputs[1:]
394
+
395
+
396
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Cognitivess, GPT_NEO->Cognitivess, transformer->model
397
+ class FlaxCognitivessPreTrainedModel(FlaxPreTrainedModel):
398
+ """
399
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
400
+ models.
401
+ """
402
+
403
+ config_class = CognitivessConfig
404
+ base_model_prefix = "model"
405
+ module_class: nn.Module = None
406
+
407
+ def __init__(
408
+ self,
409
+ config: CognitivessConfig,
410
+ input_shape: Tuple = (1, 1),
411
+ seed: int = 0,
412
+ dtype: jnp.dtype = jnp.float32,
413
+ _do_init: bool = True,
414
+ **kwargs,
415
+ ):
416
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
417
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
418
+
419
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
420
+ # init input tensors
421
+ input_ids = jnp.zeros(input_shape, dtype="i4")
422
+ attention_mask = jnp.ones_like(input_ids)
423
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
424
+ params_rng, dropout_rng = jax.random.split(rng)
425
+ rngs = {"params": params_rng, "dropout": dropout_rng}
426
+
427
+ random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
428
+
429
+ if params is not None:
430
+ random_params = flatten_dict(unfreeze(random_params))
431
+ params = flatten_dict(unfreeze(params))
432
+ for missing_key in self._missing_keys:
433
+ params[missing_key] = random_params[missing_key]
434
+ self._missing_keys = set()
435
+ return freeze(unflatten_dict(params))
436
+ else:
437
+ return random_params
438
+
439
+ def init_cache(self, batch_size, max_length):
440
+ r"""
441
+ Args:
442
+ batch_size (`int`):
443
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
444
+ max_length (`int`):
445
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
446
+ cache.
447
+ """
448
+ # init input variables to retrieve cache
449
+ input_ids = jnp.ones((batch_size, max_length))
450
+ attention_mask = jnp.ones_like(input_ids)
451
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
452
+
453
+ init_variables = self.module.init(
454
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
455
+ )
456
+ return unfreeze(init_variables["cache"])
457
+
458
+ @add_start_docstrings_to_model_forward(Cognitivess_INPUTS_DOCSTRING)
459
+ def __call__(
460
+ self,
461
+ input_ids,
462
+ attention_mask=None,
463
+ position_ids=None,
464
+ params: dict = None,
465
+ past_key_values: dict = None,
466
+ dropout_rng: jax.random.PRNGKey = None,
467
+ train: bool = False,
468
+ output_attentions: Optional[bool] = None,
469
+ output_hidden_states: Optional[bool] = None,
470
+ return_dict: Optional[bool] = None,
471
+ ):
472
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
473
+ output_hidden_states = (
474
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
475
+ )
476
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
477
+
478
+ batch_size, sequence_length = input_ids.shape
479
+
480
+ if position_ids is None:
481
+ if past_key_values is not None:
482
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
483
+
484
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
485
+
486
+ if attention_mask is None:
487
+ attention_mask = jnp.ones((batch_size, sequence_length))
488
+
489
+ # Handle any PRNG if needed
490
+ rngs = {}
491
+ if dropout_rng is not None:
492
+ rngs["dropout"] = dropout_rng
493
+
494
+ inputs = {"params": params or self.params}
495
+
496
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxCognitivessAttention module
497
+ if past_key_values:
498
+ inputs["cache"] = past_key_values
499
+ mutable = ["cache"]
500
+ else:
501
+ mutable = False
502
+
503
+ outputs = self.module.apply(
504
+ inputs,
505
+ jnp.array(input_ids, dtype="i4"),
506
+ jnp.array(attention_mask, dtype="i4"),
507
+ jnp.array(position_ids, dtype="i4"),
508
+ not train,
509
+ False,
510
+ output_attentions,
511
+ output_hidden_states,
512
+ return_dict,
513
+ rngs=rngs,
514
+ mutable=mutable,
515
+ )
516
+
517
+ # add updated cache to model output
518
+ if past_key_values is not None and return_dict:
519
+ outputs, past_key_values = outputs
520
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
521
+ return outputs
522
+ elif past_key_values is not None and not return_dict:
523
+ outputs, past_key_values = outputs
524
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
525
+
526
+ return outputs
527
+
528
+
529
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Cognitivess
530
+ class FlaxCognitivessLayerCollection(nn.Module):
531
+ config: CognitivessConfig
532
+ dtype: jnp.dtype = jnp.float32
533
+
534
+ def setup(self):
535
+ self.blocks = [
536
+ FlaxCognitivessDecoderLayer(self.config, dtype=self.dtype, name=str(i))
537
+ for i in range(self.config.num_hidden_layers)
538
+ ]
539
+
540
+ def __call__(
541
+ self,
542
+ hidden_states,
543
+ attention_mask=None,
544
+ position_ids=None,
545
+ deterministic: bool = True,
546
+ init_cache: bool = False,
547
+ output_attentions: bool = False,
548
+ output_hidden_states: bool = False,
549
+ return_dict: bool = False,
550
+ ):
551
+ all_attentions = () if output_attentions else None
552
+ all_hidden_states = () if output_hidden_states else None
553
+
554
+ for block in self.blocks:
555
+ if output_hidden_states:
556
+ all_hidden_states += (hidden_states,)
557
+ layer_outputs = block(
558
+ hidden_states,
559
+ attention_mask=attention_mask,
560
+ position_ids=position_ids,
561
+ deterministic=deterministic,
562
+ init_cache=init_cache,
563
+ output_attentions=output_attentions,
564
+ )
565
+ hidden_states = layer_outputs[0]
566
+
567
+ if output_attentions:
568
+ all_attentions += (layer_outputs[1],)
569
+
570
+ # this contains possible `None` values - `FlaxCognitivessModule` will filter them out
571
+ outputs = (hidden_states, all_hidden_states, all_attentions)
572
+
573
+ return outputs
574
+
575
+
576
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Cognitivess
577
+ class FlaxCognitivessModule(nn.Module):
578
+ config: CognitivessConfig
579
+ dtype: jnp.dtype = jnp.float32
580
+
581
+ def setup(self):
582
+ self.hidden_size = self.config.hidden_size
583
+ embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
584
+ self.embed_tokens = nn.Embed(
585
+ self.config.vocab_size,
586
+ self.hidden_size,
587
+ embedding_init=embedding_init,
588
+ dtype=self.dtype,
589
+ )
590
+ self.layers = FlaxCognitivessLayerCollection(self.config, dtype=self.dtype)
591
+ self.norm = FlaxCognitivessRMSNorm(self.config, dtype=self.dtype)
592
+
593
+ def __call__(
594
+ self,
595
+ input_ids,
596
+ attention_mask=None,
597
+ position_ids=None,
598
+ deterministic=True,
599
+ init_cache: bool = False,
600
+ output_attentions: bool = False,
601
+ output_hidden_states: bool = False,
602
+ return_dict: bool = True,
603
+ ):
604
+ input_embeds = self.embed_tokens(input_ids.astype("i4"))
605
+
606
+ outputs = self.layers(
607
+ input_embeds,
608
+ position_ids=position_ids,
609
+ attention_mask=attention_mask,
610
+ deterministic=deterministic,
611
+ init_cache=init_cache,
612
+ output_attentions=output_attentions,
613
+ output_hidden_states=output_hidden_states,
614
+ return_dict=return_dict,
615
+ )
616
+
617
+ hidden_states = outputs[0]
618
+ hidden_states = self.norm(hidden_states)
619
+
620
+ if output_hidden_states:
621
+ all_hidden_states = outputs[1] + (hidden_states,)
622
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
623
+ else:
624
+ outputs = (hidden_states,) + outputs[1:]
625
+
626
+ if not return_dict:
627
+ return tuple(v for v in outputs if v is not None)
628
+
629
+ return FlaxBaseModelOutput(
630
+ last_hidden_state=hidden_states,
631
+ hidden_states=outputs[1],
632
+ attentions=outputs[-1],
633
+ )
634
+
635
+
636
+ @add_start_docstrings(
637
+ "The bare Cognitivess Model transformer outputting raw hidden-states without any specific head on top.",
638
+ Cognitivess_START_DOCSTRING,
639
+ )
640
+ class FlaxCognitivessModel(FlaxCognitivessPreTrainedModel):
641
+ module_class = FlaxCognitivessModule
642
+
643
+
644
+ append_call_sample_docstring(
645
+ FlaxCognitivessModel,
646
+ _CHECKPOINT_FOR_DOC,
647
+ FlaxBaseModelOutputWithPast,
648
+ _CONFIG_FOR_DOC,
649
+ real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
650
+ )
651
+
652
+
653
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Cognitivess
654
+ class FlaxCognitivessForCausalLMModule(nn.Module):
655
+ config: CognitivessConfig
656
+ dtype: jnp.dtype = jnp.float32
657
+
658
+ def setup(self):
659
+ self.model = FlaxCognitivessModule(self.config, dtype=self.dtype)
660
+ self.lm_head = nn.Dense(
661
+ self.config.vocab_size,
662
+ use_bias=False,
663
+ dtype=self.dtype,
664
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
665
+ )
666
+
667
+ def __call__(
668
+ self,
669
+ input_ids,
670
+ attention_mask=None,
671
+ position_ids=None,
672
+ deterministic: bool = True,
673
+ init_cache: bool = False,
674
+ output_attentions: bool = False,
675
+ output_hidden_states: bool = False,
676
+ return_dict: bool = True,
677
+ ):
678
+ outputs = self.model(
679
+ input_ids,
680
+ position_ids=position_ids,
681
+ attention_mask=attention_mask,
682
+ deterministic=deterministic,
683
+ init_cache=init_cache,
684
+ output_attentions=output_attentions,
685
+ output_hidden_states=output_hidden_states,
686
+ return_dict=return_dict,
687
+ )
688
+
689
+ hidden_states = outputs[0]
690
+ lm_logits = self.lm_head(hidden_states)
691
+
692
+ if not return_dict:
693
+ return (lm_logits,) + outputs[1:]
694
+
695
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
696
+
697
+
698
+ @add_start_docstrings(
699
+ """
700
+ The Cognitivess Model transformer with a language modeling head (linear layer) on top.
701
+ """,
702
+ Cognitivess_START_DOCSTRING,
703
+ )
704
+
705
+ # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Cognitivess
706
+ class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
707
+ module_class = FlaxCognitivessForCausalLMModule
708
+
709
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
710
+ # initializing the cache
711
+ batch_size, seq_length = input_ids.shape
712
+
713
+ past_key_values = self.init_cache(batch_size, max_length)
714
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
715
+ # But since Cognitivess uses a causal mask, those positions are masked anyways.
716
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
717
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
718
+ if attention_mask is not None:
719
+ position_ids = attention_mask.cumsum(axis=-1) - 1
720
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
721
+ else:
722
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
723
+
724
+ return {
725
+ "past_key_values": past_key_values,
726
+ "attention_mask": extended_attention_mask,
727
+ "position_ids": position_ids,
728
+ }
729
+
730
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
731
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
732
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
733
+ return model_kwargs
734
+
735
+
736
+ append_call_sample_docstring(
737
+ FlaxCognitivessForCausalLM,
738
+ _CHECKPOINT_FOR_DOC,
739
+ FlaxCausalLMOutputWithCrossAttentions,
740
+ _CONFIG_FOR_DOC,
741
+ real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
742
+ )