Rocketknight1 HF staff commited on
Commit
4d54fdb
1 Parent(s): 0c34bae

Upload HyenaDNAForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +16 -9
modeling_hyena.py CHANGED
@@ -349,8 +349,15 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
349
  supports_gradient_checkpointing = True
350
  _no_split_modules = ["HyenaBlock"]
351
  _skip_keys_device_placement = "past_key_values"
352
-
353
- def _init_weights(self, initializer_range=0.02):
 
 
 
 
 
 
 
354
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
355
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
356
  # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
@@ -368,8 +375,8 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
368
 
369
 
370
  class HyenaDNAModel(HyenaDNAPreTrainedModel):
371
- def __init__(self, config) -> None:
372
- super().__init__(config)
373
 
374
  self.backbone = HyenaLMBackbone(config)
375
  self.config = config
@@ -395,8 +402,8 @@ class HyenaDNAModel(HyenaDNAPreTrainedModel):
395
 
396
  class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
397
 
398
- def __init__(self, config):
399
- super().__init__(config)
400
  self.hyena = HyenaDNAModel(config)
401
  vocab_size = config.vocab_size
402
  if vocab_size % config.pad_vocab_size_multiple != 0:
@@ -476,9 +483,9 @@ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
476
 
477
 
478
  class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
479
- def __init__(self, config):
480
- super().__init__(config)
481
- self.num_labels = config.num_labels
482
  self.hyena = HyenaDNAModel(config)
483
  self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
484
 
 
349
  supports_gradient_checkpointing = True
350
  _no_split_modules = ["HyenaBlock"]
351
  _skip_keys_device_placement = "past_key_values"
352
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
353
+
354
+ def _init_weights(self, module, initializer_range=0.02):
355
+ if isinstance(module, nn.Linear):
356
+ nn.init.normal_(module.weight, std=initializer_range)
357
+ if module.bias is not None:
358
+ nn.init.zeros_(module.bias)
359
+ elif isinstance(module, nn.Embedding):
360
+ nn.init.normal_(module.weight, std=initializer_range)
361
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
362
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
363
  # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 
375
 
376
 
377
  class HyenaDNAModel(HyenaDNAPreTrainedModel):
378
+ def __init__(self, config, **kwargs) -> None:
379
+ super().__init__(config, **kwargs)
380
 
381
  self.backbone = HyenaLMBackbone(config)
382
  self.config = config
 
402
 
403
  class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
404
 
405
+ def __init__(self, config, **kwargs):
406
+ super().__init__(config, **kwargs)
407
  self.hyena = HyenaDNAModel(config)
408
  vocab_size = config.vocab_size
409
  if vocab_size % config.pad_vocab_size_multiple != 0:
 
483
 
484
 
485
  class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
486
+ def __init__(self, config, **kwargs):
487
+ super().__init__(config, **kwargs)
488
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
489
  self.hyena = HyenaDNAModel(config)
490
  self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
491