Lev McKinney commited on
Commit
c35da92
Β·
1 Parent(s): 7a724e0

upgraded app to use tuned_lens=0.1.0

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +21 -15
  3. requirements.txt +1 -1
README.md CHANGED
@@ -3,6 +3,7 @@ title: Tuned Lens
3
  emoji: πŸ”Ž
4
  colorFrom: pink
5
  colorTo: blue
 
6
  sdk: docker
7
  pinned: false
8
  license: mit
 
3
  emoji: πŸ”Ž
4
  colorFrom: pink
5
  colorTo: blue
6
+ port: 7860
7
  sdk: docker
8
  pinned: false
9
  license: mit
app.py CHANGED
@@ -1,17 +1,20 @@
1
  import torch
2
  from tuned_lens.nn.lenses import TunedLens, LogitLens
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from tuned_lens.plotting import plot_lens
5
  import gradio as gr
6
  from plotly import graph_objects as go
7
 
8
  device = torch.device("cpu")
9
  print(f"Using device {device} for inference")
10
- model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped-v0")
11
  model = model.to(device)
12
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped-v0")
13
- tuned_lens = TunedLens.load("pythia-410m-deduped-v0", map_location=device)
14
- logit_lens = LogitLens(model)
 
 
 
15
 
16
  lens_options_dict = {
17
  "Tuned Lens": tuned_lens,
@@ -20,13 +23,15 @@ lens_options_dict = {
20
 
21
  statistic_options_dict = {
22
  "Entropy": "entropy",
23
- "Cross Entropy": "ce",
24
  "Forward KL": "forward_kl",
25
  }
26
 
27
 
28
  def make_plot(lens, text, statistic, token_cutoff):
29
  input_ids = tokenizer.encode(text, return_tensors="pt")
 
 
30
 
31
  if len(input_ids[0]) == 0:
32
  return go.Figure(layout=dict(title="Please enter some text."))
@@ -34,18 +39,19 @@ def make_plot(lens, text, statistic, token_cutoff):
34
  if token_cutoff < 1:
35
  return go.Figure(layout=dict(title="Please provide valid token cut off."))
36
 
37
- fig = plot_lens(
38
- model,
39
- tokenizer,
40
- lens_options_dict[lens],
41
- layer_stride=2,
42
  input_ids=input_ids,
43
- start_pos=max(len(input_ids[0]) - token_cutoff, 0),
44
- statistic=statistic_options_dict[statistic],
 
45
  )
46
 
47
- return fig
48
-
 
49
 
50
  preamble = """
51
  # The Tuned Lens πŸ”Ž
 
1
  import torch
2
  from tuned_lens.nn.lenses import TunedLens, LogitLens
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from tuned_lens.plotting import PredictionTrajectory
5
  import gradio as gr
6
  from plotly import graph_objects as go
7
 
8
  device = torch.device("cpu")
9
  print(f"Using device {device} for inference")
10
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
11
  model = model.to(device)
12
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
13
+ tuned_lens = TunedLens.from_model_and_pretrained(
14
+ model=model,
15
+ map_location=device,
16
+ )
17
+ logit_lens = LogitLens.from_model(model)
18
 
19
  lens_options_dict = {
20
  "Tuned Lens": tuned_lens,
 
23
 
24
  statistic_options_dict = {
25
  "Entropy": "entropy",
26
+ "Cross Entropy": "cross_entropy",
27
  "Forward KL": "forward_kl",
28
  }
29
 
30
 
31
  def make_plot(lens, text, statistic, token_cutoff):
32
  input_ids = tokenizer.encode(text, return_tensors="pt")
33
+ input_ids = [tokenizer.bos_token_id] + input_ids
34
+ targets = input_ids[1:] + [tokenizer.eos_token_id]
35
 
36
  if len(input_ids[0]) == 0:
37
  return go.Figure(layout=dict(title="Please enter some text."))
 
39
  if token_cutoff < 1:
40
  return go.Figure(layout=dict(title="Please provide valid token cut off."))
41
 
42
+ start_pos=max(len(input_ids[0]) - token_cutoff, 0),
43
+ pred_traj = PredictionTrajectory.from_lens_and_model(
44
+ lens=lens,
45
+ model=model,
 
46
  input_ids=input_ids,
47
+ tokenizer=tokenizer,
48
+ targets=targets,
49
+ start_pos=start_pos,
50
  )
51
 
52
+ return getattr(pred_traj, statistic)().figure(
53
+ title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
54
+ )
55
 
56
  preamble = """
57
  # The Tuned Lens πŸ”Ž
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- tuned_lens==0.0.5
2
  gradio
 
1
+ tuned_lens==0.1.0
2
  gradio