dar-tau commited on
Commit
ee7058f
·
verified ·
1 Parent(s): 7098573

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -95,9 +95,16 @@ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
95
  else:
96
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
97
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
 
 
 
 
98
  global_state.local_state.hidden_states = hidden_states.cpu().detach()
99
 
100
- token_btns = ([gr.Button(token, visible=True) for token in tokens]
 
 
 
101
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
102
  progress_dummy_output = ''
103
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
@@ -182,7 +189,7 @@ raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', contai
182
  tokens_container = []
183
 
184
  for i in range(MAX_PROMPT_TOKENS):
185
- btn = gr.Button('', visible=False, elem_classes=['token_btn'])
186
  tokens_container.append(btn)
187
 
188
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
 
95
  else:
96
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
97
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
98
+ # TODO: document this!
99
+ important_tokens = set(1 + F.normalize(hidden_states, dim=-1)
100
+ .diff(dim=0).norm(dim=-1)
101
+ .topk(k=10, dim=0).values.topk(k=5).indices.cpu().numpy())
102
  global_state.local_state.hidden_states = hidden_states.cpu().detach()
103
 
104
+ token_btns = ([gr.Button(token, visible=True,
105
+ elem_classes=['token_btn'] + (['important_token'] if i in important_tokens else [])
106
+ )
107
+ for i, token in enumerate(tokens)]
108
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
109
  progress_dummy_output = ''
110
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
 
189
  tokens_container = []
190
 
191
  for i in range(MAX_PROMPT_TOKENS):
192
+ btn = gr.Button('', visible=False)
193
  tokens_container.append(btn)
194
 
195
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo: