Update app.py
Browse files
app.py
CHANGED
@@ -104,11 +104,15 @@ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
|
104 |
|
105 |
@spaces.GPU
|
106 |
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
107 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
108 |
num_beams=1):
|
109 |
model = global_state.model
|
110 |
tokenizer = global_state.tokenizer
|
111 |
print(f'run {model}')
|
|
|
|
|
|
|
|
|
112 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
113 |
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
114 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
@@ -251,6 +255,7 @@ with gr.Blocks(theme=gr.themes.Glass(), css='styles.css') as demo:
|
|
251 |
with gr.Row():
|
252 |
for btn in tokens_container:
|
253 |
btn.render()
|
|
|
254 |
|
255 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
256 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
|
@@ -259,7 +264,8 @@ with gr.Blocks(theme=gr.themes.Glass(), css='styles.css') as demo:
|
|
259 |
for i, btn in enumerate(tokens_container):
|
260 |
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
261 |
num_tokens, do_sample, temperature,
|
262 |
-
top_k, top_p, repetition_penalty, length_penalty
|
|
|
263 |
], [progress_dummy, *interpretation_bubbles])
|
264 |
|
265 |
original_prompt_btn.click(get_hidden_states,
|
|
|
104 |
|
105 |
@spaces.GPU
|
106 |
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
107 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
108 |
num_beams=1):
|
109 |
model = global_state.model
|
110 |
tokenizer = global_state.tokenizer
|
111 |
print(f'run {model}')
|
112 |
+
if use_gpu:
|
113 |
+
model = model.cuda()
|
114 |
+
else:
|
115 |
+
model = model.cpu()
|
116 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
117 |
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
118 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
|
|
255 |
with gr.Row():
|
256 |
for btn in tokens_container:
|
257 |
btn.render()
|
258 |
+
use_gpu = gr.Radio('Use GPU', value=True)
|
259 |
|
260 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
261 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
|
|
|
264 |
for i, btn in enumerate(tokens_container):
|
265 |
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
266 |
num_tokens, do_sample, temperature,
|
267 |
+
top_k, top_p, repetition_penalty, length_penalty,
|
268 |
+
use_gpu
|
269 |
], [progress_dummy, *interpretation_bubbles])
|
270 |
|
271 |
original_prompt_btn.click(get_hidden_states,
|