jiangchengchengNLP commited on
Commit
3c82d9b
·
verified ·
1 Parent(s): 288b444

Update function.py

Browse files
Files changed (1) hide show
  1. function.py +97 -97
function.py CHANGED
@@ -1,98 +1,98 @@
1
- import torch
2
- from transformers import RobertaForTokenClassification, AutoTokenizer
3
- model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
4
- tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
5
- device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
- model.eval()
7
- model.to(device)
8
- import json
9
- label_list={
10
- 0:'其他',
11
- 1:'电话',
12
- 2:'毕业时间', #毕业时间
13
- 3:'出生日期', #出生日期
14
- 4:'项目名称', #项目名称
15
- 5:'毕业院校', #毕业院校
16
- 6:'职务', #职务
17
- 7:'籍贯', #籍贯
18
- 8:'学位', #学位
19
- 9:'性别', #性别
20
- 10:'姓名', #姓名
21
- 11:'工作时间', #工作时间
22
- 12:'落户市县', #落户市县
23
- 13:'项目时间', #项目时间
24
- 14:'最高学历', #最高学历
25
- 15:'工作单位', #工作单位
26
- 16:'政治面貌', #政治面貌
27
- 17:'工作内容', #工作内容
28
- 18:'项目责任', #项目责任
29
- }
30
-
31
- def get_info(text):
32
- #文本处理
33
- text=text.strip()
34
- text=text.replace('\n',',') # 将换行符替换为逗号
35
- text=text.replace('\r',',') # 将回车符替换为逗号
36
- text=text.replace('\t',',') # 将制表符替换为逗号
37
- text=text.replace(' ',',') # 将空格替换为逗号
38
- #将连续的逗号合并成一个逗号
39
- while ',,' in text:
40
- text=text.replace(',,',',')
41
- block_list=[]
42
- if len(text)>300:
43
- #切块策略
44
- #先切分成句
45
- sentence_list=text.split(',')
46
- #然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
47
- boundary=300
48
- block_list=[]
49
- block=sentence_list[0]
50
- for i in range(1,len(sentence_list)):
51
- if len(block)+len(sentence_list[i])<=boundary:
52
- block+=sentence_list[i]
53
- else:
54
- block_list.append(block)
55
- block=sentence_list[i]
56
- block_list.append(block)
57
- else:
58
- block_list.append(text)
59
- _input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
60
- #如果有GPU,将输入数据移到GPU
61
- input_ids = _input['input_ids'].to(device)
62
- attention_mask = _input['attention_mask'].to(device)
63
- # 模型推理
64
- with torch.no_grad():
65
- logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
66
-
67
- # 获取预测的标签ID
68
- print(logits.shape)
69
- ids = torch.argmax(logits, dim=-1)
70
- input_ids=input_ids.reshape(-1)
71
- #将张量在最后一个维度拼接,并以0为分界,拼接成句
72
- ids =ids.reshape(-1)
73
- # 按标签组合成提取内容
74
- extracted_info = {}
75
- word_list=[]
76
- flag=None
77
- for idx, label_id in enumerate(ids):
78
- label_id = label_id.item()
79
- if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
80
- if flag==None:
81
- flag=label_id
82
- label = label_list[label_id] # 获取对应的标签
83
- word_list.append(input_ids[idx].item())
84
- if label not in extracted_info:
85
- extracted_info[label] = []
86
- else:
87
- if word_list:
88
- sentence=''.join(tokenizer.decode(word_list))
89
- extracted_info[label].append(sentence)
90
- flag=None
91
- word_list=[]
92
- if label_id!= 0:
93
- label = label_list[label_id] # 获取对应的标签
94
- word_list.append(input_ids[idx].item())
95
- if label not in extracted_info:
96
- extracted_info[label] = []
97
- # 返回JSON格式的提取内容
98
  return extracted_info
 
1
+ import torch
2
+ from transformers import RobertaForTokenClassification, AutoTokenizer
3
+ model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
4
+ tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
5
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+ model.eval()
7
+ model.to(device)
8
+ import json
9
+ label_list={
10
+ 0:'其他',
11
+ 1:'电话',
12
+ 2:'毕业时间', #毕业时间
13
+ 3:'出生日期', #出生日期
14
+ 4:'项目名称', #项目名称
15
+ 5:'毕业院校', #毕业院校
16
+ 6:'职务', #职务
17
+ 7:'籍贯', #籍贯
18
+ 8:'学位', #学位
19
+ 9:'性别', #性别
20
+ 10:'姓名', #姓名
21
+ 11:'工作时间', #工作时间
22
+ 12:'落户市县', #落户市县
23
+ 13:'项目时间', #项目时间
24
+ 14:'最高学历', #最高学历
25
+ 15:'工作单位', #工作单位
26
+ 16:'政治面貌', #政治面貌
27
+ 17:'工作内容', #工作内容
28
+ 18:'项目责任', #项目责任
29
+ }
30
+
31
+ def get_info(text):
32
+ #文本处理
33
+ text=text.strip()
34
+ text=text.replace('\n',',') # 将换行符替换为逗号
35
+ text=text.replace('\r',',') # 将回车符替换为逗号
36
+ text=text.replace('\t',',') # 将制表符替换为逗号
37
+ text=text.replace(' ',',') # 将空格替换为逗号
38
+ #将连续的逗号合并成一个逗号
39
+ while ',,' in text:
40
+ text=text.replace(',,',',')
41
+ block_list=[]
42
+ if len(text)>300:
43
+ #切块策略
44
+ #先切分成句
45
+ sentence_list=text.split(',')
46
+ #然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
47
+ boundary=300
48
+ block_list=[]
49
+ block=sentence_list[0]
50
+ for i in range(1,len(sentence_list)):
51
+ if len(block)+len(sentence_list[i])<=boundary:
52
+ block+=sentence_list[i]
53
+ else:
54
+ block_list.append(block)
55
+ block=sentence_list[i]
56
+ block_list.append(block)
57
+ else:
58
+ block_list.append(text)
59
+ _input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
60
+ #如果有GPU,将输入数据移到GPU
61
+ input_ids = _input['input_ids'].to(device)
62
+ attention_mask = _input['attention_mask'].to(device)
63
+ # 模型推理
64
+ with torch.no_grad():
65
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
66
+
67
+ # 获取预测的标签ID
68
+ #print(logits.shape)
69
+ ids = torch.argmax(logits, dim=-1)
70
+ input_ids=input_ids.reshape(-1)
71
+ #将张量在最后一个维度拼接,并以0为分界,拼接成句
72
+ ids =ids.reshape(-1)
73
+ # 按标签组合成提取内容
74
+ extracted_info = {}
75
+ word_list=[]
76
+ flag=None
77
+ for idx, label_id in enumerate(ids):
78
+ label_id = label_id.item()
79
+ if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
80
+ if flag==None:
81
+ flag=label_id
82
+ label = label_list[label_id] # 获取对应的标签
83
+ word_list.append(input_ids[idx].item())
84
+ if label not in extracted_info:
85
+ extracted_info[label] = []
86
+ else:
87
+ if word_list:
88
+ sentence=''.join(tokenizer.decode(word_list))
89
+ extracted_info[label].append(sentence)
90
+ flag=None
91
+ word_list=[]
92
+ if label_id!= 0:
93
+ label = label_list[label_id] # 获取对应的标签
94
+ word_list.append(input_ids[idx].item())
95
+ if label not in extracted_info:
96
+ extracted_info[label] = []
97
+ # 返回JSON格式的提取内容
98
  return extracted_info