dar-tau commited on
Commit
1e4e3c2
·
verified ·
1 Parent(s): 9ab090f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -104,11 +104,15 @@ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
104
 
105
  @spaces.GPU
106
  def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
107
- temperature, top_k, top_p, repetition_penalty, length_penalty, i,
108
  num_beams=1):
109
  model = global_state.model
110
  tokenizer = global_state.tokenizer
111
  print(f'run {model}')
 
 
 
 
112
  if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
113
  get_hidden_states(raw_original_prompt, force_hidden_states=True)
114
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
@@ -251,6 +255,7 @@ with gr.Blocks(theme=gr.themes.Glass(), css='styles.css') as demo:
251
  with gr.Row():
252
  for btn in tokens_container:
253
  btn.render()
 
254
 
255
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
256
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
@@ -259,7 +264,8 @@ with gr.Blocks(theme=gr.themes.Glass(), css='styles.css') as demo:
259
  for i, btn in enumerate(tokens_container):
260
  btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
261
  num_tokens, do_sample, temperature,
262
- top_k, top_p, repetition_penalty, length_penalty
 
263
  ], [progress_dummy, *interpretation_bubbles])
264
 
265
  original_prompt_btn.click(get_hidden_states,
 
104
 
105
  @spaces.GPU
106
  def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
107
+ temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
108
  num_beams=1):
109
  model = global_state.model
110
  tokenizer = global_state.tokenizer
111
  print(f'run {model}')
112
+ if use_gpu:
113
+ model = model.cuda()
114
+ else:
115
+ model = model.cpu()
116
  if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
117
  get_hidden_states(raw_original_prompt, force_hidden_states=True)
118
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
 
255
  with gr.Row():
256
  for btn in tokens_container:
257
  btn.render()
258
+ use_gpu = gr.Radio('Use GPU', value=True)
259
 
260
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
261
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
 
264
  for i, btn in enumerate(tokens_container):
265
  btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
266
  num_tokens, do_sample, temperature,
267
+ top_k, top_p, repetition_penalty, length_penalty,
268
+ use_gpu
269
  ], [progress_dummy, *interpretation_bubbles])
270
 
271
  original_prompt_btn.click(get_hidden_states,