ronald commited on
Commit
1639e25
1 Parent(s): 1cb6225
Files changed (1) hide show
  1. local_coh_ppl.py +2 -1
local_coh_ppl.py CHANGED
@@ -244,7 +244,8 @@ class LocalCohPPL(evaluate.Measurement):
244
  for i in range(1,len(stl)):
245
  ppl = torch.exp( loss_out[b,stl[i-1]:stl[i]].sum() / shift_attention_mask_batch[b,stl[i-1]:stl[i]].sum() ).detach().cpu().item()
246
  indv.append(ppl)
247
- norm_ppl.append( perplexity_all[b] / sum(indv) )
 
248
  #
249
  all_norm_ppl.extend(norm_ppl)
250
 
 
244
  for i in range(1,len(stl)):
245
  ppl = torch.exp( loss_out[b,stl[i-1]:stl[i]].sum() / shift_attention_mask_batch[b,stl[i-1]:stl[i]].sum() ).detach().cpu().item()
246
  indv.append(ppl)
247
+ nppl = perplexity_all[b] / sum(indv) if len(indv)>1 else 0.0
248
+ norm_ppl.append(nppl)
249
  #
250
  all_norm_ppl.extend(norm_ppl)
251