cognitivess
commited on
Commit
•
97f6008
1
Parent(s):
d7df90a
Update cognitivess_model/modeling_flax_Cognitivess.py
Browse files
cognitivess_model/modeling_flax_Cognitivess.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
# coding=utf-8
|
2 |
-
# Copyright
|
3 |
-
#
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
# you may not use this file except in compliance with the License.
|
6 |
# You may obtain a copy of the License at
|
@@ -14,6 +13,7 @@
|
|
14 |
# limitations under the License.
|
15 |
"""Flax Cognitivess model."""
|
16 |
|
|
|
17 |
from typing import Optional, Tuple
|
18 |
|
19 |
import flax.linen as nn
|
@@ -26,22 +26,17 @@ from flax.linen.attention import dot_product_attention_weights
|
|
26 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
27 |
from jax import lax
|
28 |
|
29 |
-
from ...modeling_flax_outputs import
|
30 |
-
|
31 |
-
|
32 |
-
FlaxCausalLMOutput,
|
33 |
-
FlaxCausalLMOutputWithCrossAttentions,
|
34 |
-
)
|
35 |
-
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging
|
36 |
-
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
37 |
from .configuration_Cognitivess import CognitivessConfig
|
38 |
|
39 |
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|
42 |
_CONFIG_FOR_DOC = "CognitivessConfig"
|
43 |
-
|
44 |
-
|
45 |
|
46 |
Cognitivess_START_DOCSTRING = r"""
|
47 |
|
@@ -127,7 +122,27 @@ Cognitivess_INPUTS_DOCSTRING = r"""
|
|
127 |
"""
|
128 |
|
129 |
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
class FlaxCognitivessRMSNorm(nn.Module):
|
132 |
config: CognitivessConfig
|
133 |
dtype: jnp.dtype = jnp.float32
|
@@ -146,7 +161,6 @@ class FlaxCognitivessRMSNorm(nn.Module):
|
|
146 |
return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
|
147 |
|
148 |
|
149 |
-
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Cognitivess
|
150 |
class FlaxCognitivessRotaryEmbedding(nn.Module):
|
151 |
config: CognitivessConfig
|
152 |
dtype: jnp.dtype = jnp.float32
|
@@ -168,86 +182,46 @@ class FlaxCognitivessRotaryEmbedding(nn.Module):
|
|
168 |
return key, query
|
169 |
|
170 |
|
171 |
-
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Cognitivess
|
172 |
-
class FlaxCognitivessMLP(nn.Module):
|
173 |
-
config: CognitivessConfig
|
174 |
-
dtype: jnp.dtype = jnp.float32
|
175 |
-
|
176 |
-
def setup(self):
|
177 |
-
embed_dim = self.config.hidden_size
|
178 |
-
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
|
179 |
-
|
180 |
-
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
181 |
-
self.act = ACT2FN[self.config.hidden_act]
|
182 |
-
|
183 |
-
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
184 |
-
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
185 |
-
self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
186 |
-
|
187 |
-
def __call__(self, hidden_states):
|
188 |
-
up_proj_states = self.up_proj(hidden_states)
|
189 |
-
gate_states = self.act(self.gate_proj(hidden_states))
|
190 |
-
|
191 |
-
hidden_states = self.down_proj(up_proj_states * gate_states)
|
192 |
-
return hidden_states
|
193 |
-
|
194 |
-
|
195 |
-
# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
|
196 |
-
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
|
197 |
-
return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
|
198 |
-
|
199 |
-
|
200 |
-
# Copied from transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions
|
201 |
-
def create_sinusoidal_positions(num_pos, dim):
|
202 |
-
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
|
203 |
-
freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
|
204 |
-
|
205 |
-
emb = np.concatenate((freqs, freqs), axis=-1)
|
206 |
-
out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
|
207 |
-
return jnp.array(out[:, :, :num_pos])
|
208 |
-
|
209 |
-
|
210 |
-
# Copied from transformers.models.llama.modeling_flax_llama.rotate_half
|
211 |
-
def rotate_half(tensor):
|
212 |
-
"""Rotates half the hidden dims of the input."""
|
213 |
-
rotate_half_tensor = jnp.concatenate(
|
214 |
-
(-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
|
215 |
-
)
|
216 |
-
return rotate_half_tensor
|
217 |
-
|
218 |
-
|
219 |
class FlaxCognitivessAttention(nn.Module):
|
220 |
config: CognitivessConfig
|
221 |
dtype: jnp.dtype = jnp.float32
|
|
|
|
|
222 |
|
223 |
def setup(self):
|
224 |
config = self.config
|
225 |
-
self.
|
226 |
self.num_heads = config.num_attention_heads
|
227 |
-
self.head_dim = self.
|
228 |
self.num_key_value_heads = config.num_key_value_heads
|
229 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
230 |
-
self.max_position_embeddings = config.max_position_embeddings
|
231 |
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
raise ValueError(
|
235 |
-
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.
|
236 |
f" and `num_heads`: {self.num_heads})."
|
237 |
)
|
238 |
-
|
239 |
-
self.
|
240 |
-
self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
|
241 |
-
self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
|
242 |
-
casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
243 |
-
self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
|
244 |
self.rotary_emb = FlaxCognitivessRotaryEmbedding(config, dtype=self.dtype)
|
245 |
|
246 |
def _split_heads(self, hidden_states, num_heads):
|
247 |
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
248 |
|
249 |
def _merge_heads(self, hidden_states):
|
250 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.
|
251 |
|
252 |
@nn.compact
|
253 |
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
|
@@ -284,23 +258,25 @@ class FlaxCognitivessAttention(nn.Module):
|
|
284 |
|
285 |
def __call__(
|
286 |
self,
|
287 |
-
hidden_states
|
288 |
-
attention_mask
|
289 |
-
position_ids
|
290 |
deterministic: bool = True,
|
291 |
-
output_attentions: bool = False,
|
292 |
init_cache: bool = False,
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
|
302 |
-
key_states, query_states = self.rotary_emb(key_states, query_states, position_ids)
|
303 |
-
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
304 |
if self.has_variable("cache", "cached_key"):
|
305 |
mask_shift = self.variables["cache"]["cache_index"]
|
306 |
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
@@ -312,16 +288,23 @@ class FlaxCognitivessAttention(nn.Module):
|
|
312 |
|
313 |
batch_size = hidden_states.shape[0]
|
314 |
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
|
|
315 |
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
316 |
attention_mask = combine_masks(attention_mask, causal_mask)
|
317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
if self.has_variable("cache", "cached_key") or init_cache:
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2)
|
324 |
|
|
|
325 |
attention_bias = lax.select(
|
326 |
attention_mask > 0,
|
327 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
@@ -331,18 +314,19 @@ class FlaxCognitivessAttention(nn.Module):
|
|
331 |
# usual dot product attention
|
332 |
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
|
333 |
attn_weights = dot_product_attention_weights(
|
334 |
-
|
335 |
-
|
336 |
bias=attention_bias,
|
337 |
-
|
338 |
dropout_rate=self.config.attention_dropout,
|
|
|
339 |
dtype=attention_dtype,
|
340 |
)
|
341 |
|
342 |
if self.attention_softmax_in_fp32:
|
343 |
attn_weights = attn_weights.astype(self.dtype)
|
344 |
|
345 |
-
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights,
|
346 |
attn_output = self._merge_heads(attn_output)
|
347 |
attn_output = self.o_proj(attn_output)
|
348 |
|
@@ -350,7 +334,29 @@ class FlaxCognitivessAttention(nn.Module):
|
|
350 |
return outputs
|
351 |
|
352 |
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
class FlaxCognitivessDecoderLayer(nn.Module):
|
355 |
config: CognitivessConfig
|
356 |
dtype: jnp.dtype = jnp.float32
|
@@ -526,7 +532,6 @@ class FlaxCognitivessPreTrainedModel(FlaxPreTrainedModel):
|
|
526 |
return outputs
|
527 |
|
528 |
|
529 |
-
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Cognitivess
|
530 |
class FlaxCognitivessLayerCollection(nn.Module):
|
531 |
config: CognitivessConfig
|
532 |
dtype: jnp.dtype = jnp.float32
|
@@ -573,7 +578,6 @@ class FlaxCognitivessLayerCollection(nn.Module):
|
|
573 |
return outputs
|
574 |
|
575 |
|
576 |
-
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Cognitivess
|
577 |
class FlaxCognitivessModule(nn.Module):
|
578 |
config: CognitivessConfig
|
579 |
dtype: jnp.dtype = jnp.float32
|
@@ -644,13 +648,12 @@ class FlaxCognitivessModel(FlaxCognitivessPreTrainedModel):
|
|
644 |
append_call_sample_docstring(
|
645 |
FlaxCognitivessModel,
|
646 |
_CHECKPOINT_FOR_DOC,
|
647 |
-
|
648 |
_CONFIG_FOR_DOC,
|
649 |
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
|
650 |
)
|
651 |
|
652 |
|
653 |
-
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Cognitivess
|
654 |
class FlaxCognitivessForCausalLMModule(nn.Module):
|
655 |
config: CognitivessConfig
|
656 |
dtype: jnp.dtype = jnp.float32
|
@@ -701,7 +704,6 @@ class FlaxCognitivessForCausalLMModule(nn.Module):
|
|
701 |
""",
|
702 |
Cognitivess_START_DOCSTRING,
|
703 |
)
|
704 |
-
|
705 |
# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Cognitivess
|
706 |
class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
|
707 |
module_class = FlaxCognitivessForCausalLMModule
|
@@ -736,7 +738,7 @@ class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
|
|
736 |
append_call_sample_docstring(
|
737 |
FlaxCognitivessForCausalLM,
|
738 |
_CHECKPOINT_FOR_DOC,
|
739 |
-
|
740 |
_CONFIG_FOR_DOC,
|
741 |
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
|
742 |
)
|
|
|
1 |
# coding=utf-8
|
2 |
+
# Copyright 2023 Cognitivess and the HuggingFace Inc. team. All rights reserved.
|
|
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
5 |
# You may obtain a copy of the License at
|
|
|
13 |
# limitations under the License.
|
14 |
"""Flax Cognitivess model."""
|
15 |
|
16 |
+
from functools import partial
|
17 |
from typing import Optional, Tuple
|
18 |
|
19 |
import flax.linen as nn
|
|
|
26 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
27 |
from jax import lax
|
28 |
|
29 |
+
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
30 |
+
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
31 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
|
|
|
|
|
|
|
|
|
|
32 |
from .configuration_Cognitivess import CognitivessConfig
|
33 |
|
34 |
|
35 |
logger = logging.get_logger(__name__)
|
36 |
|
37 |
_CONFIG_FOR_DOC = "CognitivessConfig"
|
38 |
+
_CHECKPOINT_FOR_DOC = "afmck/testing-Cognitivess-tiny"
|
39 |
+
_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_Cognitivess_3b_v2"
|
40 |
|
41 |
Cognitivess_START_DOCSTRING = r"""
|
42 |
|
|
|
122 |
"""
|
123 |
|
124 |
|
125 |
+
def create_sinusoidal_positions(num_pos, dim):
|
126 |
+
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
|
127 |
+
freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
|
128 |
+
|
129 |
+
emb = np.concatenate((freqs, freqs), axis=-1)
|
130 |
+
out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
|
131 |
+
return jnp.array(out[:, :, :num_pos])
|
132 |
+
|
133 |
+
|
134 |
+
def rotate_half(tensor):
|
135 |
+
"""Rotates half the hidden dims of the input."""
|
136 |
+
rotate_half_tensor = jnp.concatenate(
|
137 |
+
(-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
|
138 |
+
)
|
139 |
+
return rotate_half_tensor
|
140 |
+
|
141 |
+
|
142 |
+
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
|
143 |
+
return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
|
144 |
+
|
145 |
+
|
146 |
class FlaxCognitivessRMSNorm(nn.Module):
|
147 |
config: CognitivessConfig
|
148 |
dtype: jnp.dtype = jnp.float32
|
|
|
161 |
return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
|
162 |
|
163 |
|
|
|
164 |
class FlaxCognitivessRotaryEmbedding(nn.Module):
|
165 |
config: CognitivessConfig
|
166 |
dtype: jnp.dtype = jnp.float32
|
|
|
182 |
return key, query
|
183 |
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
class FlaxCognitivessAttention(nn.Module):
|
186 |
config: CognitivessConfig
|
187 |
dtype: jnp.dtype = jnp.float32
|
188 |
+
causal: bool = True
|
189 |
+
is_cross_attention: bool = False
|
190 |
|
191 |
def setup(self):
|
192 |
config = self.config
|
193 |
+
self.embed_dim = config.hidden_size
|
194 |
self.num_heads = config.num_attention_heads
|
195 |
+
self.head_dim = self.embed_dim // self.num_heads
|
196 |
self.num_key_value_heads = config.num_key_value_heads
|
197 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
198 |
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
|
199 |
+
|
200 |
+
dense = partial(
|
201 |
+
nn.Dense,
|
202 |
+
use_bias=config.attention_bias,
|
203 |
+
dtype=self.dtype,
|
204 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
205 |
+
)
|
206 |
+
|
207 |
+
self.q_proj = dense(self.num_heads * self.head_dim)
|
208 |
+
self.k_proj = dense(self.num_key_value_heads * self.head_dim)
|
209 |
+
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
|
210 |
+
self.o_proj = dense(self.embed_dim)
|
211 |
+
if (self.head_dim * self.num_heads) != self.embed_dim:
|
212 |
raise ValueError(
|
213 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
|
214 |
f" and `num_heads`: {self.num_heads})."
|
215 |
)
|
216 |
+
|
217 |
+
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
|
|
|
|
|
|
|
|
218 |
self.rotary_emb = FlaxCognitivessRotaryEmbedding(config, dtype=self.dtype)
|
219 |
|
220 |
def _split_heads(self, hidden_states, num_heads):
|
221 |
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
222 |
|
223 |
def _merge_heads(self, hidden_states):
|
224 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
225 |
|
226 |
@nn.compact
|
227 |
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
|
|
|
258 |
|
259 |
def __call__(
|
260 |
self,
|
261 |
+
hidden_states,
|
262 |
+
attention_mask,
|
263 |
+
position_ids,
|
264 |
deterministic: bool = True,
|
|
|
265 |
init_cache: bool = False,
|
266 |
+
output_attentions: bool = False,
|
267 |
+
):
|
268 |
+
query = self.q_proj(hidden_states)
|
269 |
+
key = self.k_proj(hidden_states)
|
270 |
+
value = self.v_proj(hidden_states)
|
271 |
+
|
272 |
+
query = self._split_heads(query, self.num_heads)
|
273 |
+
key = self._split_heads(key, self.num_key_value_heads)
|
274 |
+
value = self._split_heads(value, self.num_key_value_heads)
|
275 |
|
276 |
+
key, query = self.rotary_emb(key, query, position_ids)
|
277 |
+
|
278 |
+
query_length, key_length = query.shape[1], key.shape[1]
|
279 |
|
|
|
|
|
280 |
if self.has_variable("cache", "cached_key"):
|
281 |
mask_shift = self.variables["cache"]["cache_index"]
|
282 |
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
|
|
288 |
|
289 |
batch_size = hidden_states.shape[0]
|
290 |
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
291 |
+
|
292 |
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
293 |
attention_mask = combine_masks(attention_mask, causal_mask)
|
294 |
|
295 |
+
dropout_rng = None
|
296 |
+
if not deterministic and self.config.attention_dropout > 0.0:
|
297 |
+
dropout_rng = self.make_rng("dropout")
|
298 |
+
|
299 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
300 |
+
# and cache the keys and values step by step.
|
301 |
if self.has_variable("cache", "cached_key") or init_cache:
|
302 |
+
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
303 |
+
|
304 |
+
key = jnp.repeat(key, self.num_key_value_groups, axis=2)
|
305 |
+
value = jnp.repeat(value, self.num_key_value_groups, axis=2)
|
|
|
306 |
|
307 |
+
# transform boolean mask into float mask
|
308 |
attention_bias = lax.select(
|
309 |
attention_mask > 0,
|
310 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
|
|
314 |
# usual dot product attention
|
315 |
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
|
316 |
attn_weights = dot_product_attention_weights(
|
317 |
+
query,
|
318 |
+
key,
|
319 |
bias=attention_bias,
|
320 |
+
dropout_rng=dropout_rng,
|
321 |
dropout_rate=self.config.attention_dropout,
|
322 |
+
deterministic=deterministic,
|
323 |
dtype=attention_dtype,
|
324 |
)
|
325 |
|
326 |
if self.attention_softmax_in_fp32:
|
327 |
attn_weights = attn_weights.astype(self.dtype)
|
328 |
|
329 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
330 |
attn_output = self._merge_heads(attn_output)
|
331 |
attn_output = self.o_proj(attn_output)
|
332 |
|
|
|
334 |
return outputs
|
335 |
|
336 |
|
337 |
+
class FlaxCognitivessMLP(nn.Module):
|
338 |
+
config: CognitivessConfig
|
339 |
+
dtype: jnp.dtype = jnp.float32
|
340 |
+
|
341 |
+
def setup(self):
|
342 |
+
embed_dim = self.config.hidden_size
|
343 |
+
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
|
344 |
+
|
345 |
+
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
346 |
+
self.act = ACT2FN[self.config.hidden_act]
|
347 |
+
|
348 |
+
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
349 |
+
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
350 |
+
self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
|
351 |
+
|
352 |
+
def __call__(self, hidden_states):
|
353 |
+
up_proj_states = self.up_proj(hidden_states)
|
354 |
+
gate_states = self.act(self.gate_proj(hidden_states))
|
355 |
+
|
356 |
+
hidden_states = self.down_proj(up_proj_states * gate_states)
|
357 |
+
return hidden_states
|
358 |
+
|
359 |
+
|
360 |
class FlaxCognitivessDecoderLayer(nn.Module):
|
361 |
config: CognitivessConfig
|
362 |
dtype: jnp.dtype = jnp.float32
|
|
|
532 |
return outputs
|
533 |
|
534 |
|
|
|
535 |
class FlaxCognitivessLayerCollection(nn.Module):
|
536 |
config: CognitivessConfig
|
537 |
dtype: jnp.dtype = jnp.float32
|
|
|
578 |
return outputs
|
579 |
|
580 |
|
|
|
581 |
class FlaxCognitivessModule(nn.Module):
|
582 |
config: CognitivessConfig
|
583 |
dtype: jnp.dtype = jnp.float32
|
|
|
648 |
append_call_sample_docstring(
|
649 |
FlaxCognitivessModel,
|
650 |
_CHECKPOINT_FOR_DOC,
|
651 |
+
FlaxBaseModelOutput,
|
652 |
_CONFIG_FOR_DOC,
|
653 |
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
|
654 |
)
|
655 |
|
656 |
|
|
|
657 |
class FlaxCognitivessForCausalLMModule(nn.Module):
|
658 |
config: CognitivessConfig
|
659 |
dtype: jnp.dtype = jnp.float32
|
|
|
704 |
""",
|
705 |
Cognitivess_START_DOCSTRING,
|
706 |
)
|
|
|
707 |
# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Cognitivess
|
708 |
class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
|
709 |
module_class = FlaxCognitivessForCausalLMModule
|
|
|
738 |
append_call_sample_docstring(
|
739 |
FlaxCognitivessForCausalLM,
|
740 |
_CHECKPOINT_FOR_DOC,
|
741 |
+
FlaxCausalLMOutput,
|
742 |
_CONFIG_FOR_DOC,
|
743 |
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
|
744 |
)
|