zxdu20 commited on
Commit
e22cddf
1 Parent(s): e1494f2

Fix encode method

Browse files
Files changed (1) hide show
  1. tokenization_chatglm.py +8 -6
tokenization_chatglm.py CHANGED
@@ -398,19 +398,21 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
398
 
399
  # Initialize attention mask if not present.
400
  if return_attention_mask:
401
- context_length = required_input.index(bos_token_id)
 
 
 
402
  attention_mask = np.ones((1, seq_length, seq_length))
403
  attention_mask = np.tril(attention_mask)
404
  attention_mask[:, :, :context_length] = 1
405
  attention_mask = np.bool_(attention_mask < 0.5)
406
  encoded_inputs["attention_mask"] = attention_mask
407
 
408
- if return_attention_mask:
409
- mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
410
- mask_position = required_input.index(mask_token)
411
- context_length = required_input.index(bos_token_id)
412
  position_ids = np.arange(seq_length, dtype=np.int64)
413
- position_ids[context_length:] = mask_position
 
 
 
414
  block_position_ids = np.concatenate(
415
  [np.zeros(context_length, dtype=np.int64),
416
  np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
 
398
 
399
  # Initialize attention mask if not present.
400
  if return_attention_mask:
401
+ if bos_token_id in required_input:
402
+ context_length = required_input.index(bos_token_id)
403
+ else:
404
+ context_length = seq_length
405
  attention_mask = np.ones((1, seq_length, seq_length))
406
  attention_mask = np.tril(attention_mask)
407
  attention_mask[:, :, :context_length] = 1
408
  attention_mask = np.bool_(attention_mask < 0.5)
409
  encoded_inputs["attention_mask"] = attention_mask
410
 
 
 
 
 
411
  position_ids = np.arange(seq_length, dtype=np.int64)
412
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
413
+ if mask_token in required_input:
414
+ mask_position = required_input.index(mask_token)
415
+ position_ids[context_length:] = mask_position
416
  block_position_ids = np.concatenate(
417
  [np.zeros(context_length, dtype=np.int64),
418
  np.arange(1, seq_length - context_length + 1, dtype=np.int64)])