dar-tau commited on
Commit
01e48f0
·
verified ·
1 Parent(s): f355a3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -46
app.py CHANGED
@@ -75,7 +75,7 @@ def initialize_gpu():
75
  pass
76
 
77
 
78
- def reset_model(model_name):
79
  # extract model info
80
  model_args = deepcopy(model_info[model_name])
81
  model_path = model_args.pop('model_path')
@@ -91,6 +91,7 @@ def reset_model(model_name):
91
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
92
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
93
  gc.collect()
 
94
 
95
 
96
  def get_hidden_states(raw_original_prompt):
@@ -189,58 +190,59 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
189
  # with gr.Column(scale=1):
190
  # gr.Markdown('<span style="font-size:180px;">🤔</span>')
191
 
192
- with gr.Group():
193
- model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name)
194
-
195
- gr.Markdown('## Choose Your Interpretation Prompt')
196
- with gr.Group('Interpretation'):
197
- interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
198
- gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False)
199
-
200
-
201
- gr.Markdown('## The Prompt to Analyze')
202
- for info in dataset_info:
203
- with gr.Tab(info['name']):
204
- num_examples = 10
205
- dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
206
- if 'filter' in info:
207
- dataset = dataset.filter(info['filter'])
208
- dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
209
- dataset = [[row[info['text_col']]] for row in dataset]
210
- gr.Examples(dataset, [original_prompt_raw], cache_examples=False)
211
-
212
  with gr.Group():
213
- original_prompt_raw.render()
214
- original_prompt_btn = gr.Button('Output Token List', variant='primary')
215
-
216
- gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
 
 
 
217
 
218
- with gr.Row():
219
- for btn in tokens_container:
220
- btn.render()
221
-
222
 
223
- with gr.Accordion(open=False, label='Generation Settings'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  with gr.Row():
225
- num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
226
- repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
227
- length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
228
- # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
229
- do_sample = gr.Checkbox(label='With sampling')
230
- with gr.Accordion(label='Sampling Parameters'):
231
  with gr.Row():
232
- temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
233
- top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
234
- top_p = gr.Slider(0., 1., value=0.95, label='top p')
235
-
236
- progress_dummy = gr.Markdown('', elem_id='progress_dummy')
237
- interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
238
- elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
239
- ) for i in range(MAX_NUM_LAYERS)]
240
-
 
 
 
 
 
 
 
241
 
242
  # event listeners
243
- model_chooser.change(reset_model, [model_chooser], [])
244
 
245
  for i, btn in enumerate(tokens_container):
246
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
 
75
  pass
76
 
77
 
78
+ def reset_model(model_name, demo_blocks):
79
  # extract model info
80
  model_args = deepcopy(model_info[model_name])
81
  model_path = model_args.pop('model_path')
 
91
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
92
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
93
  gc.collect()
94
+ return demo_blocks
95
 
96
 
97
  def get_hidden_states(raw_original_prompt):
 
190
  # with gr.Column(scale=1):
191
  # gr.Markdown('<span style="font-size:180px;">🤔</span>')
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  with gr.Group():
194
+ model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name)
195
+
196
+ with gr.Blocks() as demo_blocks:
197
+ gr.Markdown('## Choose Your Interpretation Prompt')
198
+ with gr.Group('Interpretation'):
199
+ interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
200
+ gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False)
201
 
 
 
 
 
202
 
203
+ gr.Markdown('## The Prompt to Analyze')
204
+ for info in dataset_info:
205
+ with gr.Tab(info['name']):
206
+ num_examples = 10
207
+ dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
208
+ if 'filter' in info:
209
+ dataset = dataset.filter(info['filter'])
210
+ dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
211
+ dataset = [[row[info['text_col']]] for row in dataset]
212
+ gr.Examples(dataset, [original_prompt_raw], cache_examples=False)
213
+
214
+ with gr.Group():
215
+ original_prompt_raw.render()
216
+ original_prompt_btn = gr.Button('Output Token List', variant='primary')
217
+
218
+ gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
219
+
220
  with gr.Row():
221
+ for btn in tokens_container:
222
+ btn.render()
223
+
224
+
225
+ with gr.Accordion(open=False, label='Generation Settings'):
 
226
  with gr.Row():
227
+ num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
228
+ repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
229
+ length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
230
+ # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
231
+ do_sample = gr.Checkbox(label='With sampling')
232
+ with gr.Accordion(label='Sampling Parameters'):
233
+ with gr.Row():
234
+ temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
235
+ top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
236
+ top_p = gr.Slider(0., 1., value=0.95, label='top p')
237
+
238
+ progress_dummy = gr.Markdown('', elem_id='progress_dummy')
239
+ interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
240
+ elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
241
+ ) for i in range(MAX_NUM_LAYERS)]
242
+
243
 
244
  # event listeners
245
+ model_chooser.change(reset_model, [model_chooser, demo_blocks], [demo_blocks])
246
 
247
  for i, btn in enumerate(tokens_container):
248
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,