cognitivess commited on
Commit
97f6008
1 Parent(s): d7df90a

Update cognitivess_model/modeling_flax_Cognitivess.py

Browse files
cognitivess_model/modeling_flax_Cognitivess.py CHANGED
@@ -1,6 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2024 Cognitivess AI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
6
  # You may obtain a copy of the License at
@@ -14,6 +13,7 @@
14
  # limitations under the License.
15
  """Flax Cognitivess model."""
16
 
 
17
  from typing import Optional, Tuple
18
 
19
  import flax.linen as nn
@@ -26,22 +26,17 @@ from flax.linen.attention import dot_product_attention_weights
26
  from flax.traverse_util import flatten_dict, unflatten_dict
27
  from jax import lax
28
 
29
- from ...modeling_flax_outputs import (
30
- FlaxBaseModelOutput,
31
- FlaxBaseModelOutputWithPast,
32
- FlaxCausalLMOutput,
33
- FlaxCausalLMOutputWithCrossAttentions,
34
- )
35
- from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging
36
- from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
37
  from .configuration_Cognitivess import CognitivessConfig
38
 
39
 
40
  logger = logging.get_logger(__name__)
41
 
42
  _CONFIG_FOR_DOC = "CognitivessConfig"
43
- _REAL_CHECKPOINT_FOR_DOC = "CognitivessAI/cognitivess"
44
- _CHECKPOINT_FOR_DOC = "ksmcg/Cognitivess-tiny"
45
 
46
  Cognitivess_START_DOCSTRING = r"""
47
 
@@ -127,7 +122,27 @@ Cognitivess_INPUTS_DOCSTRING = r"""
127
  """
128
 
129
 
130
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRMSNorm with Llama->Cognitivess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  class FlaxCognitivessRMSNorm(nn.Module):
132
  config: CognitivessConfig
133
  dtype: jnp.dtype = jnp.float32
@@ -146,7 +161,6 @@ class FlaxCognitivessRMSNorm(nn.Module):
146
  return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
147
 
148
 
149
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Cognitivess
150
  class FlaxCognitivessRotaryEmbedding(nn.Module):
151
  config: CognitivessConfig
152
  dtype: jnp.dtype = jnp.float32
@@ -168,86 +182,46 @@ class FlaxCognitivessRotaryEmbedding(nn.Module):
168
  return key, query
169
 
170
 
171
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Cognitivess
172
- class FlaxCognitivessMLP(nn.Module):
173
- config: CognitivessConfig
174
- dtype: jnp.dtype = jnp.float32
175
-
176
- def setup(self):
177
- embed_dim = self.config.hidden_size
178
- inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
179
-
180
- kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
181
- self.act = ACT2FN[self.config.hidden_act]
182
-
183
- self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
184
- self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
185
- self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
186
-
187
- def __call__(self, hidden_states):
188
- up_proj_states = self.up_proj(hidden_states)
189
- gate_states = self.act(self.gate_proj(hidden_states))
190
-
191
- hidden_states = self.down_proj(up_proj_states * gate_states)
192
- return hidden_states
193
-
194
-
195
- # Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
196
- def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
197
- return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
198
-
199
-
200
- # Copied from transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions
201
- def create_sinusoidal_positions(num_pos, dim):
202
- inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
203
- freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
204
-
205
- emb = np.concatenate((freqs, freqs), axis=-1)
206
- out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
207
- return jnp.array(out[:, :, :num_pos])
208
-
209
-
210
- # Copied from transformers.models.llama.modeling_flax_llama.rotate_half
211
- def rotate_half(tensor):
212
- """Rotates half the hidden dims of the input."""
213
- rotate_half_tensor = jnp.concatenate(
214
- (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
215
- )
216
- return rotate_half_tensor
217
-
218
-
219
  class FlaxCognitivessAttention(nn.Module):
220
  config: CognitivessConfig
221
  dtype: jnp.dtype = jnp.float32
 
 
222
 
223
  def setup(self):
224
  config = self.config
225
- self.hidden_size = config.hidden_size
226
  self.num_heads = config.num_attention_heads
227
- self.head_dim = self.hidden_size // self.num_heads
228
  self.num_key_value_heads = config.num_key_value_heads
229
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
230
- self.max_position_embeddings = config.max_position_embeddings
231
  self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
232
- self.rope_theta = config.rope_theta
233
- if (self.head_dim * self.num_heads) != self.hidden_size:
 
 
 
 
 
 
 
 
 
 
 
234
  raise ValueError(
235
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
236
  f" and `num_heads`: {self.num_heads})."
237
  )
238
- self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype)
239
- self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
240
- self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
241
- self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
242
- casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
243
- self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
244
  self.rotary_emb = FlaxCognitivessRotaryEmbedding(config, dtype=self.dtype)
245
 
246
  def _split_heads(self, hidden_states, num_heads):
247
  return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
248
 
249
  def _merge_heads(self, hidden_states):
250
- return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
251
 
252
  @nn.compact
253
  # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
@@ -284,23 +258,25 @@ class FlaxCognitivessAttention(nn.Module):
284
 
285
  def __call__(
286
  self,
287
- hidden_states: jnp.ndarray,
288
- attention_mask: Optional[jnp.ndarray] = None,
289
- position_ids: Optional[jnp.ndarray] = None,
290
  deterministic: bool = True,
291
- output_attentions: bool = False,
292
  init_cache: bool = False,
293
- ) -> Tuple[jnp.ndarray, jnp.ndarray]:
294
- query_states = self.q_proj(hidden_states)
295
- key_states = self.k_proj(hidden_states)
296
- value_states = self.v_proj(hidden_states)
 
 
 
 
 
297
 
298
- query_states = self._split_heads(query_states, self.num_heads)
299
- key_states = self._split_heads(key_states, self.num_key_value_heads)
300
- value_states = self._split_heads(value_states, self.num_key_value_heads)
301
 
302
- key_states, query_states = self.rotary_emb(key_states, query_states, position_ids)
303
- query_length, key_length = query_states.shape[1], key_states.shape[1]
304
  if self.has_variable("cache", "cached_key"):
305
  mask_shift = self.variables["cache"]["cache_index"]
306
  max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
@@ -312,16 +288,23 @@ class FlaxCognitivessAttention(nn.Module):
312
 
313
  batch_size = hidden_states.shape[0]
314
  causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
 
315
  attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
316
  attention_mask = combine_masks(attention_mask, causal_mask)
317
 
 
 
 
 
 
 
318
  if self.has_variable("cache", "cached_key") or init_cache:
319
- key_states, value_states, attention_mask = self._concatenate_to_cache(
320
- key_states, value_states, query_states, attention_mask
321
- )
322
- key_states = jnp.repeat(key_states, self.num_key_value_groups, axis=2)
323
- value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2)
324
 
 
325
  attention_bias = lax.select(
326
  attention_mask > 0,
327
  jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
@@ -331,18 +314,19 @@ class FlaxCognitivessAttention(nn.Module):
331
  # usual dot product attention
332
  attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
333
  attn_weights = dot_product_attention_weights(
334
- query_states,
335
- key_states,
336
  bias=attention_bias,
337
- deterministic=deterministic,
338
  dropout_rate=self.config.attention_dropout,
 
339
  dtype=attention_dtype,
340
  )
341
 
342
  if self.attention_softmax_in_fp32:
343
  attn_weights = attn_weights.astype(self.dtype)
344
 
345
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
346
  attn_output = self._merge_heads(attn_output)
347
  attn_output = self.o_proj(attn_output)
348
 
@@ -350,7 +334,29 @@ class FlaxCognitivessAttention(nn.Module):
350
  return outputs
351
 
352
 
353
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Cognitivess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  class FlaxCognitivessDecoderLayer(nn.Module):
355
  config: CognitivessConfig
356
  dtype: jnp.dtype = jnp.float32
@@ -526,7 +532,6 @@ class FlaxCognitivessPreTrainedModel(FlaxPreTrainedModel):
526
  return outputs
527
 
528
 
529
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Cognitivess
530
  class FlaxCognitivessLayerCollection(nn.Module):
531
  config: CognitivessConfig
532
  dtype: jnp.dtype = jnp.float32
@@ -573,7 +578,6 @@ class FlaxCognitivessLayerCollection(nn.Module):
573
  return outputs
574
 
575
 
576
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Cognitivess
577
  class FlaxCognitivessModule(nn.Module):
578
  config: CognitivessConfig
579
  dtype: jnp.dtype = jnp.float32
@@ -644,13 +648,12 @@ class FlaxCognitivessModel(FlaxCognitivessPreTrainedModel):
644
  append_call_sample_docstring(
645
  FlaxCognitivessModel,
646
  _CHECKPOINT_FOR_DOC,
647
- FlaxBaseModelOutputWithPast,
648
  _CONFIG_FOR_DOC,
649
  real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
650
  )
651
 
652
 
653
- # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Cognitivess
654
  class FlaxCognitivessForCausalLMModule(nn.Module):
655
  config: CognitivessConfig
656
  dtype: jnp.dtype = jnp.float32
@@ -701,7 +704,6 @@ class FlaxCognitivessForCausalLMModule(nn.Module):
701
  """,
702
  Cognitivess_START_DOCSTRING,
703
  )
704
-
705
  # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Cognitivess
706
  class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
707
  module_class = FlaxCognitivessForCausalLMModule
@@ -736,7 +738,7 @@ class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
736
  append_call_sample_docstring(
737
  FlaxCognitivessForCausalLM,
738
  _CHECKPOINT_FOR_DOC,
739
- FlaxCausalLMOutputWithCrossAttentions,
740
  _CONFIG_FOR_DOC,
741
  real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
742
  )
 
1
  # coding=utf-8
2
+ # Copyright 2023 Cognitivess and the HuggingFace Inc. team. All rights reserved.
 
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
  # You may obtain a copy of the License at
 
13
  # limitations under the License.
14
  """Flax Cognitivess model."""
15
 
16
+ from functools import partial
17
  from typing import Optional, Tuple
18
 
19
  import flax.linen as nn
 
26
  from flax.traverse_util import flatten_dict, unflatten_dict
27
  from jax import lax
28
 
29
+ from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
+ from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
31
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
 
 
 
 
 
32
  from .configuration_Cognitivess import CognitivessConfig
33
 
34
 
35
  logger = logging.get_logger(__name__)
36
 
37
  _CONFIG_FOR_DOC = "CognitivessConfig"
38
+ _CHECKPOINT_FOR_DOC = "afmck/testing-Cognitivess-tiny"
39
+ _REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_Cognitivess_3b_v2"
40
 
41
  Cognitivess_START_DOCSTRING = r"""
42
 
 
122
  """
123
 
124
 
125
+ def create_sinusoidal_positions(num_pos, dim):
126
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
127
+ freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
128
+
129
+ emb = np.concatenate((freqs, freqs), axis=-1)
130
+ out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
131
+ return jnp.array(out[:, :, :num_pos])
132
+
133
+
134
+ def rotate_half(tensor):
135
+ """Rotates half the hidden dims of the input."""
136
+ rotate_half_tensor = jnp.concatenate(
137
+ (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
138
+ )
139
+ return rotate_half_tensor
140
+
141
+
142
+ def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
143
+ return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
144
+
145
+
146
  class FlaxCognitivessRMSNorm(nn.Module):
147
  config: CognitivessConfig
148
  dtype: jnp.dtype = jnp.float32
 
161
  return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
162
 
163
 
 
164
  class FlaxCognitivessRotaryEmbedding(nn.Module):
165
  config: CognitivessConfig
166
  dtype: jnp.dtype = jnp.float32
 
182
  return key, query
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  class FlaxCognitivessAttention(nn.Module):
186
  config: CognitivessConfig
187
  dtype: jnp.dtype = jnp.float32
188
+ causal: bool = True
189
+ is_cross_attention: bool = False
190
 
191
  def setup(self):
192
  config = self.config
193
+ self.embed_dim = config.hidden_size
194
  self.num_heads = config.num_attention_heads
195
+ self.head_dim = self.embed_dim // self.num_heads
196
  self.num_key_value_heads = config.num_key_value_heads
197
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
 
198
  self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
199
+
200
+ dense = partial(
201
+ nn.Dense,
202
+ use_bias=config.attention_bias,
203
+ dtype=self.dtype,
204
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
205
+ )
206
+
207
+ self.q_proj = dense(self.num_heads * self.head_dim)
208
+ self.k_proj = dense(self.num_key_value_heads * self.head_dim)
209
+ self.v_proj = dense(self.num_key_value_heads * self.head_dim)
210
+ self.o_proj = dense(self.embed_dim)
211
+ if (self.head_dim * self.num_heads) != self.embed_dim:
212
  raise ValueError(
213
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
214
  f" and `num_heads`: {self.num_heads})."
215
  )
216
+
217
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
 
 
 
 
218
  self.rotary_emb = FlaxCognitivessRotaryEmbedding(config, dtype=self.dtype)
219
 
220
  def _split_heads(self, hidden_states, num_heads):
221
  return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
222
 
223
  def _merge_heads(self, hidden_states):
224
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
225
 
226
  @nn.compact
227
  # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
 
258
 
259
  def __call__(
260
  self,
261
+ hidden_states,
262
+ attention_mask,
263
+ position_ids,
264
  deterministic: bool = True,
 
265
  init_cache: bool = False,
266
+ output_attentions: bool = False,
267
+ ):
268
+ query = self.q_proj(hidden_states)
269
+ key = self.k_proj(hidden_states)
270
+ value = self.v_proj(hidden_states)
271
+
272
+ query = self._split_heads(query, self.num_heads)
273
+ key = self._split_heads(key, self.num_key_value_heads)
274
+ value = self._split_heads(value, self.num_key_value_heads)
275
 
276
+ key, query = self.rotary_emb(key, query, position_ids)
277
+
278
+ query_length, key_length = query.shape[1], key.shape[1]
279
 
 
 
280
  if self.has_variable("cache", "cached_key"):
281
  mask_shift = self.variables["cache"]["cache_index"]
282
  max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
 
288
 
289
  batch_size = hidden_states.shape[0]
290
  causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
291
+
292
  attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
293
  attention_mask = combine_masks(attention_mask, causal_mask)
294
 
295
+ dropout_rng = None
296
+ if not deterministic and self.config.attention_dropout > 0.0:
297
+ dropout_rng = self.make_rng("dropout")
298
+
299
+ # During fast autoregressive decoding, we feed one position at a time,
300
+ # and cache the keys and values step by step.
301
  if self.has_variable("cache", "cached_key") or init_cache:
302
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
303
+
304
+ key = jnp.repeat(key, self.num_key_value_groups, axis=2)
305
+ value = jnp.repeat(value, self.num_key_value_groups, axis=2)
 
306
 
307
+ # transform boolean mask into float mask
308
  attention_bias = lax.select(
309
  attention_mask > 0,
310
  jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
 
314
  # usual dot product attention
315
  attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
316
  attn_weights = dot_product_attention_weights(
317
+ query,
318
+ key,
319
  bias=attention_bias,
320
+ dropout_rng=dropout_rng,
321
  dropout_rate=self.config.attention_dropout,
322
+ deterministic=deterministic,
323
  dtype=attention_dtype,
324
  )
325
 
326
  if self.attention_softmax_in_fp32:
327
  attn_weights = attn_weights.astype(self.dtype)
328
 
329
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
330
  attn_output = self._merge_heads(attn_output)
331
  attn_output = self.o_proj(attn_output)
332
 
 
334
  return outputs
335
 
336
 
337
+ class FlaxCognitivessMLP(nn.Module):
338
+ config: CognitivessConfig
339
+ dtype: jnp.dtype = jnp.float32
340
+
341
+ def setup(self):
342
+ embed_dim = self.config.hidden_size
343
+ inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
344
+
345
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
346
+ self.act = ACT2FN[self.config.hidden_act]
347
+
348
+ self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
349
+ self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
350
+ self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
351
+
352
+ def __call__(self, hidden_states):
353
+ up_proj_states = self.up_proj(hidden_states)
354
+ gate_states = self.act(self.gate_proj(hidden_states))
355
+
356
+ hidden_states = self.down_proj(up_proj_states * gate_states)
357
+ return hidden_states
358
+
359
+
360
  class FlaxCognitivessDecoderLayer(nn.Module):
361
  config: CognitivessConfig
362
  dtype: jnp.dtype = jnp.float32
 
532
  return outputs
533
 
534
 
 
535
  class FlaxCognitivessLayerCollection(nn.Module):
536
  config: CognitivessConfig
537
  dtype: jnp.dtype = jnp.float32
 
578
  return outputs
579
 
580
 
 
581
  class FlaxCognitivessModule(nn.Module):
582
  config: CognitivessConfig
583
  dtype: jnp.dtype = jnp.float32
 
648
  append_call_sample_docstring(
649
  FlaxCognitivessModel,
650
  _CHECKPOINT_FOR_DOC,
651
+ FlaxBaseModelOutput,
652
  _CONFIG_FOR_DOC,
653
  real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
654
  )
655
 
656
 
 
657
  class FlaxCognitivessForCausalLMModule(nn.Module):
658
  config: CognitivessConfig
659
  dtype: jnp.dtype = jnp.float32
 
704
  """,
705
  Cognitivess_START_DOCSTRING,
706
  )
 
707
  # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Cognitivess
708
  class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
709
  module_class = FlaxCognitivessForCausalLMModule
 
738
  append_call_sample_docstring(
739
  FlaxCognitivessForCausalLM,
740
  _CHECKPOINT_FOR_DOC,
741
+ FlaxCausalLMOutput,
742
  _CONFIG_FOR_DOC,
743
  real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
744
  )