iioSnail commited on
Commit
f320103
1 Parent(s): 538b751

Upload csc_tokenizer.py

Browse files
Files changed (1) hide show
  1. csc_tokenizer.py +19 -46
csc_tokenizer.py CHANGED
@@ -108,7 +108,7 @@ class ChineseBertTokenizer(BertTokenizerFast):
108
  return_token_type_ids=return_token_type_ids,
109
  return_attention_mask=return_attention_mask,
110
  return_overflowing_tokens=return_overflowing_tokens,
111
- return_offsets_mapping=return_offsets_mapping,
112
  return_length=return_length,
113
  verbose=verbose,
114
  )
@@ -117,61 +117,34 @@ class ChineseBertTokenizer(BertTokenizerFast):
117
 
118
  pinyin_ids = None
119
  if type(text) == str:
120
- pinyin_ids = self.convert_ids_to_pinyin_ids(input_ids)
 
 
121
 
122
- if type(text) == list:
123
  pinyin_ids = []
124
- for ids in input_ids:
125
- pinyin_ids.append(self.convert_ids_to_pinyin_ids(ids))
 
 
126
 
127
  if torch.is_tensor(encoding.input_ids):
128
  pinyin_ids = torch.LongTensor(pinyin_ids)
129
 
130
  encoding['pinyin_ids'] = pinyin_ids
131
 
132
- return encoding
133
-
134
- def tokenize_sentence(self, sentence):
135
- # convert sentence to ids
136
- tokenizer_output = self.tokenizer.encode(sentence)
137
- bert_tokens = tokenizer_output.ids
138
- pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
139
- # assert,token nums should be same as pinyin token nums
140
- assert len(bert_tokens) <= self.max_length
141
- assert len(bert_tokens) == len(pinyin_tokens)
142
- # convert list to tensor
143
- input_ids = torch.LongTensor(bert_tokens)
144
- pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
145
- return input_ids, pinyin_ids
146
-
147
- def convert_ids_to_pinyin_ids(self, ids: List[int]):
148
- pinyin_ids = []
149
- tokens = self.convert_ids_to_tokens(ids)
150
- for token in tokens:
151
- if len(token) > 1:
152
- pinyin_ids.append([0] * 8)
153
- continue
154
-
155
- pinyin_string = pinyin(token, style=Style.TONE3, errors=lambda x: [['not chinese'] for _ in x])[0][0]
156
-
157
- if pinyin_string == "not chinese":
158
- pinyin_ids.append([0] * 8)
159
- continue
160
 
161
- if pinyin_string in self.pinyin2tensor:
162
- pinyin_ids.append(self.pinyin2tensor[pinyin_string])
163
- else:
164
- ids = [0] * 8
165
- for i, p in enumerate(pinyin_string):
166
- if p not in self.pinyin_dict["char2idx"]:
167
- ids = [0] * 8
168
- break
169
- ids[i] = self.pinyin_dict["char2idx"][p]
170
- pinyin_ids.append(pinyin_ids)
171
 
172
- return pinyin_ids
 
 
 
 
173
 
174
- def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
175
  # get pinyin of a sentence
176
  pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
177
  pinyin_locs = {}
@@ -194,7 +167,7 @@ class ChineseBertTokenizer(BertTokenizerFast):
194
 
195
  # find chinese character location, and generate pinyin ids
196
  pinyin_ids = []
197
- for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)):
198
  if offset[1] - offset[0] != 1:
199
  pinyin_ids.append([0] * 8)
200
  continue
 
108
  return_token_type_ids=return_token_type_ids,
109
  return_attention_mask=return_attention_mask,
110
  return_overflowing_tokens=return_overflowing_tokens,
111
+ return_offsets_mapping=True,
112
  return_length=return_length,
113
  verbose=verbose,
114
  )
 
117
 
118
  pinyin_ids = None
119
  if type(text) == str:
120
+ offsets = encoding.offset_mapping[0].tolist()
121
+ tokens = self.sentence_to_tokens(text, offsets)
122
+ pinyin_ids = [self.convert_sentence_to_pinyin_ids(text, tokens, offsets)]
123
 
124
+ if type(text) == list or type(text) == tuple:
125
  pinyin_ids = []
126
+ for i, sentence in enumerate(text):
127
+ offsets = encoding.offset_mapping[i].tolist()
128
+ tokens = self.sentence_to_tokens(sentence, offsets)
129
+ pinyin_ids.append(self.convert_sentence_to_pinyin_ids(sentence, tokens, offsets))
130
 
131
  if torch.is_tensor(encoding.input_ids):
132
  pinyin_ids = torch.LongTensor(pinyin_ids)
133
 
134
  encoding['pinyin_ids'] = pinyin_ids
135
 
136
+ if not return_offsets_mapping:
137
+ del encoding['offset_mapping']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ return encoding
 
 
 
 
 
 
 
 
 
140
 
141
+ def sentence_to_tokens(self, sentence, offsets):
142
+ tokens = []
143
+ for start, end in offsets:
144
+ tokens.append(sentence[start:end])
145
+ return tokens
146
 
147
+ def convert_sentence_to_pinyin_ids(self, sentence: str, tokens, offsets):
148
  # get pinyin of a sentence
149
  pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
150
  pinyin_locs = {}
 
167
 
168
  # find chinese character location, and generate pinyin ids
169
  pinyin_ids = []
170
+ for idx, (token, offset) in enumerate(zip(tokens, offsets)):
171
  if offset[1] - offset[0] != 1:
172
  pinyin_ids.append([0] * 8)
173
  continue