Lora commited on
Commit
e29168c
·
1 Parent(s): 9cfeab8

zero indexed

Browse files
Files changed (2) hide show
  1. app.py +1 -20
  2. requirements.txt +4 -71
app.py CHANGED
@@ -29,7 +29,7 @@ def visualize_word(word, count=10, remove_space=False):
29
  for i in range(contents.shape[0]):
30
  logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
31
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
32
- sense_names.append('sense {}'.format(i+1))
33
 
34
  # currently a lot of repetition
35
  pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
@@ -60,19 +60,6 @@ def visualize_word(word, count=10, remove_space=False):
60
 
61
  return pos_df, neg_df, tokens
62
 
63
- # argp = argparse.ArgumentParser()
64
- # argp.add_argument('vecs_path')
65
- # argp.add_argument('lm_head_path')
66
- # args = argp.parse_args()
67
-
68
- # Load tokenizer and parameters
69
- # tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
70
- # vecs = torch.load(args.vecs_path)
71
- # lm_head = torch.load(args.lm_head_path)
72
-
73
- # visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)
74
- # visualize_word("fish", vecs, lm_head, count=COUNT)
75
-
76
  with gr.Blocks() as demo:
77
  gr.Markdown("""
78
  ## Backpack visualization: senses lookup
@@ -100,11 +87,5 @@ with gr.Blocks() as demo:
100
  outputs= [pos_outputs, neg_outputs, token_breakdown],
101
  )
102
 
103
- # sentence.select(
104
- # fn=visualize_word,
105
- # inputs= [word, count],
106
- # outputs= [pos_outputs, neg_outputs],
107
- # )
108
-
109
  demo.launch(share=False)
110
 
 
29
  for i in range(contents.shape[0]):
30
  logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
31
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
32
+ sense_names.append('sense {}'.format(i))
33
 
34
  # currently a lot of repetition
35
  pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
 
60
 
61
  return pos_df, neg_df, tokens
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown("""
65
  ## Backpack visualization: senses lookup
 
87
  outputs= [pos_outputs, neg_outputs, token_breakdown],
88
  )
89
 
 
 
 
 
 
 
90
  demo.launch(share=False)
91
 
requirements.txt CHANGED
@@ -1,71 +1,4 @@
1
- aiofiles==23.1.0
2
- aiohttp==3.8.4
3
- aiosignal==1.3.1
4
- altair==4.2.2
5
- anyio==3.6.2
6
- async-timeout==4.0.2
7
- attrs==22.2.0
8
- certifi @ file:///Users/cbousseau/work/recipes/ci_py311/certifi_1677903144932/work/certifi
9
- charset-normalizer==3.1.0
10
- click==8.1.3
11
- contourpy==1.0.7
12
- cycler==0.11.0
13
- entrypoints==0.4
14
- fastapi==0.95.0
15
- ffmpy==0.3.0
16
- filelock==3.10.7
17
- fonttools==4.39.3
18
- frozenlist==1.3.3
19
- fsspec==2023.3.0
20
- gradio==3.24.1
21
- gradio_client==0.0.7
22
- h11==0.14.0
23
- httpcore==0.16.3
24
- httpx==0.23.3
25
- huggingface-hub==0.13.3
26
- idna==3.4
27
- Jinja2==3.1.2
28
- jsonschema==4.17.3
29
- kiwisolver==1.4.4
30
- linkify-it-py==2.0.0
31
- markdown-it-py==2.2.0
32
- MarkupSafe==2.1.2
33
- matplotlib==3.7.1
34
- mdit-py-plugins==0.3.3
35
- mdurl==0.1.2
36
- mpmath==1.3.0
37
- multidict==6.0.4
38
- networkx==3.1
39
- numpy==1.24.2
40
- orjson==3.8.9
41
- packaging==23.0
42
- pandas==2.0.0
43
- Pillow==9.5.0
44
- pydantic==1.10.7
45
- pydub==0.25.1
46
- pyparsing==3.0.9
47
- pyrsistent==0.19.3
48
- python-dateutil==2.8.2
49
- python-multipart==0.0.6
50
- pytz==2023.3
51
- PyYAML==6.0
52
- regex==2023.3.23
53
- requests==2.28.2
54
- rfc3986==1.5.0
55
- semantic-version==2.10.0
56
- six==1.16.0
57
- sniffio==1.3.0
58
- starlette==0.26.1
59
- sympy==1.11.1
60
- tokenizers==0.13.3
61
- toolz==0.12.0
62
- torch==2.0.0
63
- tqdm==4.65.0
64
- transformers==4.27.4
65
- typing_extensions==4.5.0
66
- tzdata==2023.3
67
- uc-micro-py==1.0.1
68
- urllib3==1.26.15
69
- uvicorn==0.21.1
70
- websockets==11.0
71
- yarl==1.8.2
 
1
+ torch
2
+ pandas
3
+ transformers
4
+ gradio