jiangchengchengNLP
commited on
Update function.py
Browse files- function.py +97 -97
function.py
CHANGED
@@ -1,98 +1,98 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import RobertaForTokenClassification, AutoTokenizer
|
3 |
-
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
|
4 |
-
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
|
5 |
-
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
-
model.eval()
|
7 |
-
model.to(device)
|
8 |
-
import json
|
9 |
-
label_list={
|
10 |
-
0:'其他',
|
11 |
-
1:'电话',
|
12 |
-
2:'毕业时间', #毕业时间
|
13 |
-
3:'出生日期', #出生日期
|
14 |
-
4:'项目名称', #项目名称
|
15 |
-
5:'毕业院校', #毕业院校
|
16 |
-
6:'职务', #职务
|
17 |
-
7:'籍贯', #籍贯
|
18 |
-
8:'学位', #学位
|
19 |
-
9:'性别', #性别
|
20 |
-
10:'姓名', #姓名
|
21 |
-
11:'工作时间', #工作时间
|
22 |
-
12:'落户市县', #落户市县
|
23 |
-
13:'项目时间', #项目时间
|
24 |
-
14:'最高学历', #最高学历
|
25 |
-
15:'工作单位', #工作单位
|
26 |
-
16:'政治面貌', #政治面貌
|
27 |
-
17:'工作内容', #工作内容
|
28 |
-
18:'项目责任', #项目责任
|
29 |
-
}
|
30 |
-
|
31 |
-
def get_info(text):
|
32 |
-
#文本处理
|
33 |
-
text=text.strip()
|
34 |
-
text=text.replace('\n',',') # 将换行符替换为逗号
|
35 |
-
text=text.replace('\r',',') # 将回车符替换为逗号
|
36 |
-
text=text.replace('\t',',') # 将制表符替换为逗号
|
37 |
-
text=text.replace(' ',',') # 将空格替换为逗号
|
38 |
-
#将连续的逗号合并成一个逗号
|
39 |
-
while ',,' in text:
|
40 |
-
text=text.replace(',,',',')
|
41 |
-
block_list=[]
|
42 |
-
if len(text)>300:
|
43 |
-
#切块策略
|
44 |
-
#先切分成句
|
45 |
-
sentence_list=text.split(',')
|
46 |
-
#然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
|
47 |
-
boundary=300
|
48 |
-
block_list=[]
|
49 |
-
block=sentence_list[0]
|
50 |
-
for i in range(1,len(sentence_list)):
|
51 |
-
if len(block)+len(sentence_list[i])<=boundary:
|
52 |
-
block+=sentence_list[i]
|
53 |
-
else:
|
54 |
-
block_list.append(block)
|
55 |
-
block=sentence_list[i]
|
56 |
-
block_list.append(block)
|
57 |
-
else:
|
58 |
-
block_list.append(text)
|
59 |
-
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
|
60 |
-
#如果有GPU,将输入数据移到GPU
|
61 |
-
input_ids = _input['input_ids'].to(device)
|
62 |
-
attention_mask = _input['attention_mask'].to(device)
|
63 |
-
# 模型推理
|
64 |
-
with torch.no_grad():
|
65 |
-
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
66 |
-
|
67 |
-
# 获取预测的标签ID
|
68 |
-
print(logits.shape)
|
69 |
-
ids = torch.argmax(logits, dim=-1)
|
70 |
-
input_ids=input_ids.reshape(-1)
|
71 |
-
#将张量在最后一个维度拼接,并以0为分界,拼接成句
|
72 |
-
ids =ids.reshape(-1)
|
73 |
-
# 按标签组合成提取内容
|
74 |
-
extracted_info = {}
|
75 |
-
word_list=[]
|
76 |
-
flag=None
|
77 |
-
for idx, label_id in enumerate(ids):
|
78 |
-
label_id = label_id.item()
|
79 |
-
if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
|
80 |
-
if flag==None:
|
81 |
-
flag=label_id
|
82 |
-
label = label_list[label_id] # 获取对应的标签
|
83 |
-
word_list.append(input_ids[idx].item())
|
84 |
-
if label not in extracted_info:
|
85 |
-
extracted_info[label] = []
|
86 |
-
else:
|
87 |
-
if word_list:
|
88 |
-
sentence=''.join(tokenizer.decode(word_list))
|
89 |
-
extracted_info[label].append(sentence)
|
90 |
-
flag=None
|
91 |
-
word_list=[]
|
92 |
-
if label_id!= 0:
|
93 |
-
label = label_list[label_id] # 获取对应的标签
|
94 |
-
word_list.append(input_ids[idx].item())
|
95 |
-
if label not in extracted_info:
|
96 |
-
extracted_info[label] = []
|
97 |
-
# 返回JSON格式的提取内容
|
98 |
return extracted_info
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import RobertaForTokenClassification, AutoTokenizer
|
3 |
+
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
|
5 |
+
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
+
model.eval()
|
7 |
+
model.to(device)
|
8 |
+
import json
|
9 |
+
label_list={
|
10 |
+
0:'其他',
|
11 |
+
1:'电话',
|
12 |
+
2:'毕业时间', #毕业时间
|
13 |
+
3:'出生日期', #出生日期
|
14 |
+
4:'项目名称', #项目名称
|
15 |
+
5:'毕业院校', #毕业院校
|
16 |
+
6:'职务', #职务
|
17 |
+
7:'籍贯', #籍贯
|
18 |
+
8:'学位', #学位
|
19 |
+
9:'性别', #性别
|
20 |
+
10:'姓名', #姓名
|
21 |
+
11:'工作时间', #工作时间
|
22 |
+
12:'落户市县', #落户市县
|
23 |
+
13:'项目时间', #项目时间
|
24 |
+
14:'最高学历', #最高学历
|
25 |
+
15:'工作单位', #工作单位
|
26 |
+
16:'政治面貌', #政治面貌
|
27 |
+
17:'工作内容', #工作内容
|
28 |
+
18:'项目责任', #项目责任
|
29 |
+
}
|
30 |
+
|
31 |
+
def get_info(text):
|
32 |
+
#文本处理
|
33 |
+
text=text.strip()
|
34 |
+
text=text.replace('\n',',') # 将换行符替换为逗号
|
35 |
+
text=text.replace('\r',',') # 将回车符替换为逗号
|
36 |
+
text=text.replace('\t',',') # 将制表符替换为逗号
|
37 |
+
text=text.replace(' ',',') # 将空格替换为逗号
|
38 |
+
#将连续的逗号合并成一个逗号
|
39 |
+
while ',,' in text:
|
40 |
+
text=text.replace(',,',',')
|
41 |
+
block_list=[]
|
42 |
+
if len(text)>300:
|
43 |
+
#切块策略
|
44 |
+
#先切分成句
|
45 |
+
sentence_list=text.split(',')
|
46 |
+
#然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
|
47 |
+
boundary=300
|
48 |
+
block_list=[]
|
49 |
+
block=sentence_list[0]
|
50 |
+
for i in range(1,len(sentence_list)):
|
51 |
+
if len(block)+len(sentence_list[i])<=boundary:
|
52 |
+
block+=sentence_list[i]
|
53 |
+
else:
|
54 |
+
block_list.append(block)
|
55 |
+
block=sentence_list[i]
|
56 |
+
block_list.append(block)
|
57 |
+
else:
|
58 |
+
block_list.append(text)
|
59 |
+
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
|
60 |
+
#如果有GPU,将输入数据移到GPU
|
61 |
+
input_ids = _input['input_ids'].to(device)
|
62 |
+
attention_mask = _input['attention_mask'].to(device)
|
63 |
+
# 模型推理
|
64 |
+
with torch.no_grad():
|
65 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
66 |
+
|
67 |
+
# 获取预测的标签ID
|
68 |
+
#print(logits.shape)
|
69 |
+
ids = torch.argmax(logits, dim=-1)
|
70 |
+
input_ids=input_ids.reshape(-1)
|
71 |
+
#将张量在最后一个维度拼接,并以0为分界,拼接成句
|
72 |
+
ids =ids.reshape(-1)
|
73 |
+
# 按标签组合成提取内容
|
74 |
+
extracted_info = {}
|
75 |
+
word_list=[]
|
76 |
+
flag=None
|
77 |
+
for idx, label_id in enumerate(ids):
|
78 |
+
label_id = label_id.item()
|
79 |
+
if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
|
80 |
+
if flag==None:
|
81 |
+
flag=label_id
|
82 |
+
label = label_list[label_id] # 获取对应的标签
|
83 |
+
word_list.append(input_ids[idx].item())
|
84 |
+
if label not in extracted_info:
|
85 |
+
extracted_info[label] = []
|
86 |
+
else:
|
87 |
+
if word_list:
|
88 |
+
sentence=''.join(tokenizer.decode(word_list))
|
89 |
+
extracted_info[label].append(sentence)
|
90 |
+
flag=None
|
91 |
+
word_list=[]
|
92 |
+
if label_id!= 0:
|
93 |
+
label = label_list[label_id] # 获取对应的标签
|
94 |
+
word_list.append(input_ids[idx].item())
|
95 |
+
if label not in extracted_info:
|
96 |
+
extracted_info[label] = []
|
97 |
+
# 返回JSON格式的提取内容
|
98 |
return extracted_info
|