yairschiff commited on
Commit
a951351
1 Parent(s): 5ddd9c9

Upload Caduceus

Browse files
Files changed (3) hide show
  1. model.safetensors +2 -2
  2. modeling_caduceus.py +3 -3
  3. modeling_rcps.py +1 -1
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cf9b060f9fe236a72eeb4df44d905816bbf3ce013462524edcfbb7272b8c947
3
- size 2174536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16d3acb10a57ce482dd0799e59fd8616b83ce414143b04d68d95a9ab8cd8180e
3
+ size 2173880
modeling_caduceus.py CHANGED
@@ -158,7 +158,7 @@ class CaduceusMixerModel(nn.Module):
158
  self.rcps = config.rcps
159
  self.residual_in_fp32 = config.residual_in_fp32
160
 
161
- self.embeddings = torch.compile(CaduceusEmbeddings(config, **factory_kwargs))
162
 
163
  # Mamba changes the order of residual and layer norm:
164
  # Instead of LN -> Attn / MLP -> Add, we do:
@@ -377,12 +377,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
377
  factory_kwargs = {"device": device, "dtype": dtype}
378
  self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
379
  if config.rcps:
380
- self.lm_head = torch.compile(RCPSLMHead(
381
  complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
382
  vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
383
  true_dim=config.d_model,
384
  dtype=dtype
385
- ))
386
  else:
387
  self.lm_head = nn.Linear(
388
  config.d_model,
 
158
  self.rcps = config.rcps
159
  self.residual_in_fp32 = config.residual_in_fp32
160
 
161
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
162
 
163
  # Mamba changes the order of residual and layer norm:
164
  # Instead of LN -> Attn / MLP -> Add, we do:
 
377
  factory_kwargs = {"device": device, "dtype": dtype}
378
  self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
379
  if config.rcps:
380
+ self.lm_head = RCPSLMHead(
381
  complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
382
  vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
383
  true_dim=config.d_model,
384
  dtype=dtype
385
+ )
386
  else:
387
  self.lm_head = nn.Linear(
388
  config.d_model,
modeling_rcps.py CHANGED
@@ -144,7 +144,7 @@ class RCPSMambaBlock(nn.Module):
144
  super().__init__()
145
  self.residual_in_fp32 = residual_in_fp32
146
  self.fused_add_norm = fused_add_norm
147
- self.mixer = torch.compile(RCPSWrapper(mixer_cls(dim)))
148
  norm_f = norm_cls(dim)
149
  self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
150
 
 
144
  super().__init__()
145
  self.residual_in_fp32 = residual_in_fp32
146
  self.fused_add_norm = fused_add_norm
147
+ self.mixer = RCPSWrapper(mixer_cls(dim))
148
  norm_f = norm_cls(dim)
149
  self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
150