xusenlin commited on
Commit
26443e3
1 Parent(s): 738a065

Upload 9 files

Browse files
added_tokens.json CHANGED
@@ -1,3 +1,3 @@
1
- {
2
- "[UNK]": 39979
3
- }
 
1
+ {
2
+ "[UNK]": 39979
3
+ }
config.json CHANGED
@@ -1,30 +1,30 @@
1
- {
2
- "_name_or_path": "uie_base_pytorch",
3
- "architectures": [
4
- "UIEModel"
5
- ],
6
- "attention_probs_dropout_prob": 0.1,
7
- "auto_map": {
8
- "AutoModel": "modeling_uie.UIEModel"
9
- },
10
- "classifier_dropout": null,
11
- "hidden_act": "gelu",
12
- "hidden_dropout_prob": 0.1,
13
- "hidden_size": 768,
14
- "initializer_range": 0.02,
15
- "intermediate_size": 3072,
16
- "layer_norm_eps": 1e-12,
17
- "max_position_embeddings": 2048,
18
- "model_type": "ernie",
19
- "num_attention_heads": 12,
20
- "num_hidden_layers": 12,
21
- "pad_token_id": 0,
22
- "position_embedding_type": "absolute",
23
- "task_type_vocab_size": 3,
24
- "torch_dtype": "float32",
25
- "transformers_version": "4.39.1",
26
- "type_vocab_size": 4,
27
- "use_cache": true,
28
- "use_task_id": true,
29
- "vocab_size": 40000
30
- }
 
1
+ {
2
+ "_name_or_path": "uie_base_pytorch",
3
+ "architectures": [
4
+ "UIEModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModel": "modeling_uie.UIEModel"
9
+ },
10
+ "classifier_dropout": null,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 2048,
18
+ "model_type": "ernie",
19
+ "num_attention_heads": 12,
20
+ "num_hidden_layers": 12,
21
+ "pad_token_id": 0,
22
+ "position_embedding_type": "absolute",
23
+ "task_type_vocab_size": 3,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.44.2",
26
+ "type_vocab_size": 4,
27
+ "use_cache": true,
28
+ "use_task_id": true,
29
+ "vocab_size": 40000
30
+ }
decode_utils.py CHANGED
@@ -1,571 +1,697 @@
1
- import math
2
- import re
3
- from typing import (
4
- List,
5
- Union,
6
- Any,
7
- Optional,
8
- )
9
-
10
- import numpy as np
11
- import torch
12
- import torch.nn as nn
13
- from tqdm import tqdm
14
- from transformers import PreTrainedTokenizer
15
-
16
-
17
- def get_id_and_prob(spans, offset_map):
18
- prompt_length = 0
19
- for i in range(1, len(offset_map)):
20
- if offset_map[i] != [0, 0]:
21
- prompt_length += 1
22
- else:
23
- break
24
-
25
- for i in range(1, prompt_length + 1):
26
- offset_map[i][0] -= (prompt_length + 1)
27
- offset_map[i][1] -= (prompt_length + 1)
28
-
29
- sentence_id = []
30
- prob = []
31
- for start, end in spans:
32
- prob.append(start[1] * end[1])
33
- sentence_id.append(
34
- (offset_map[start[0]][0], offset_map[end[0]][1]))
35
- return sentence_id, prob
36
-
37
-
38
- def get_span(start_ids, end_ids, with_prob=False):
39
- """
40
- Get span set from position start and end list.
41
- Args:
42
- start_ids (List[int]/List[tuple]): The start index list.
43
- end_ids (List[int]/List[tuple]): The end index list.
44
- with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
45
- Returns:
46
- set: The span set without overlapping, every id can only be used once.
47
- """
48
- if with_prob:
49
- start_ids = sorted(start_ids, key=lambda x: x[0])
50
- end_ids = sorted(end_ids, key=lambda x: x[0])
51
- else:
52
- start_ids = sorted(start_ids)
53
- end_ids = sorted(end_ids)
54
-
55
- start_pointer = 0
56
- end_pointer = 0
57
- len_start = len(start_ids)
58
- len_end = len(end_ids)
59
- couple_dict = {}
60
-
61
- # 将每一个span的首/尾token的id进行配对(就近匹配,默认没有overlap的情况)
62
- while start_pointer < len_start and end_pointer < len_end:
63
- if with_prob:
64
- start_id = start_ids[start_pointer][0]
65
- end_id = end_ids[end_pointer][0]
66
- else:
67
- start_id = start_ids[start_pointer]
68
- end_id = end_ids[end_pointer]
69
-
70
- if start_id == end_id:
71
- couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
72
- start_pointer += 1
73
- end_pointer += 1
74
- continue
75
-
76
- if start_id < end_id:
77
- couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
78
- start_pointer += 1
79
- continue
80
-
81
- if start_id > end_id:
82
- end_pointer += 1
83
- continue
84
-
85
- result = [(couple_dict[end], end) for end in couple_dict]
86
- result = set(result)
87
- return result
88
-
89
-
90
- def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
91
- """
92
- Get idx of the last dimension in probability arrays, which is greater than a limitation.
93
- Args:
94
- probs (List[List[float]]): The input probability arrays.
95
- limit (float): The limitation for probability.
96
- return_prob (bool): Whether to return the probability
97
- Returns:
98
- List[List[int]]: The index of the last dimension meet the conditions.
99
- """
100
- probs = np.array(probs)
101
- dim_len = len(probs.shape)
102
- if dim_len > 1:
103
- result = []
104
- for p in probs:
105
- result.append(get_bool_ids_greater_than(p, limit, return_prob))
106
- return result
107
- else:
108
- result = []
109
- for i, p in enumerate(probs):
110
- if p > limit:
111
- if return_prob:
112
- result.append((i, p))
113
- else:
114
- result.append(i)
115
- return result
116
-
117
-
118
- def dbc2sbc(s):
119
- rs = ""
120
- for char in s:
121
- code = ord(char)
122
- if code == 0x3000:
123
- code = 0x0020
124
- else:
125
- code -= 0xfee0
126
- if not (0x0021 <= code <= 0x7e):
127
- rs += char
128
- continue
129
- rs += chr(code)
130
- return rs
131
-
132
-
133
- def cut_chinese_sent(para):
134
- """
135
- Cut the Chinese sentences more precisely, reference to
136
- "https://blog.csdn.net/blmoistawinde/article/details/82379256".
137
- """
138
- para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para)
139
- para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para)
140
- para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para)
141
- para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
142
- para = para.rstrip()
143
- return para.split("\n")
144
-
145
-
146
- def auto_splitter(input_texts, max_text_len, split_sentence=False):
147
- """
148
- Split the raw texts automatically for model inference.
149
- Args:
150
- input_texts (List[str]): input raw texts.
151
- max_text_len (int): cutting length.
152
- split_sentence (bool): If True, sentence-level split will be performed.
153
- return:
154
- short_input_texts (List[str]): the short input texts for model inference.
155
- input_mapping (dict): mapping between raw text and short input texts.
156
- """
157
- input_mapping = {}
158
- short_input_texts = []
159
- cnt_short = 0
160
- for cnt_org, text in enumerate(input_texts):
161
- sens = cut_chinese_sent(text) if split_sentence else [text]
162
- for sen in sens:
163
- lens = len(sen)
164
- if lens <= max_text_len:
165
- short_input_texts.append(sen)
166
- if cnt_org in input_mapping:
167
- input_mapping[cnt_org].append(cnt_short)
168
- else:
169
- input_mapping[cnt_org] = [cnt_short]
170
- cnt_short += 1
171
- else:
172
- temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)]
173
-
174
- short_input_texts.extend(temp_text_list)
175
- short_idx = cnt_short
176
- cnt_short += math.ceil(lens / max_text_len)
177
- temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
178
- if cnt_org in input_mapping:
179
- input_mapping[cnt_org].extend(temp_text_id)
180
- else:
181
- input_mapping[cnt_org] = temp_text_id
182
- return short_input_texts, input_mapping
183
-
184
-
185
- class UIEDecoder(nn.Module):
186
-
187
- keys_to_ignore_on_gpu = ["offset_mapping", "texts"]
188
-
189
- @torch.inference_mode()
190
- def predict(
191
- self,
192
- tokenizer: PreTrainedTokenizer,
193
- texts: Union[List[str], str],
194
- schema: Optional[Any] = None,
195
- batch_size: int = 64,
196
- max_length: int = 512,
197
- split_sentence: bool = False,
198
- position_prob: float = 0.5,
199
- is_english: bool = False,
200
- disable_tqdm: bool = True,
201
- ) -> List[Any]:
202
- self.eval()
203
- self.tokenizer = tokenizer
204
- self.is_english = is_english
205
- if schema is not None:
206
- self.set_schema(schema)
207
-
208
- texts = texts
209
- if isinstance(texts, str):
210
- texts = [texts]
211
- return self._multi_stage_predict(
212
- texts, batch_size, max_length, split_sentence, position_prob, disable_tqdm
213
- )
214
-
215
- def set_schema(self, schema):
216
- if isinstance(schema, (dict, str)):
217
- schema = [schema]
218
- self._schema_tree = self._build_tree(schema)
219
-
220
- def _multi_stage_predict(
221
- self,
222
- texts: List[str],
223
- batch_size: int = 64,
224
- max_length: int = 512,
225
- split_sentence: bool = False,
226
- position_prob: float = 0.5,
227
- disable_tqdm: bool = True,
228
- ) -> List[Any]:
229
- """ Traversal the schema tree and do multi-stage prediction. """
230
- results = [{} for _ in range(len(texts))]
231
- if len(texts) < 1 or self._schema_tree is None:
232
- return results
233
-
234
- schema_list = self._schema_tree.children[:]
235
- while len(schema_list) > 0:
236
- node = schema_list.pop(0)
237
- examples = []
238
- input_map = {}
239
- cnt = 0
240
- idx = 0
241
- if not node.prefix:
242
- for data in texts:
243
- examples.append({"text": data, "prompt": dbc2sbc(node.name)})
244
- input_map[cnt] = [idx]
245
- idx += 1
246
- cnt += 1
247
- else:
248
- for pre, data in zip(node.prefix, texts):
249
- if len(pre) == 0:
250
- input_map[cnt] = []
251
- else:
252
- for p in pre:
253
- if self.is_english:
254
- if re.search(r'\[.*?\]$', node.name):
255
- prompt_prefix = node.name[:node.name.find("[", 1)].strip()
256
- cls_options = re.search(r'\[.*?\]$', node.name).group()
257
- # Sentiment classification of xxx [positive, negative]
258
- prompt = prompt_prefix + p + " " + cls_options
259
- else:
260
- prompt = node.name + p
261
- else:
262
- prompt = p + node.name
263
- examples.append(
264
- {
265
- "text": data,
266
- "prompt": dbc2sbc(prompt)
267
- }
268
- )
269
- input_map[cnt] = [i + idx for i in range(len(pre))]
270
- idx += len(pre)
271
- cnt += 1
272
-
273
- result_list = self._single_stage_predict(
274
- examples, batch_size, max_length, split_sentence, position_prob, disable_tqdm
275
- ) if examples else []
276
- if not node.parent_relations:
277
- relations = [[] for _ in range(len(texts))]
278
- for k, v in input_map.items():
279
- for idx in v:
280
- if len(result_list[idx]) == 0:
281
- continue
282
- if node.name not in results[k].keys():
283
- results[k][node.name] = result_list[idx]
284
- else:
285
- results[k][node.name].extend(result_list[idx])
286
- if node.name in results[k].keys():
287
- relations[k].extend(results[k][node.name])
288
- else:
289
- relations = node.parent_relations
290
- for k, v in input_map.items():
291
- for i in range(len(v)):
292
- if len(result_list[v[i]]) == 0:
293
- continue
294
- if "relations" not in relations[k][i].keys():
295
- relations[k][i]["relations"] = {node.name: result_list[v[i]]}
296
- elif node.name not in relations[k][i]["relations"].keys():
297
- relations[k][i]["relations"][node.name] = result_list[v[i]]
298
- else:
299
- relations[k][i]["relations"][node.name].extend(result_list[v[i]])
300
-
301
- new_relations = [[] for _ in range(len(texts))]
302
- for i in range(len(relations)):
303
- for j in range(len(relations[i])):
304
- if "relations" in relations[i][j].keys() and node.name in relations[i][j]["relations"].keys():
305
- for k in range(len(relations[i][j]["relations"][node.name])):
306
- new_relations[i].append(relations[i][j]["relations"][node.name][k])
307
- relations = new_relations
308
-
309
- prefix = [[] for _ in range(len(texts))]
310
- for k, v in input_map.items():
311
- for idx in v:
312
- for i in range(len(result_list[idx])):
313
- if self.is_english:
314
- prefix[k].append(" of " + result_list[idx][i]["text"])
315
- else:
316
- prefix[k].append(result_list[idx][i]["text"] + "的")
317
-
318
- for child in node.children:
319
- child.prefix = prefix
320
- child.parent_relations = relations
321
- schema_list.append(child)
322
-
323
- return results
324
-
325
- def _convert_ids_to_results(self, examples, sentence_ids, probs):
326
- """ Convert ids to raw text in a single stage. """
327
- results = []
328
- for example, sentence_id, prob in zip(examples, sentence_ids, probs):
329
- if len(sentence_id) == 0:
330
- results.append([])
331
- continue
332
- result_list = []
333
- text = example["text"]
334
- prompt = example["prompt"]
335
- for i in range(len(sentence_id)):
336
- start, end = sentence_id[i]
337
- if start < 0 and end >= 0:
338
- continue
339
- if end < 0:
340
- start += len(prompt) + 1
341
- end += len(prompt) + 1
342
- result = {"text": prompt[start: end], "probability": prob[i]}
343
- else:
344
- result = {"text": text[start: end], "start": start, "end": end, "probability": prob[i]}
345
-
346
- result_list.append(result)
347
- results.append(result_list)
348
- return results
349
-
350
- def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
351
- """
352
- Split the raw texts automatically for model inference.
353
- Args:
354
- input_texts (List[str]): input raw texts.
355
- max_text_len (int): cutting length.
356
- split_sentence (bool): If True, sentence-level split will be performed.
357
- return:
358
- short_input_texts (List[str]): the short input texts for model inference.
359
- input_mapping (dict): mapping between raw text and short input texts.
360
- """
361
- input_mapping = {}
362
- short_input_texts = []
363
- cnt_short = 0
364
- for cnt_org, text in enumerate(input_texts):
365
- sens = cut_chinese_sent(text) if split_sentence else [text]
366
- for sen in sens:
367
- lens = len(sen)
368
- if lens <= max_text_len:
369
- short_input_texts.append(sen)
370
- if cnt_org in input_mapping:
371
- input_mapping[cnt_org].append(cnt_short)
372
- else:
373
- input_mapping[cnt_org] = [cnt_short]
374
- cnt_short += 1
375
- else:
376
- temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)]
377
-
378
- short_input_texts.extend(temp_text_list)
379
- short_idx = cnt_short
380
- cnt_short += math.ceil(lens / max_text_len)
381
- temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
382
- if cnt_org in input_mapping:
383
- input_mapping[cnt_org].extend(temp_text_id)
384
- else:
385
- input_mapping[cnt_org] = temp_text_id
386
- return short_input_texts, input_mapping
387
-
388
- def _single_stage_predict(
389
- self,
390
- inputs: List[dict],
391
- batch_size: int = 64,
392
- max_length: int = 512,
393
- split_sentence: bool = False,
394
- position_prob: float = 0.5,
395
- disable_tqdm: bool = True,
396
- ):
397
- input_texts = []
398
- prompts = []
399
- for i in range(len(inputs)):
400
- input_texts.append(inputs[i]["text"])
401
- prompts.append(inputs[i]["prompt"])
402
- # max predict length should exclude the length of prompt and summary tokens
403
- max_predict_len = max_length - len(max(prompts)) - 3
404
-
405
- short_input_texts, input_mapping = self._auto_splitter(
406
- input_texts, max_predict_len, split_sentence=split_sentence
407
- )
408
-
409
- short_texts_prompts = []
410
- for k, v in input_mapping.items():
411
- short_texts_prompts.extend([prompts[k] for _ in range(len(v))])
412
- short_inputs = [
413
- {
414
- "text": short_input_texts[i],
415
- "prompt": short_texts_prompts[i]
416
- }
417
- for i in range(len(short_input_texts))
418
- ]
419
-
420
- encoded_inputs = self.tokenizer(
421
- text=short_texts_prompts,
422
- text_pair=short_input_texts,
423
- stride=2,
424
- truncation=True,
425
- max_length=max_length,
426
- padding="longest",
427
- add_special_tokens=True,
428
- return_offsets_mapping=True,
429
- return_tensors="np")
430
- offset_maps = encoded_inputs["offset_mapping"]
431
-
432
- start_prob_concat, end_prob_concat = [], []
433
- if disable_tqdm:
434
- batch_iterator = range(0, len(short_input_texts), batch_size)
435
- else:
436
- batch_iterator = tqdm(range(0, len(short_input_texts), batch_size), desc="Predicting", unit="batch")
437
- for batch_start in batch_iterator:
438
- batch = {
439
- key:
440
- np.array(value[batch_start: batch_start + batch_size], dtype="int64")
441
- for key, value in encoded_inputs.items() if key not in self.keys_to_ignore_on_gpu
442
- }
443
-
444
- for k, v in batch.items():
445
- batch[k] = torch.tensor(v, device=self.device)
446
-
447
- outputs = self(**batch)
448
- start_prob, end_prob = outputs[0], outputs[1]
449
- if self.device != torch.device("cpu"):
450
- start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
451
- start_prob_concat.append(start_prob.detach().numpy())
452
- end_prob_concat.append(end_prob.detach().numpy())
453
-
454
- start_prob_concat = np.concatenate(start_prob_concat)
455
- end_prob_concat = np.concatenate(end_prob_concat)
456
-
457
- start_ids_list = get_bool_ids_greater_than(start_prob_concat, limit=position_prob, return_prob=True)
458
- end_ids_list = get_bool_ids_greater_than(end_prob_concat, limit=position_prob, return_prob=True)
459
-
460
- input_ids = encoded_inputs['input_ids'].tolist()
461
- sentence_ids, probs = [], []
462
- for start_ids, end_ids, ids, offset_map in zip(start_ids_list, end_ids_list, input_ids, offset_maps):
463
- span_list = get_span(start_ids, end_ids, with_prob=True)
464
- sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist())
465
- sentence_ids.append(sentence_id)
466
- probs.append(prob)
467
-
468
- results = self._convert_ids_to_results(short_inputs, sentence_ids, probs)
469
- results = self._auto_joiner(results, short_input_texts, input_mapping)
470
- return results
471
-
472
- def _auto_joiner(self, short_results, short_inputs, input_mapping):
473
- concat_results = []
474
- is_cls_task = False
475
- for short_result in short_results:
476
- if not short_result:
477
- continue
478
- elif 'start' not in short_result[0].keys() and 'end' not in short_result[0].keys():
479
- is_cls_task = True
480
- break
481
- else:
482
- break
483
- for k, vs in input_mapping.items():
484
- single_results = []
485
- if is_cls_task:
486
- cls_options = {}
487
- for v in vs:
488
- if len(short_results[v]) == 0:
489
- continue
490
- if short_results[v][0]['text'] in cls_options:
491
- cls_options[short_results[v][0]["text"]][0] += 1
492
- cls_options[short_results[v][0]["text"]][1] += short_results[v][0]["probability"]
493
-
494
- else:
495
- cls_options[short_results[v][0]["text"]] = [1, short_results[v][0]["probability"]]
496
-
497
- if cls_options:
498
- cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1])
499
- concat_results.append(
500
- [
501
- {"text": cls_res, "probability": cls_info[1] / cls_info[0]}
502
- ]
503
- )
504
-
505
- else:
506
- concat_results.append([])
507
- else:
508
- offset = 0
509
- for v in vs:
510
- if v == 0:
511
- single_results = short_results[v]
512
- offset += len(short_inputs[v])
513
- else:
514
- for i in range(len(short_results[v])):
515
- if 'start' not in short_results[v][i] or 'end' not in short_results[v][i]:
516
- continue
517
- short_results[v][i]["start"] += offset
518
- short_results[v][i]["end"] += offset
519
- offset += len(short_inputs[v])
520
- single_results.extend(short_results[v])
521
- concat_results.append(single_results)
522
- return concat_results
523
-
524
- @classmethod
525
- def _build_tree(cls, schema, name='root'):
526
- """
527
- Build the schema tree.
528
- """
529
- schema_tree = SchemaTree(name)
530
- for s in schema:
531
- if isinstance(s, str):
532
- schema_tree.add_child(SchemaTree(s))
533
- elif isinstance(s, dict):
534
- for k, v in s.items():
535
- if isinstance(v, str):
536
- child = [v]
537
- elif isinstance(v, list):
538
- child = v
539
- else:
540
- raise TypeError(
541
- f"Invalid schema, value for each key:value pairs should be list or string"
542
- f"but {type(v)} received")
543
- schema_tree.add_child(cls._build_tree(child, name=k))
544
- else:
545
- raise TypeError(f"Invalid schema, element should be string or dict, but {type(s)} received")
546
-
547
- return schema_tree
548
-
549
-
550
- class SchemaTree(object):
551
- """
552
- Implementation of SchemaTree
553
- """
554
-
555
- def __init__(self, name='root', children=None):
556
- self.name = name
557
- self.children = []
558
- self.prefix = None
559
- self.parent_relations = None
560
- if children is not None:
561
- for child in children:
562
- self.add_child(child)
563
-
564
- def __repr__(self):
565
- return self.name
566
-
567
- def add_child(self, node):
568
- assert isinstance(
569
- node, SchemaTree
570
- ), "The children of a node should be an instance of SchemaTree."
571
- self.children.append(node)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import queue
5
+ import re
6
+ from multiprocessing import Queue
7
+ from typing import (
8
+ List,
9
+ Tuple,
10
+ Union,
11
+ Dict,
12
+ Any,
13
+ Set,
14
+ TYPE_CHECKING,
15
+ Optional,
16
+ Literal,
17
+ )
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.multiprocessing as mp
22
+ import torch.nn as nn
23
+ from tqdm import tqdm
24
+ from transformers import is_torch_npu_available
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers import PreTrainedTokenizer
28
+
29
+
30
+ os.environ["PYTHONWARNINGS"] = "ignore"
31
+ logger = logging.getLogger("FASTIE")
32
+
33
+
34
+ def get_id_and_prob(spans, offset_map):
35
+ prompt_length = 0
36
+ for i in range(1, len(offset_map)):
37
+ if offset_map[i] != [0, 0]:
38
+ prompt_length += 1
39
+ else:
40
+ break
41
+
42
+ for i in range(1, prompt_length + 1):
43
+ offset_map[i][0] -= (prompt_length + 1)
44
+ offset_map[i][1] -= (prompt_length + 1)
45
+
46
+ sentence_id = []
47
+ prob = []
48
+ for start, end in spans:
49
+ prob.append(float(start[1] * end[1]))
50
+ sentence_id.append(
51
+ (offset_map[start[0]][0], offset_map[end[0]][1]))
52
+ return sentence_id, prob
53
+
54
+
55
+ def get_span(
56
+ start_ids: Union[List[int], List[Tuple[int, float]]],
57
+ end_ids: Union[List[int], List[Tuple[int, float]]],
58
+ with_prob: bool = False
59
+ ) -> Set[Tuple[int, int]]:
60
+ """
61
+ Get span set from position start and end list.
62
+ Args:
63
+ start_ids (List[int]/List[tuple]): The start index list.
64
+ end_ids (List[int]/List[tuple]): The end index list.
65
+ with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
66
+ Returns:
67
+ set: The span set without overlapping, every id can only be used once.
68
+ """
69
+ if with_prob:
70
+ start_ids = sorted(start_ids, key=lambda x: x[0])
71
+ end_ids = sorted(end_ids, key=lambda x: x[0])
72
+ else:
73
+ start_ids = sorted(start_ids)
74
+ end_ids = sorted(end_ids)
75
+
76
+ start_pointer = 0
77
+ end_pointer = 0
78
+ len_start = len(start_ids)
79
+ len_end = len(end_ids)
80
+ couple_dict = {}
81
+
82
+ # 将每一个span的首/尾token的id进行配对(就近匹配,默认没有overlap的情况)
83
+ while start_pointer < len_start and end_pointer < len_end:
84
+ if with_prob:
85
+ start_id = start_ids[start_pointer][0]
86
+ end_id = end_ids[end_pointer][0]
87
+ else:
88
+ start_id = start_ids[start_pointer]
89
+ end_id = end_ids[end_pointer]
90
+
91
+ if start_id == end_id:
92
+ couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
93
+ start_pointer += 1
94
+ end_pointer += 1
95
+ continue
96
+
97
+ if start_id < end_id:
98
+ couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
99
+ start_pointer += 1
100
+ continue
101
+
102
+ if start_id > end_id:
103
+ end_pointer += 1
104
+ continue
105
+
106
+ result = [(couple_dict[end], end) for end in couple_dict]
107
+ result = set(result)
108
+ return result
109
+
110
+
111
+ def get_bool_ids_greater_than(
112
+ probs: List[List[float]], limit: float = 0.5, return_prob: bool = False
113
+ ) -> List[List[int]]:
114
+ """
115
+ Get idx of the last dimension in probability arrays, which is greater than a limitation.
116
+ Args:
117
+ probs (List[List[float]]): The input probability arrays.
118
+ limit (float): The limitation for probability.
119
+ return_prob (bool): Whether to return the probability
120
+ Returns:
121
+ List[List[int]]: The index of the last dimension meet the conditions.
122
+ """
123
+ probs = np.array(probs)
124
+ dim_len = len(probs.shape)
125
+ if dim_len > 1:
126
+ result = []
127
+ for p in probs:
128
+ result.append(get_bool_ids_greater_than(p, limit, return_prob))
129
+ return result
130
+ else:
131
+ result = []
132
+ for i, p in enumerate(probs):
133
+ if p > limit:
134
+ if return_prob:
135
+ result.append((i, p))
136
+ else:
137
+ result.append(i)
138
+ return result
139
+
140
+
141
+ def dbc2sbc(s) -> str:
142
+ rs = ""
143
+ for char in s:
144
+ code = ord(char)
145
+ if code == 0x3000:
146
+ code = 0x0020
147
+ else:
148
+ code -= 0xfee0
149
+ if not (0x0021 <= code <= 0x7e):
150
+ rs += char
151
+ continue
152
+ rs += chr(code)
153
+ return rs
154
+
155
+
156
+ def cut_chinese_sent(para: str) -> List[str]:
157
+ """
158
+ Cut the Chinese sentences more precisely, reference to
159
+ "https://blog.csdn.net/blmoistawinde/article/details/82379256".
160
+ """
161
+ para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para)
162
+ para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para)
163
+ para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para)
164
+ para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
165
+ para = para.rstrip()
166
+ return para.split("\n")
167
+
168
+
169
+ class UIEDecoder(nn.Module):
170
+
171
+ keys_to_ignore_on_gpu = ["offset_mapping", "texts"]
172
+
173
+ @torch.inference_mode()
174
+ def predict(
175
+ self,
176
+ tokenizer: "PreTrainedTokenizer",
177
+ texts: Union[List[str], str],
178
+ schema: Optional[Any] = None,
179
+ batch_size: int = 64,
180
+ max_length: int = 512,
181
+ split_sentence: bool = False,
182
+ position_prob: float = 0.5,
183
+ language: Optional[str] = "zh",
184
+ show_progress_bar: bool = None,
185
+ device: Optional[str] = None,
186
+ ) -> List[Any]:
187
+ self.eval()
188
+ self.is_english = False if language.lower() in ["zh", "zh-cn", "chinese"] else True
189
+ if schema is not None:
190
+ self.set_schema(schema)
191
+
192
+ if show_progress_bar is None:
193
+ show_progress_bar = (
194
+ logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
195
+ )
196
+ # Cast an individual text to a list with length 1
197
+ if isinstance(texts, str) or not hasattr(texts, "__len__"):
198
+ texts = [texts]
199
+
200
+ if device is None:
201
+ device = next(self.parameters()).device
202
+
203
+ self.to(device)
204
+
205
+ return self._multi_stage_predict(
206
+ tokenizer, texts, batch_size, max_length, split_sentence, position_prob, show_progress_bar
207
+ )
208
+
209
+ def set_schema(self, schema):
210
+ if isinstance(schema, (dict, str)):
211
+ schema = [schema]
212
+ self._schema_tree = self._build_tree(schema)
213
+
214
+ def _multi_stage_predict(
215
+ self,
216
+ tokenizer: "PreTrainedTokenizer",
217
+ texts: List[str],
218
+ batch_size: int = 64,
219
+ max_length: int = 512,
220
+ split_sentence: bool = False,
221
+ position_prob: float = 0.5,
222
+ show_progress_bar: bool = False,
223
+ ) -> List[Any]:
224
+ """ Traversal the schema tree and do multi-stage prediction. """
225
+ results = [{} for _ in range(len(texts))]
226
+ if len(texts) < 1 or self._schema_tree is None:
227
+ return results
228
+
229
+ schema_list = self._schema_tree.children[:]
230
+ while len(schema_list) > 0:
231
+ node = schema_list.pop(0)
232
+ examples = []
233
+ input_map = {}
234
+ cnt = 0
235
+ idx = 0
236
+ if not node.prefix:
237
+ for data in texts:
238
+ examples.append({"text": data, "prompt": dbc2sbc(node.name)})
239
+ input_map[cnt] = [idx]
240
+ idx += 1
241
+ cnt += 1
242
+ else:
243
+ for pre, data in zip(node.prefix, texts):
244
+ if len(pre) == 0:
245
+ input_map[cnt] = []
246
+ else:
247
+ for p in pre:
248
+ if self.is_english:
249
+ if re.search(r'\[.*?\]$', node.name):
250
+ prompt_prefix = node.name[:node.name.find("[", 1)].strip()
251
+ cls_options = re.search(r'\[.*?\]$', node.name).group()
252
+ # Sentiment classification of xxx [positive, negative]
253
+ prompt = prompt_prefix + p + " " + cls_options
254
+ else:
255
+ prompt = node.name + p
256
+ else:
257
+ prompt = p + node.name
258
+ examples.append(
259
+ {
260
+ "text": data,
261
+ "prompt": dbc2sbc(prompt)
262
+ }
263
+ )
264
+ input_map[cnt] = [i + idx for i in range(len(pre))]
265
+ idx += len(pre)
266
+ cnt += 1
267
+
268
+ result_list = self._single_stage_predict(
269
+ tokenizer, examples, batch_size, max_length, split_sentence, position_prob, show_progress_bar
270
+ ) if examples else []
271
+ if not node.parent_relations:
272
+ relations = [[] for _ in range(len(texts))]
273
+ for k, v in input_map.items():
274
+ for idx in v:
275
+ if len(result_list[idx]) == 0:
276
+ continue
277
+ if node.name not in results[k].keys():
278
+ results[k][node.name] = result_list[idx]
279
+ else:
280
+ results[k][node.name].extend(result_list[idx])
281
+ if node.name in results[k].keys():
282
+ relations[k].extend(results[k][node.name])
283
+ else:
284
+ relations = node.parent_relations
285
+ for k, v in input_map.items():
286
+ for i in range(len(v)):
287
+ if len(result_list[v[i]]) == 0:
288
+ continue
289
+ if "relations" not in relations[k][i].keys():
290
+ relations[k][i]["relations"] = {node.name: result_list[v[i]]}
291
+ elif node.name not in relations[k][i]["relations"].keys():
292
+ relations[k][i]["relations"][node.name] = result_list[v[i]]
293
+ else:
294
+ relations[k][i]["relations"][node.name].extend(result_list[v[i]])
295
+
296
+ new_relations = [[] for _ in range(len(texts))]
297
+ for i in range(len(relations)):
298
+ for j in range(len(relations[i])):
299
+ if "relations" in relations[i][j].keys() and node.name in relations[i][j]["relations"].keys():
300
+ for k in range(len(relations[i][j]["relations"][node.name])):
301
+ new_relations[i].append(relations[i][j]["relations"][node.name][k])
302
+ relations = new_relations
303
+
304
+ prefix = [[] for _ in range(len(texts))]
305
+ for k, v in input_map.items():
306
+ for idx in v:
307
+ for i in range(len(result_list[idx])):
308
+ if self.is_english:
309
+ prefix[k].append(" of " + result_list[idx][i]["text"])
310
+ else:
311
+ prefix[k].append(result_list[idx][i]["text"] + "的")
312
+
313
+ for child in node.children:
314
+ child.prefix = prefix
315
+ child.parent_relations = relations
316
+ schema_list.append(child)
317
+
318
+ return results
319
+
320
+ def _convert_ids_to_results(self, examples, sentence_ids, probs):
321
+ """ Convert ids to raw text in a single stage. """
322
+ results = []
323
+ for example, sentence_id, prob in zip(examples, sentence_ids, probs):
324
+ if len(sentence_id) == 0:
325
+ results.append([])
326
+ continue
327
+ result_list = []
328
+ text = example["text"]
329
+ prompt = example["prompt"]
330
+ for i in range(len(sentence_id)):
331
+ start, end = sentence_id[i]
332
+ if start < 0 and end >= 0:
333
+ continue
334
+ if end < 0:
335
+ start += len(prompt) + 1
336
+ end += len(prompt) + 1
337
+ result = {"text": prompt[start: end], "probability": float(prob[i])}
338
+ else:
339
+ result = {"text": text[start: end], "start": start, "end": end, "probability": float(prob[i])}
340
+
341
+ result_list.append(result)
342
+ results.append(result_list)
343
+ return results
344
+
345
+ def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
346
+ """
347
+ Split the raw texts automatically for model inference.
348
+ Args:
349
+ input_texts (List[str]): input raw texts.
350
+ max_text_len (int): cutting length.
351
+ split_sentence (bool): If True, sentence-level split will be performed.
352
+ return:
353
+ short_input_texts (List[str]): the short input texts for model inference.
354
+ input_mapping (dict): mapping between raw text and short input texts.
355
+ """
356
+ input_mapping = {}
357
+ short_input_texts = []
358
+ cnt_short = 0
359
+ for cnt_org, text in enumerate(input_texts):
360
+ sens = cut_chinese_sent(text) if split_sentence else [text]
361
+ for sen in sens:
362
+ lens = len(sen)
363
+ if lens <= max_text_len:
364
+ short_input_texts.append(sen)
365
+ if cnt_org in input_mapping:
366
+ input_mapping[cnt_org].append(cnt_short)
367
+ else:
368
+ input_mapping[cnt_org] = [cnt_short]
369
+ cnt_short += 1
370
+ else:
371
+ temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)]
372
+
373
+ short_input_texts.extend(temp_text_list)
374
+ short_idx = cnt_short
375
+ cnt_short += math.ceil(lens / max_text_len)
376
+ temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
377
+ if cnt_org in input_mapping:
378
+ input_mapping[cnt_org].extend(temp_text_id)
379
+ else:
380
+ input_mapping[cnt_org] = temp_text_id
381
+ return short_input_texts, input_mapping
382
+
383
+ def _single_stage_predict(
384
+ self,
385
+ tokenizer: "PreTrainedTokenizer",
386
+ inputs: List[dict],
387
+ batch_size: int = 64,
388
+ max_length: int = 512,
389
+ split_sentence: bool = False,
390
+ position_prob: float = 0.5,
391
+ show_progress_bar: bool = False,
392
+ ) -> List[Any]:
393
+ input_texts = []
394
+ prompts = []
395
+ for i in range(len(inputs)):
396
+ input_texts.append(inputs[i]["text"])
397
+ prompts.append(inputs[i]["prompt"])
398
+ # max predict length should exclude the length of prompt and summary tokens
399
+ max_predict_len = max_length - len(max(prompts)) - 3
400
+
401
+ short_input_texts, input_mapping = self._auto_splitter(
402
+ input_texts, max_predict_len, split_sentence=split_sentence
403
+ )
404
+
405
+ short_texts_prompts = []
406
+ for k, v in input_mapping.items():
407
+ short_texts_prompts.extend([prompts[k] for _ in range(len(v))])
408
+ short_inputs = [
409
+ {
410
+ "text": short_input_texts[i],
411
+ "prompt": short_texts_prompts[i]
412
+ }
413
+ for i in range(len(short_input_texts))
414
+ ]
415
+
416
+ encoded_inputs = tokenizer(
417
+ text=short_texts_prompts,
418
+ text_pair=short_input_texts,
419
+ stride=2,
420
+ truncation=True,
421
+ max_length=512,
422
+ padding="max_length",
423
+ add_special_tokens=True,
424
+ return_offsets_mapping=True,
425
+ return_tensors="np",
426
+ )
427
+ offset_maps = encoded_inputs["offset_mapping"]
428
+
429
+ start_prob_concat, end_prob_concat = [], []
430
+
431
+ batch_iterator = tqdm(range(0, len(short_input_texts), batch_size), desc="Batches", disable=not show_progress_bar)
432
+ for batch_start in batch_iterator:
433
+ batch = {
434
+ key:
435
+ np.array(value[batch_start: batch_start + batch_size], dtype="int64")
436
+ for key, value in encoded_inputs.items() if key not in self.keys_to_ignore_on_gpu
437
+ }
438
+
439
+ for k, v in batch.items():
440
+ batch[k] = torch.tensor(v, device=self.device)
441
+
442
+ outputs = self(**batch)
443
+ start_prob, end_prob = outputs[0], outputs[1]
444
+ if self.device != torch.device("cpu"):
445
+ start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
446
+ start_prob_concat.append(start_prob.detach().numpy())
447
+ end_prob_concat.append(end_prob.detach().numpy())
448
+
449
+ start_prob_concat = np.concatenate(start_prob_concat)
450
+ end_prob_concat = np.concatenate(end_prob_concat)
451
+
452
+ start_ids_list = get_bool_ids_greater_than(start_prob_concat, limit=position_prob, return_prob=True)
453
+ end_ids_list = get_bool_ids_greater_than(end_prob_concat, limit=position_prob, return_prob=True)
454
+
455
+ input_ids = encoded_inputs["input_ids"].tolist()
456
+ sentence_ids, probs = [], []
457
+ for start_ids, end_ids, ids, offset_map in zip(start_ids_list, end_ids_list, input_ids, offset_maps):
458
+ span_list = get_span(start_ids, end_ids, with_prob=True)
459
+ sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist())
460
+ sentence_ids.append(sentence_id)
461
+ probs.append(prob)
462
+
463
+ results = self._convert_ids_to_results(short_inputs, sentence_ids, probs)
464
+ results = self._auto_joiner(results, short_input_texts, input_mapping)
465
+ return results
466
+
467
+ def _auto_joiner(self, short_results, short_inputs, input_mapping):
468
+ concat_results = []
469
+ is_cls_task = False
470
+ for short_result in short_results:
471
+ if not short_result:
472
+ continue
473
+ elif 'start' not in short_result[0].keys() and 'end' not in short_result[0].keys():
474
+ is_cls_task = True
475
+ break
476
+ else:
477
+ break
478
+ for k, vs in input_mapping.items():
479
+ single_results = []
480
+ if is_cls_task:
481
+ cls_options = {}
482
+ for v in vs:
483
+ if len(short_results[v]) == 0:
484
+ continue
485
+ if short_results[v][0]['text'] in cls_options:
486
+ cls_options[short_results[v][0]["text"]][0] += 1
487
+ cls_options[short_results[v][0]["text"]][1] += short_results[v][0]["probability"]
488
+
489
+ else:
490
+ cls_options[short_results[v][0]["text"]] = [1, short_results[v][0]["probability"]]
491
+
492
+ if cls_options:
493
+ cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1])
494
+ concat_results.append(
495
+ [
496
+ {"text": cls_res, "probability": cls_info[1] / cls_info[0]}
497
+ ]
498
+ )
499
+
500
+ else:
501
+ concat_results.append([])
502
+ else:
503
+ offset = 0
504
+ for v in vs:
505
+ if v == 0:
506
+ single_results = short_results[v]
507
+ offset += len(short_inputs[v])
508
+ else:
509
+ for i in range(len(short_results[v])):
510
+ if "start" not in short_results[v][i] or 'end' not in short_results[v][i]:
511
+ continue
512
+ short_results[v][i]["start"] += offset
513
+ short_results[v][i]["end"] += offset
514
+ offset += len(short_inputs[v])
515
+ single_results.extend(short_results[v])
516
+ concat_results.append(single_results)
517
+ return concat_results
518
+
519
+ @classmethod
520
+ def _build_tree(cls, schema, name="root"):
521
+ """
522
+ Build the schema tree.
523
+ """
524
+ schema_tree = SchemaTree(name)
525
+ for s in schema:
526
+ if isinstance(s, str):
527
+ schema_tree.add_child(SchemaTree(s))
528
+ elif isinstance(s, dict):
529
+ for k, v in s.items():
530
+ if isinstance(v, str):
531
+ child = [v]
532
+ elif isinstance(v, list):
533
+ child = v
534
+ else:
535
+ raise TypeError(
536
+ f"Invalid schema, value for each key:value pairs should be list or string"
537
+ f"but {type(v)} received")
538
+ schema_tree.add_child(cls._build_tree(child, name=k))
539
+ else:
540
+ raise TypeError(f"Invalid schema, element should be string or dict, but {type(s)} received")
541
+
542
+ return schema_tree
543
+
544
+ def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[
545
+ Literal["input", "output", "processes"], Any]:
546
+ """启动多进程池,用多个独立进程进行预测
547
+ 如果要在多个GPU或CPU上进行预测,建议使用此方法,建议每个GPU只启动一个进程
548
+
549
+ Args:
550
+ target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
551
+ ["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
552
+ is available, then all available CUDA/NPU devices will be used. If target_devices is None and
553
+ CUDA/NPU is not available, then 4 CPU devices will be used.
554
+
555
+ Returns:
556
+ Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
557
+ """
558
+ if target_devices is None:
559
+ if torch.cuda.is_available():
560
+ target_devices = ["cuda:{}".format(i) for i in range(torch.cuda.device_count())]
561
+ elif is_torch_npu_available():
562
+ target_devices = ["npu:{}".format(i) for i in range(torch.npu.device_count())]
563
+ else:
564
+ logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
565
+ target_devices = ["cpu"] * 4
566
+
567
+ logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))
568
+
569
+ self.to("cpu")
570
+ self.share_memory()
571
+ ctx = mp.get_context("spawn")
572
+ input_queue = ctx.Queue()
573
+ output_queue = ctx.Queue()
574
+ processes = []
575
+
576
+ for device_id in target_devices:
577
+ p = ctx.Process(
578
+ target=UIEDecoder._predict_multi_process_worker,
579
+ args=(device_id, self, input_queue, output_queue),
580
+ daemon=True,
581
+ )
582
+ p.start()
583
+ processes.append(p)
584
+
585
+ return {"input": input_queue, "output": output_queue, "processes": processes}
586
+
587
+ @staticmethod
588
+ def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None:
589
+ """
590
+ Stops all processes started with start_multi_process_pool.
591
+
592
+ Args:
593
+ pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
594
+
595
+ Returns:
596
+ None
597
+ """
598
+ for p in pool["processes"]:
599
+ p.terminate()
600
+
601
+ for p in pool["processes"]:
602
+ p.join()
603
+ p.close()
604
+
605
+ pool["input"].close()
606
+ pool["output"].close()
607
+
608
+ def predict_multi_process(
609
+ self,
610
+ tokenizer: "PreTrainedTokenizer",
611
+ texts: List[str],
612
+ pool: Dict[Literal["input", "output", "processes"], Any],
613
+ batch_size: int = 64,
614
+ max_length: int = 512,
615
+ split_sentence: bool = False,
616
+ language: Optional[str] = "zh",
617
+ position_prob: float = 0.5,
618
+ chunk_size: Optional[int] = None,
619
+ ) -> List[Any]:
620
+ if chunk_size is None:
621
+ chunk_size = min(math.ceil(len(texts) / len(pool["processes"]) / 10), 5000)
622
+
623
+ logger.debug(f"Chunk data into {math.ceil(len(texts) / chunk_size)} packages of size {chunk_size}")
624
+
625
+ input_queue = pool["input"]
626
+ last_chunk_id = 0
627
+ chunk = []
628
+
629
+ for text in texts:
630
+ chunk.append(text)
631
+ if len(chunk) >= chunk_size:
632
+ input_queue.put(
633
+ [last_chunk_id, tokenizer, batch_size, chunk, max_length, split_sentence, language, position_prob]
634
+ )
635
+ last_chunk_id += 1
636
+ chunk = []
637
+
638
+ if len(chunk) > 0:
639
+ input_queue.put(
640
+ [last_chunk_id, tokenizer, batch_size, chunk, max_length, split_sentence, language, position_prob]
641
+ )
642
+ last_chunk_id += 1
643
+
644
+ output_queue = pool["output"]
645
+ results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
646
+ return sum([result[1] for result in results_list], [])
647
+
648
+ @staticmethod
649
+ def _predict_multi_process_worker(
650
+ target_device: str, model: "UIEDecoder", input_queue: Queue, results_queue: Queue
651
+ ) -> None:
652
+ """
653
+ Internal working process to predict in multi-process setup
654
+ """
655
+ while True:
656
+ try:
657
+ chunk_id, tokenizer, batch_size, chunk, max_length, split_sentence, language, position_prob = (
658
+ input_queue.get()
659
+ )
660
+ results = model.predict(
661
+ tokenizer,
662
+ chunk,
663
+ batch_size=batch_size,
664
+ max_length=max_length,
665
+ split_sentence=split_sentence,
666
+ language=language,
667
+ show_progress_bar=False,
668
+ device=target_device,
669
+ )
670
+
671
+ results_queue.put([chunk_id, results])
672
+ except queue.Empty:
673
+ break
674
+
675
+
676
+ class SchemaTree(object):
677
+ """
678
+ Implementation of SchemaTree
679
+ """
680
+
681
+ def __init__(self, name='root', children=None):
682
+ self.name = name
683
+ self.children = []
684
+ self.prefix = None
685
+ self.parent_relations = None
686
+ if children is not None:
687
+ for child in children:
688
+ self.add_child(child)
689
+
690
+ def __repr__(self):
691
+ return self.name
692
+
693
+ def add_child(self, node):
694
+ assert isinstance(
695
+ node, SchemaTree
696
+ ), "The children of a node should be an instance of SchemaTree."
697
+ self.children.append(node)
modeling_uie.py CHANGED
@@ -1,162 +1,162 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple
3
-
4
- import torch
5
- import torch.nn as nn
6
- from transformers import ErnieModel, ErniePreTrainedModel, PretrainedConfig
7
- from transformers.file_utils import ModelOutput
8
-
9
- from .decode_utils import UIEDecoder
10
-
11
-
12
- @dataclass
13
- class UIEModelOutput(ModelOutput):
14
- """
15
- Output class for outputs of UIE.
16
- losses (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
17
- Total spn extraction losses is the sum of a Cross-Entropy for the start and end positions.
18
- start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
19
- Span-start scores (after Sigmoid).
20
- end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
21
- Span-end scores (after Sigmoid).
22
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
23
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layers, +
24
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
25
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
26
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
27
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
28
- sequence_length)`.
29
- Attention weights after the attention softmax, used to compute the weighted average in the self-attention
30
- heads.
31
- """
32
- loss: Optional[torch.FloatTensor] = None
33
- start_prob: torch.FloatTensor = None
34
- end_prob: torch.FloatTensor = None
35
- start_positions: torch.FloatTensor = None
36
- end_positions: torch.FloatTensor = None
37
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
38
- attentions: Optional[Tuple[torch.FloatTensor]] = None
39
-
40
-
41
- class UIEModel(ErniePreTrainedModel, UIEDecoder):
42
- """
43
- UIE model based on Bert model.
44
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
45
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
46
- etc.)
47
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
48
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
49
- and behavior.
50
- Parameters:
51
- config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
52
- Initializing with a config file does not load the weights associated with the model, only the
53
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
54
- """
55
-
56
- def __init__(self, config: PretrainedConfig):
57
- super(UIEModel, self).__init__(config)
58
- self.encoder = ErnieModel(config)
59
- self.config = config
60
- hidden_size = self.config.hidden_size
61
-
62
- self.linear_start = nn.Linear(hidden_size, 1)
63
- self.linear_end = nn.Linear(hidden_size, 1)
64
- self.sigmoid = nn.Sigmoid()
65
-
66
- self.post_init()
67
-
68
- def forward(
69
- self,
70
- input_ids: Optional[torch.Tensor] = None,
71
- token_type_ids: Optional[torch.Tensor] = None,
72
- position_ids: Optional[torch.Tensor] = None,
73
- attention_mask: Optional[torch.Tensor] = None,
74
- head_mask: Optional[torch.Tensor] = None,
75
- inputs_embeds: Optional[torch.Tensor] = None,
76
- start_positions: Optional[torch.Tensor] = None,
77
- end_positions: Optional[torch.Tensor] = None,
78
- output_attentions: Optional[bool] = None,
79
- output_hidden_states: Optional[bool] = None,
80
- ) -> UIEModelOutput:
81
- """
82
- Args:
83
- input_ids (`torch.LongTensor` of shape `({0})`):
84
- Indices of input sequence tokens in the vocabulary.
85
- Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
86
- [`PreTrainedTokenizer.__call__`] for details.
87
- [What are input IDs?](../glossary#input-ids)
88
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
89
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
90
- - 1 for tokens that are **not masked**,
91
- - 0 for tokens that are **masked**.
92
- [What are attention masks?](../glossary#attention-mask)
93
- token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
94
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
95
- 1]`:
96
- - 0 corresponds to a *sentence A* token,
97
- - 1 corresponds to a *sentence B* token.
98
- [What are token type IDs?](../glossary#token-type-ids)
99
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
100
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
101
- config.max_position_embeddings - 1]`.
102
- [What are position IDs?](../glossary#position-ids)
103
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
104
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
105
- - 1 indicates the head is **not masked**,
106
- - 0 indicates the head is **masked**.
107
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
108
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
109
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
110
- model's internal embedding lookup matrix.
111
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
112
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
113
- Positions are clamped to the length of the sequence (`sequence_length`). Position outsides of the sequence
114
- are not taken into account for computing the loss.
115
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
116
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
117
- Positions are clamped to the length of the sequence (`sequence_length`). Position outsides of the sequence
118
- are not taken into account for computing the loss.
119
- output_attentions (`bool`, *optional*):
120
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
121
- tensors for more detail.
122
- output_hidden_states (`bool`, *optional*):
123
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
124
- more detail.
125
- return_dict (`bool`, *optional*):
126
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
127
- """
128
- outputs = self.encoder(
129
- input_ids=input_ids,
130
- token_type_ids=token_type_ids,
131
- position_ids=position_ids,
132
- attention_mask=attention_mask,
133
- head_mask=head_mask,
134
- inputs_embeds=inputs_embeds,
135
- output_attentions=output_attentions,
136
- output_hidden_states=output_hidden_states,
137
- )
138
- sequence_output = outputs[0]
139
-
140
- start_logits = self.linear_start(sequence_output)
141
- start_logits = torch.squeeze(start_logits, -1)
142
- start_prob = self.sigmoid(start_logits)
143
-
144
- end_logits = self.linear_end(sequence_output)
145
- end_logits = torch.squeeze(end_logits, -1)
146
- end_prob = self.sigmoid(end_logits)
147
-
148
- total_loss = None
149
- if start_positions is not None and end_positions is not None:
150
- loss_fct = nn.BCELoss()
151
- start_loss = loss_fct(start_prob, start_positions)
152
- end_loss = loss_fct(end_prob, end_positions)
153
-
154
- total_loss = (start_loss + end_loss) / 2.0
155
-
156
- return UIEModelOutput(
157
- loss=total_loss,
158
- start_prob=start_prob,
159
- end_prob=end_prob,
160
- hidden_states=outputs.hidden_states,
161
- attentions=outputs.attentions,
162
- )
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import ErnieModel, ErniePreTrainedModel, PretrainedConfig
7
+ from transformers.file_utils import ModelOutput
8
+
9
+ from .decode_utils import UIEDecoder
10
+
11
+
12
+ @dataclass
13
+ class UIEModelOutput(ModelOutput):
14
+ """
15
+ Output class for outputs of UIE.
16
+ losses (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
17
+ Total spn extraction losses is the sum of a Cross-Entropy for the start and end positions.
18
+ start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
19
+ Span-start scores (after Sigmoid).
20
+ end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
21
+ Span-end scores (after Sigmoid).
22
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
23
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layers, +
24
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
25
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
26
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
27
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
28
+ sequence_length)`.
29
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
30
+ heads.
31
+ """
32
+ loss: Optional[torch.FloatTensor] = None
33
+ start_prob: torch.FloatTensor = None
34
+ end_prob: torch.FloatTensor = None
35
+ start_positions: torch.FloatTensor = None
36
+ end_positions: torch.FloatTensor = None
37
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
38
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
39
+
40
+
41
+ class UIEModel(ErniePreTrainedModel, UIEDecoder):
42
+ """
43
+ UIE model based on Bert model.
44
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
45
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
46
+ etc.)
47
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
48
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
49
+ and behavior.
50
+ Parameters:
51
+ config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
52
+ Initializing with a config file does not load the weights associated with the model, only the
53
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
54
+ """
55
+
56
+ def __init__(self, config: PretrainedConfig):
57
+ super(UIEModel, self).__init__(config)
58
+ self.encoder = ErnieModel(config)
59
+ self.config = config
60
+ hidden_size = self.config.hidden_size
61
+
62
+ self.linear_start = nn.Linear(hidden_size, 1)
63
+ self.linear_end = nn.Linear(hidden_size, 1)
64
+ self.sigmoid = nn.Sigmoid()
65
+
66
+ self.post_init()
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.Tensor] = None,
71
+ token_type_ids: Optional[torch.Tensor] = None,
72
+ position_ids: Optional[torch.Tensor] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ head_mask: Optional[torch.Tensor] = None,
75
+ inputs_embeds: Optional[torch.Tensor] = None,
76
+ start_positions: Optional[torch.Tensor] = None,
77
+ end_positions: Optional[torch.Tensor] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ ) -> UIEModelOutput:
81
+ """
82
+ Args:
83
+ input_ids (`torch.LongTensor` of shape `({0})`):
84
+ Indices of input sequence tokens in the vocabulary.
85
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
86
+ [`PreTrainedTokenizer.__call__`] for details.
87
+ [What are input IDs?](../glossary#input-ids)
88
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
89
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
90
+ - 1 for tokens that are **not masked**,
91
+ - 0 for tokens that are **masked**.
92
+ [What are attention masks?](../glossary#attention-mask)
93
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
94
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
95
+ 1]`:
96
+ - 0 corresponds to a *sentence A* token,
97
+ - 1 corresponds to a *sentence B* token.
98
+ [What are token type IDs?](../glossary#token-type-ids)
99
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
100
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
101
+ config.max_position_embeddings - 1]`.
102
+ [What are position IDs?](../glossary#position-ids)
103
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
104
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
105
+ - 1 indicates the head is **not masked**,
106
+ - 0 indicates the head is **masked**.
107
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
108
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
109
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
110
+ model's internal embedding lookup matrix.
111
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
112
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
113
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outsides of the sequence
114
+ are not taken into account for computing the loss.
115
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
116
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
117
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outsides of the sequence
118
+ are not taken into account for computing the loss.
119
+ output_attentions (`bool`, *optional*):
120
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
121
+ tensors for more detail.
122
+ output_hidden_states (`bool`, *optional*):
123
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
124
+ more detail.
125
+ return_dict (`bool`, *optional*):
126
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
127
+ """
128
+ outputs = self.encoder(
129
+ input_ids=input_ids,
130
+ token_type_ids=token_type_ids,
131
+ position_ids=position_ids,
132
+ attention_mask=attention_mask,
133
+ head_mask=head_mask,
134
+ inputs_embeds=inputs_embeds,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ )
138
+ sequence_output = outputs[0]
139
+
140
+ start_logits = self.linear_start(sequence_output)
141
+ start_logits = torch.squeeze(start_logits, -1)
142
+ start_prob = self.sigmoid(start_logits)
143
+
144
+ end_logits = self.linear_end(sequence_output)
145
+ end_logits = torch.squeeze(end_logits, -1)
146
+ end_prob = self.sigmoid(end_logits)
147
+
148
+ total_loss = None
149
+ if start_positions is not None and end_positions is not None:
150
+ loss_fct = nn.BCELoss()
151
+ start_loss = loss_fct(start_prob, start_positions)
152
+ end_loss = loss_fct(end_prob, end_positions)
153
+
154
+ total_loss = (start_loss + end_loss) / 2.0
155
+
156
+ return UIEModelOutput(
157
+ loss=total_loss,
158
+ start_prob=start_prob,
159
+ end_prob=end_prob,
160
+ hidden_states=outputs.hidden_states,
161
+ attentions=outputs.attentions,
162
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32be889914931873a60d481926e9208625823baa2166d7f1f89359f22f15a778
3
+ size 471852986
special_tokens_map.json CHANGED
@@ -1,7 +1,7 @@
1
- {
2
- "cls_token": "[CLS]",
3
- "mask_token": "[MASK]",
4
- "pad_token": "[PAD]",
5
- "sep_token": "[SEP]",
6
- "unk_token": "[UNK]"
7
- }
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json CHANGED
@@ -1,19 +1,7 @@
1
  {
2
  "version": "1.0",
3
- "truncation": {
4
- "direction": "Right",
5
- "max_length": 512,
6
- "strategy": "LongestFirst",
7
- "stride": 2
8
- },
9
- "padding": {
10
- "strategy": "BatchLongest",
11
- "direction": "Right",
12
- "pad_to_multiple_of": null,
13
- "pad_id": 0,
14
- "pad_type_id": 0,
15
- "pad_token": "[PAD]"
16
- },
17
  "added_tokens": [
18
  {
19
  "id": 0,
@@ -12241,7 +12229,7 @@
12241
  "ー": 12081,
12242
  "★": 12082,
12243
  "’": 12083,
12244
- "$$": 12084,
12245
  "{": 12085,
12246
  "}": 12086,
12247
  "‘": 12087,
 
1
  {
2
  "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
 
 
 
 
 
 
 
 
 
 
 
 
5
  "added_tokens": [
6
  {
7
  "id": 0,
 
12229
  "ー": 12081,
12230
  "★": 12082,
12231
  "’": 12083,
12232
+ "’’": 12084,
12233
  "{": 12085,
12234
  "}": 12086,
12235
  "‘": 12087,
tokenizer_config.json CHANGED
@@ -1,57 +1,57 @@
1
- {
2
- "added_tokens_decoder": {
3
- "0": {
4
- "content": "[PAD]",
5
- "lstrip": false,
6
- "normalized": false,
7
- "rstrip": false,
8
- "single_word": false,
9
- "special": true
10
- },
11
- "1": {
12
- "content": "[CLS]",
13
- "lstrip": false,
14
- "normalized": false,
15
- "rstrip": false,
16
- "single_word": false,
17
- "special": true
18
- },
19
- "2": {
20
- "content": "[SEP]",
21
- "lstrip": false,
22
- "normalized": false,
23
- "rstrip": false,
24
- "single_word": false,
25
- "special": true
26
- },
27
- "3": {
28
- "content": "[MASK]",
29
- "lstrip": false,
30
- "normalized": false,
31
- "rstrip": false,
32
- "single_word": false,
33
- "special": true
34
- },
35
- "39979": {
36
- "content": "[UNK]",
37
- "lstrip": false,
38
- "normalized": false,
39
- "rstrip": false,
40
- "single_word": false,
41
- "special": true
42
- }
43
- },
44
- "clean_up_tokenization_spaces": true,
45
- "cls_token": "[CLS]",
46
- "do_basic_tokenize": true,
47
- "do_lower_case": true,
48
- "mask_token": "[MASK]",
49
- "model_max_length": 1000000000000000019884624838656,
50
- "never_split": null,
51
- "pad_token": "[PAD]",
52
- "sep_token": "[SEP]",
53
- "strip_accents": null,
54
- "tokenize_chinese_chars": true,
55
- "tokenizer_class": "BertTokenizer",
56
- "unk_token": "[UNK]"
57
- }
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[MASK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "39979": {
36
+ "content": "[UNK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 1000000000000000019884624838656,
50
+ "never_split": null,
51
+ "pad_token": "[PAD]",
52
+ "sep_token": "[SEP]",
53
+ "strip_accents": null,
54
+ "tokenize_chinese_chars": true,
55
+ "tokenizer_class": "BertTokenizer",
56
+ "unk_token": "[UNK]"
57
+ }