Lev McKinney commited on
Commit
4004daa
Β·
1 Parent(s): c35da92

fixed several bugs in app.py

Browse files
Files changed (3) hide show
  1. .dockerignore +1 -1
  2. README.md +0 -1
  3. app.py +7 -8
.dockerignore CHANGED
@@ -1,2 +1,2 @@
1
  lens
2
- .git
 
1
  lens
2
+ .git
README.md CHANGED
@@ -3,7 +3,6 @@ title: Tuned Lens
3
  emoji: πŸ”Ž
4
  colorFrom: pink
5
  colorTo: blue
6
- port: 7860
7
  sdk: docker
8
  pinned: false
9
  license: mit
 
3
  emoji: πŸ”Ž
4
  colorFrom: pink
5
  colorTo: blue
 
6
  sdk: docker
7
  pinned: false
8
  license: mit
app.py CHANGED
@@ -7,7 +7,7 @@ from plotly import graph_objects as go
7
 
8
  device = torch.device("cpu")
9
  print(f"Using device {device} for inference")
10
- model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
11
  model = model.to(device)
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
13
  tuned_lens = TunedLens.from_model_and_pretrained(
@@ -29,19 +29,19 @@ statistic_options_dict = {
29
 
30
 
31
  def make_plot(lens, text, statistic, token_cutoff):
32
- input_ids = tokenizer.encode(text, return_tensors="pt")
33
  input_ids = [tokenizer.bos_token_id] + input_ids
34
  targets = input_ids[1:] + [tokenizer.eos_token_id]
35
 
36
- if len(input_ids[0]) == 0:
37
  return go.Figure(layout=dict(title="Please enter some text."))
38
 
39
  if token_cutoff < 1:
40
  return go.Figure(layout=dict(title="Please provide valid token cut off."))
41
 
42
- start_pos=max(len(input_ids[0]) - token_cutoff, 0),
43
  pred_traj = PredictionTrajectory.from_lens_and_model(
44
- lens=lens,
45
  model=model,
46
  input_ids=input_ids,
47
  tokenizer=tokenizer,
@@ -49,7 +49,7 @@ def make_plot(lens, text, statistic, token_cutoff):
49
  start_pos=start_pos,
50
  )
51
 
52
- return getattr(pred_traj, statistic)().figure(
53
  title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
54
  )
55
 
@@ -114,5 +114,4 @@ with gr.Blocks() as demo:
114
  demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)
115
 
116
  if __name__ == "__main__":
117
- demo.launch()
118
-
 
7
 
8
  device = torch.device("cpu")
9
  print(f"Using device {device} for inference")
10
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped", torch_dtype="auto")
11
  model = model.to(device)
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
13
  tuned_lens = TunedLens.from_model_and_pretrained(
 
29
 
30
 
31
  def make_plot(lens, text, statistic, token_cutoff):
32
+ input_ids = tokenizer.encode(text)
33
  input_ids = [tokenizer.bos_token_id] + input_ids
34
  targets = input_ids[1:] + [tokenizer.eos_token_id]
35
 
36
+ if len(input_ids) == 1:
37
  return go.Figure(layout=dict(title="Please enter some text."))
38
 
39
  if token_cutoff < 1:
40
  return go.Figure(layout=dict(title="Please provide valid token cut off."))
41
 
42
+ start_pos=max(len(input_ids) - token_cutoff, 0)
43
  pred_traj = PredictionTrajectory.from_lens_and_model(
44
+ lens=lens_options_dict[lens],
45
  model=model,
46
  input_ids=input_ids,
47
  tokenizer=tokenizer,
 
49
  start_pos=start_pos,
50
  )
51
 
52
+ return getattr(pred_traj, statistic_options_dict[statistic])().figure(
53
  title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
54
  )
55
 
 
114
  demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)
115
 
116
  if __name__ == "__main__":
117
+ demo.launch(server_name="0.0.0.0", server_port=7860)