Lev McKinney commited on
Commit
dba1d6e
1 Parent(s): e49cdfa

lens migration script updated

Browse files
Files changed (1) hide show
  1. lens_migration.py +3 -3
lens_migration.py CHANGED
@@ -7,7 +7,7 @@ from logging import warn
7
  from pathlib import Path
8
  import json
9
 
10
- from tuned_lens.model_surgery import get_final_layer_norm, get_transformer_layers
11
  from tuned_lens.load_artifacts import load_lens_artifacts
12
  from tuned_lens.nn import TunedLens
13
  from transformers.models.bloom.modeling_bloom import BloomBlock
@@ -148,7 +148,7 @@ class TunedLensOld(th.nn.Module):
148
 
149
  # Currently we convert the decoder to full precision
150
  self.unembedding = deepcopy(model.get_output_embeddings()).float()
151
- if ln := get_final_layer_norm(model):
152
  self.layer_norm = deepcopy(ln).float()
153
  else:
154
  self.layer_norm = th.nn.Identity()
@@ -354,7 +354,7 @@ if __name__ == "__main__":
354
 
355
  tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device)
356
 
357
- tuned_lens = TunedLens.init_from_model(
358
  model, bias=tuned_lens_old.config['bias'], revision=revision
359
  )
360
 
 
7
  from pathlib import Path
8
  import json
9
 
10
+ from tuned_lens.model_surgery import get_final_norm, get_transformer_layers
11
  from tuned_lens.load_artifacts import load_lens_artifacts
12
  from tuned_lens.nn import TunedLens
13
  from transformers.models.bloom.modeling_bloom import BloomBlock
 
148
 
149
  # Currently we convert the decoder to full precision
150
  self.unembedding = deepcopy(model.get_output_embeddings()).float()
151
+ if ln := get_final_norm(model):
152
  self.layer_norm = deepcopy(ln).float()
153
  else:
154
  self.layer_norm = th.nn.Identity()
 
354
 
355
  tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device)
356
 
357
+ tuned_lens = TunedLens.from_model(
358
  model, bias=tuned_lens_old.config['bias'], revision=revision
359
  )
360