sunzeyeah commited on
Commit
d0274c1
1 Parent(s): fdc053e

update modeling and tokenization

Browse files
Files changed (2) hide show
  1. modeling_gptpangu.py +2 -2
  2. tokenization_gptpangu.py +32 -10
modeling_gptpangu.py CHANGED
@@ -460,7 +460,7 @@ class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
460
 
461
  if attention_mask is not None and position_ids is None:
462
  # create position_ids on the fly for batch generation
463
- position_ids = attention_mask.long().cumsum(-1) - 1
464
  position_ids.masked_fill_(attention_mask == 0, 1)
465
  if past:
466
  position_ids = position_ids[:, -1].unsqueeze(-1)
@@ -521,7 +521,7 @@ class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
521
  shift_logits = lm_logits[..., :-1, :].contiguous()
522
  shift_labels = labels[..., 1:].contiguous()
523
  # Flatten the tokens
524
- loss_fct = nn.CrossEntropyLoss()
525
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
526
 
527
  if not return_dict:
 
460
 
461
  if attention_mask is not None and position_ids is None:
462
  # create position_ids on the fly for batch generation
463
+ position_ids = attention_mask.int().cumsum(-1).long() - 1
464
  position_ids.masked_fill_(attention_mask == 0, 1)
465
  if past:
466
  position_ids = position_ids[:, -1].unsqueeze(-1)
 
521
  shift_logits = lm_logits[..., :-1, :].contiguous()
522
  shift_labels = labels[..., 1:].contiguous()
523
  # Flatten the tokens
524
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
525
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
526
 
527
  if not return_dict:
tokenization_gptpangu.py CHANGED
@@ -6,6 +6,13 @@ import numpy as np
6
 
7
  from transformers.tokenization_utils import PreTrainedTokenizer
8
 
 
 
 
 
 
 
 
9
 
10
  class GPTPanguTokenizer(PreTrainedTokenizer):
11
  # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
@@ -69,10 +76,25 @@ class GPTPanguTokenizer(PreTrainedTokenizer):
69
 
70
  if isinstance(tokens, str):
71
  return self._convert_token_to_id_with_added_voc(tokens)
72
-
73
- new_seg = " ".join(tokens)
74
- return self.sp.encode(new_seg)
75
- # return tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def _convert_token_to_id(self, token):
78
  return self.sp.piece_to_id(token)
@@ -83,16 +105,16 @@ class GPTPanguTokenizer(PreTrainedTokenizer):
83
  def convert_ids_to_tokens(self, ids):
84
  return self.decode(ids)
85
 
86
- def decode(self, tokens, **kwargs):
87
- if isinstance(tokens, torch.Tensor) or isinstance(tokens, np.ndarray):
88
- tokens = tokens.tolist()
89
 
90
  if kwargs.get('skip_special_tokens', None) is True:
91
- tokens = [token for token in tokens if token not in self.all_special_ids]
92
- text = self.sp.decode(tokens)
93
  if isinstance(text, list):
94
  text = text[0]
95
- text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
96
  return text
97
 
98
  @property
 
6
 
7
  from transformers.tokenization_utils import PreTrainedTokenizer
8
 
9
+ jieba.add_word('<s>')
10
+ jieba.add_word('</s>')
11
+ jieba.add_word('<eot>')
12
+ jieba.add_word('<unk>')
13
+ jieba.add_word('<sep>')
14
+ jieba.add_word('<pad>')
15
+
16
 
17
  class GPTPanguTokenizer(PreTrainedTokenizer):
18
  # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
 
76
 
77
  if isinstance(tokens, str):
78
  return self._convert_token_to_id_with_added_voc(tokens)
79
+
80
+ special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]
81
+
82
+ ids = []
83
+ i = 0
84
+ for j in special_tokens_index:
85
+ new_seg = " ".join(tokens[i:j])
86
+ ids.extend(self.sp.encode(new_seg))
87
+ ids.append(self._convert_token_to_id(tokens[j]))
88
+ i = j + 1
89
+
90
+ new_seg = " ".join(tokens[i:])
91
+ ids.extend(self.sp.encode(new_seg))
92
+
93
+ return ids
94
+
95
+ # new_seg = " ".join(tokens)
96
+ # return self.sp.encode(new_seg)
97
+ # # return tokens
98
 
99
  def _convert_token_to_id(self, token):
100
  return self.sp.piece_to_id(token)
 
105
  def convert_ids_to_tokens(self, ids):
106
  return self.decode(ids)
107
 
108
+ def decode(self, ids, **kwargs):
109
+ if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):
110
+ ids = ids.tolist()
111
 
112
  if kwargs.get('skip_special_tokens', None) is True:
113
+ ids = [token_id for token_id in ids if token_id not in self.all_special_ids]
114
+ text = self.sp.decode(ids)
115
  if isinstance(text, list):
116
  text = text[0]
117
+ text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')#.replace('⁇', self.unk_token)
118
  return text
119
 
120
  @property