JonasGeiping commited on
Commit
9216f55
·
verified ·
1 Parent(s): ff774db

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +26 -11
raven_modeling_minimal.py CHANGED
@@ -11,7 +11,7 @@ from .raven_config_minimal import RavenConfig
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
- from transformers import PreTrainedModel
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
- print("Random Initialization not implemented.")
 
36
 
37
 
38
  @dataclass
@@ -241,7 +242,6 @@ class CausalSelfAttention(torch.nn.Module):
241
  if past_key_values is not None:
242
  k, v = past_key_values.update(k, v, step_idx)
243
 
244
- return_attn = False
245
  if return_attn:
246
  y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
247
  else:
@@ -310,7 +310,7 @@ class SandwichBlock(torch.nn.Module):
310
  return x, attn_map
311
 
312
 
313
- class RavenForCausalLM(RavenPreTrainedModel):
314
  def __init__(
315
  self,
316
  config: RavenConfig,
@@ -368,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
368
  "return_latents": True,
369
  "return_attention": False,
370
  "return_head": False,
371
- "return_stats": True,
372
  },
373
  use_cache: bool = False,
374
  cache_position: Optional[torch.Tensor] = None,
@@ -396,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
396
  # Non-recurrent prelude
397
  for block_idx, block in enumerate(self.transformer.prelude):
398
  input_embeds, attn_map = block(
399
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
400
  )
401
  attn_maps[block_idx] = attn_map
402
 
@@ -410,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
410
  past_key_values,
411
  num_steps,
412
  attn_maps,
 
413
  )
414
  latent_states = x.clone().detach()
415
 
416
  # Coda layers
417
  for block_idx, block in enumerate(self.transformer.coda, start=1):
418
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
419
  attn_maps[-block_idx] = attn_map
420
  x = self.transformer.ln_f(x)
421
 
@@ -452,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
452
  past_key_values: Optional[Cache] = None,
453
  num_steps: Optional[torch.Tensor] = None,
454
  attn_maps: dict = {},
 
455
  ):
456
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
457
  if num_steps is None:
@@ -469,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
469
  for step in range(num_steps_no_grad):
470
  xk = x
471
  x, block_idx, attn_maps = self.core_block_forward(
472
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
473
  )
474
 
475
  for step in range(num_steps_with_grad):
476
  xk = x
477
  x, block_idx, attn_maps = self.core_block_forward(
478
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
479
  )
480
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
481
 
@@ -488,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
488
  past_key_values,
489
  block_idx: Union[torch.Tensor, int],
490
  attn_maps: dict = {},
 
491
  ):
492
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
493
  for idx, block in enumerate(self.transformer.core_block, start=1):
494
- x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=len(attn_maps) > 0)
495
  attn_maps[block_idx + idx] = attn_map
496
  return x, block_idx + idx, attn_maps
497
 
@@ -624,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
624
  model_inputs["cache_position"] = cache_position
625
  current_input_length = input_ids.shape[1]
626
  if past_key_values is not None:
627
- if type(past_key_values) == DynamicCache:
628
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
629
  assert past_key_values.get_seq_length() == 0
630
  past_key_values = HuginnDynamicCache()
@@ -644,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel):
644
  model_inputs[key] = value
645
  return model_inputs
646
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  @torch.no_grad()
648
  def generate_minimal(
649
  self,
 
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
+ from transformers import PreTrainedModel, GenerationMixin
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
 
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
+ if not torch.rand((1,)).is_meta:
36
+ print("Random Initialization not implemented.")
37
 
38
 
39
  @dataclass
 
242
  if past_key_values is not None:
243
  k, v = past_key_values.update(k, v, step_idx)
244
 
 
245
  if return_attn:
246
  y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
247
  else:
 
310
  return x, attn_map
311
 
312
 
313
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
314
  def __init__(
315
  self,
316
  config: RavenConfig,
 
368
  "return_latents": True,
369
  "return_attention": False,
370
  "return_head": False,
371
+ "return_stats": False,
372
  },
373
  use_cache: bool = False,
374
  cache_position: Optional[torch.Tensor] = None,
 
396
  # Non-recurrent prelude
397
  for block_idx, block in enumerate(self.transformer.prelude):
398
  input_embeds, attn_map = block(
399
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
400
  )
401
  attn_maps[block_idx] = attn_map
402
 
 
410
  past_key_values,
411
  num_steps,
412
  attn_maps,
413
+ return_attn=return_attn,
414
  )
415
  latent_states = x.clone().detach()
416
 
417
  # Coda layers
418
  for block_idx, block in enumerate(self.transformer.coda, start=1):
419
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
420
  attn_maps[-block_idx] = attn_map
421
  x = self.transformer.ln_f(x)
422
 
 
453
  past_key_values: Optional[Cache] = None,
454
  num_steps: Optional[torch.Tensor] = None,
455
  attn_maps: dict = {},
456
+ return_attn: bool = False,
457
  ):
458
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
459
  if num_steps is None:
 
471
  for step in range(num_steps_no_grad):
472
  xk = x
473
  x, block_idx, attn_maps = self.core_block_forward(
474
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
475
  )
476
 
477
  for step in range(num_steps_with_grad):
478
  xk = x
479
  x, block_idx, attn_maps = self.core_block_forward(
480
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
481
  )
482
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
483
 
 
490
  past_key_values,
491
  block_idx: Union[torch.Tensor, int],
492
  attn_maps: dict = {},
493
+ return_attn: bool = False,
494
  ):
495
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
+ x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map
499
  return x, block_idx + idx, attn_maps
500
 
 
627
  model_inputs["cache_position"] = cache_position
628
  current_input_length = input_ids.shape[1]
629
  if past_key_values is not None:
630
+ if type(past_key_values) != HuginnDynamicCache:
631
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
632
  assert past_key_values.get_seq_length() == 0
633
  past_key_values = HuginnDynamicCache()
 
647
  model_inputs[key] = value
648
  return model_inputs
649
 
650
+ @torch.no_grad()
651
+ def generate(self, *args, **kwargs):
652
+ """Dispatcher - use HF generate in all normal cases."""
653
+ if any(
654
+ k in kwargs
655
+ for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
656
+ ):
657
+ print("Dispatching to custom generate function call")
658
+ return self.generate_with_adaptive_compute(*args, **kwargs)
659
+ else:
660
+ return super().generate(*args, **kwargs)
661
+
662
  @torch.no_grad()
663
  def generate_minimal(
664
  self,