taka-yamakoshi commited on
Commit
2f141a3
1 Parent(s): ce466e4
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -199,17 +199,17 @@ if __name__=='__main__':
199
  input_ids_dict = {}
200
  masked_ids_option_1 = {}
201
  masked_ids_option_2 = {}
202
- for sent_id in range(2):
203
- option_1_locs[f'sent_{sent_id+1}'], option_2_locs[f'sent_{sent_id+1}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
204
- pron_locs[f'sent_{sent_id+1}'] = st.session_state[f'mask_locs_{sent_id+1}']
205
- input_ids_dict[f'sent_{sent_id+1}'] = tokenizer(st.session_state[f'sent_{sent_id+1}']).input_ids
206
-
207
- masked_ids_option_1[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
208
- pron_locs[f'sent_{sent_id+1}'],
209
- option_1_locs[f'sent_{sent_id+1}'],mask_id)
210
- masked_ids_option_2[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
211
- pron_locs[f'sent_{sent_id+1}'],
212
- option_2_locs[f'sent_{sent_id+1}'],mask_id)
213
 
214
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
215
  st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
 
199
  input_ids_dict = {}
200
  masked_ids_option_1 = {}
201
  masked_ids_option_2 = {}
202
+ for sent_id in [1,2]:
203
+ option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
204
+ pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}']
205
+ input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids
206
+
207
+ masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
208
+ pron_locs[f'sent_{sent_id}'],
209
+ option_1_locs[f'sent_{sent_id}'],mask_id)
210
+ masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
211
+ pron_locs[f'sent_{sent_id}'],
212
+ option_2_locs[f'sent_{sent_id}'],mask_id)
213
 
214
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
215
  st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))