LukeJacob2023 commited on
Commit
9773a47
1 Parent(s): fcb460b

Update cttpunctuator/src/utils/text_post_process.py

Browse files
cttpunctuator/src/utils/text_post_process.py CHANGED
@@ -1,85 +1,85 @@
1
- # -*- coding:utf-8 -*-
2
- # @FileName :text_post_process.py
3
- # @Time :2023/4/13 15:09
4
- # @Author :lovemefan
5
- # @Email :lovemefan@outlook.com
6
- from pathlib import Path
7
- from typing import Dict, Iterable, List, Union
8
-
9
- import numpy as np
10
- import yaml
11
- from typeguard import check_argument_types
12
-
13
-
14
- class TokenIDConverterError(Exception):
15
- pass
16
-
17
-
18
- class TokenIDConverter:
19
- def __init__(
20
- self,
21
- token_list: Union[List, str],
22
- ):
23
- check_argument_types()
24
-
25
- self.token_list = token_list
26
- self.unk_symbol = token_list[-1]
27
- self.token2id = {v: i for i, v in enumerate(self.token_list)}
28
- self.unk_id = self.token2id[self.unk_symbol]
29
-
30
- def get_num_vocabulary_size(self) -> int:
31
- return len(self.token_list)
32
-
33
- def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
34
- if isinstance(integers, np.ndarray) and integers.ndim != 1:
35
- raise TokenIDConverterError(
36
- f"Must be 1 dim ndarray, but got {integers.ndim}"
37
- )
38
- return [self.token_list[i] for i in integers]
39
-
40
- def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
41
- return [self.token2id.get(i, self.unk_id) for i in tokens]
42
-
43
-
44
- def split_to_mini_sentence(words: list, word_limit: int = 20):
45
- assert word_limit > 1
46
- if len(words) <= word_limit:
47
- return [words]
48
- sentences = []
49
- length = len(words)
50
- sentence_len = length // word_limit
51
- for i in range(sentence_len):
52
- sentences.append(words[i * word_limit : (i + 1) * word_limit])
53
- if length % word_limit > 0:
54
- sentences.append(words[sentence_len * word_limit :])
55
- return sentences
56
-
57
-
58
- def code_mix_split_words(text: str):
59
- words = []
60
- segs = text.split()
61
- for seg in segs:
62
- # There is no space in seg.
63
- current_word = ""
64
- for c in seg:
65
- if len(c.encode()) == 1:
66
- # This is an ASCII char.
67
- current_word += c
68
- else:
69
- # This is a Chinese char.
70
- if len(current_word) > 0:
71
- words.append(current_word)
72
- current_word = ""
73
- words.append(c)
74
- if len(current_word) > 0:
75
- words.append(current_word)
76
- return words
77
-
78
-
79
- def read_yaml(yaml_path: Union[str, Path]) -> Dict:
80
- if not Path(yaml_path).exists():
81
- raise FileExistsError(f"The {yaml_path} does not exist.")
82
-
83
- with open(str(yaml_path), "rb") as f:
84
- data = yaml.load(f, Loader=yaml.Loader)
85
- return data
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :text_post_process.py
3
+ # @Time :2023/4/13 15:09
4
+ # @Author :lovemefan
5
+ # @Email :lovemefan@outlook.com
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, List, Union
8
+
9
+ import numpy as np
10
+ import yaml
11
+ # from typeguard import check_argument_types
12
+
13
+
14
+ class TokenIDConverterError(Exception):
15
+ pass
16
+
17
+
18
+ class TokenIDConverter:
19
+ def __init__(
20
+ self,
21
+ token_list: Union[List, str],
22
+ ):
23
+ # check_argument_types()
24
+
25
+ self.token_list = token_list
26
+ self.unk_symbol = token_list[-1]
27
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
28
+ self.unk_id = self.token2id[self.unk_symbol]
29
+
30
+ def get_num_vocabulary_size(self) -> int:
31
+ return len(self.token_list)
32
+
33
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
34
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
35
+ raise TokenIDConverterError(
36
+ f"Must be 1 dim ndarray, but got {integers.ndim}"
37
+ )
38
+ return [self.token_list[i] for i in integers]
39
+
40
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
41
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
42
+
43
+
44
+ def split_to_mini_sentence(words: list, word_limit: int = 20):
45
+ assert word_limit > 1
46
+ if len(words) <= word_limit:
47
+ return [words]
48
+ sentences = []
49
+ length = len(words)
50
+ sentence_len = length // word_limit
51
+ for i in range(sentence_len):
52
+ sentences.append(words[i * word_limit : (i + 1) * word_limit])
53
+ if length % word_limit > 0:
54
+ sentences.append(words[sentence_len * word_limit :])
55
+ return sentences
56
+
57
+
58
+ def code_mix_split_words(text: str):
59
+ words = []
60
+ segs = text.split()
61
+ for seg in segs:
62
+ # There is no space in seg.
63
+ current_word = ""
64
+ for c in seg:
65
+ if len(c.encode()) == 1:
66
+ # This is an ASCII char.
67
+ current_word += c
68
+ else:
69
+ # This is a Chinese char.
70
+ if len(current_word) > 0:
71
+ words.append(current_word)
72
+ current_word = ""
73
+ words.append(c)
74
+ if len(current_word) > 0:
75
+ words.append(current_word)
76
+ return words
77
+
78
+
79
+ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
80
+ if not Path(yaml_path).exists():
81
+ raise FileExistsError(f"The {yaml_path} does not exist.")
82
+
83
+ with open(str(yaml_path), "rb") as f:
84
+ data = yaml.load(f, Loader=yaml.Loader)
85
+ return data