yairschiff
commited on
Ensure weights are tied for BiMamba (if applicable) when loaded from_pretrained
Browse files- modeling_caduceus.py +31 -1
modeling_caduceus.py
CHANGED
@@ -360,6 +360,24 @@ class Caduceus(CaduceusPreTrainedModel):
|
|
360 |
factory_kwargs = {"device": device, "dtype": dtype}
|
361 |
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
def forward(
|
364 |
self,
|
365 |
input_ids: torch.LongTensor = None,
|
@@ -431,8 +449,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
431 |
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
|
432 |
self.lm_head = new_embeddings
|
433 |
|
|
|
|
|
|
|
434 |
def tie_weights(self):
|
435 |
"""Tie weights, accounting for RCPS."""
|
|
|
436 |
if self.config.rcps:
|
437 |
self.lm_head.set_weight(self.get_input_embeddings().weight)
|
438 |
else:
|
@@ -445,7 +467,7 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
445 |
def set_decoder(self, decoder):
|
446 |
"""Set decoder (backbone) for the model."""
|
447 |
self.caduceus = decoder
|
448 |
-
|
449 |
def forward(
|
450 |
self,
|
451 |
input_ids: torch.LongTensor = None,
|
@@ -536,6 +558,13 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
536 |
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
|
537 |
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
|
538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
def forward(
|
540 |
self,
|
541 |
input_ids: torch.LongTensor = None,
|
@@ -543,6 +572,7 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
543 |
labels: Optional[torch.LongTensor] = None,
|
544 |
output_hidden_states: Optional[bool] = None,
|
545 |
return_dict: Optional[bool] = None,
|
|
|
546 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
547 |
r"""
|
548 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
360 |
factory_kwargs = {"device": device, "dtype": dtype}
|
361 |
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
|
362 |
|
363 |
+
def maybe_weight_tie_mamba(self):
|
364 |
+
if getattr(self.config, 'bidirectional', False) and getattr(self.config, 'bidirectional_weight_tie', False):
|
365 |
+
if getattr(self.config, 'rcps', False):
|
366 |
+
for layer in self.backbone.layers:
|
367 |
+
layer.mixer.submodule.mamba_rev.in_proj.weight = layer.mixer.submodule.mamba_fwd.in_proj.weight
|
368 |
+
layer.mixer.submodule.mamba_rev.in_proj.bias = layer.mixer.submodule.mamba_fwd.in_proj.bias
|
369 |
+
layer.mixer.submodule.mamba_rev.out_proj.weight = layer.mixer.submodule.mamba_fwd.out_proj.weight
|
370 |
+
layer.mixer.submodule.mamba_rev.out_proj.bias = layer.mixer.submodule.mamba_fwd.out_proj.bias
|
371 |
+
else:
|
372 |
+
for layer in self.backbone.layers:
|
373 |
+
layer.mixer.mamba_rev.in_proj.weight = layer.mixer.mamba_fwd.in_proj.weight
|
374 |
+
layer.mixer.mamba_rev.in_proj.bias = layer.mixer.mamba_fwd.in_proj.bias
|
375 |
+
layer.mixer.mamba_rev.out_proj.weight = layer.mixer.mamba_fwd.out_proj.weight
|
376 |
+
layer.mixer.mamba_rev.out_proj.bias = layer.mixer.mamba_fwd.out_proj.bias
|
377 |
+
|
378 |
+
def tie_weights(self):
|
379 |
+
self.maybe_weight_tie_mamba()
|
380 |
+
|
381 |
def forward(
|
382 |
self,
|
383 |
input_ids: torch.LongTensor = None,
|
|
|
449 |
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
|
450 |
self.lm_head = new_embeddings
|
451 |
|
452 |
+
def maybe_weight_tie_mamba(self):
|
453 |
+
self.caduceus.maybe_weight_tie_mamba()
|
454 |
+
|
455 |
def tie_weights(self):
|
456 |
"""Tie weights, accounting for RCPS."""
|
457 |
+
self.maybe_weight_tie_mamba()
|
458 |
if self.config.rcps:
|
459 |
self.lm_head.set_weight(self.get_input_embeddings().weight)
|
460 |
else:
|
|
|
467 |
def set_decoder(self, decoder):
|
468 |
"""Set decoder (backbone) for the model."""
|
469 |
self.caduceus = decoder
|
470 |
+
|
471 |
def forward(
|
472 |
self,
|
473 |
input_ids: torch.LongTensor = None,
|
|
|
558 |
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
|
559 |
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
|
560 |
|
561 |
+
def maybe_weight_tie_mamba(self):
|
562 |
+
self.caduceus.maybe_weight_tie_mamba()
|
563 |
+
|
564 |
+
def tie_weights(self):
|
565 |
+
self.maybe_weight_tie_mamba()
|
566 |
+
super().tie_weights()
|
567 |
+
|
568 |
def forward(
|
569 |
self,
|
570 |
input_ids: torch.LongTensor = None,
|
|
|
572 |
labels: Optional[torch.LongTensor] = None,
|
573 |
output_hidden_states: Optional[bool] = None,
|
574 |
return_dict: Optional[bool] = None,
|
575 |
+
**kwargs,
|
576 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
577 |
r"""
|
578 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|