|
import torch |
|
from transformers import RobertaForTokenClassification, AutoTokenizer |
|
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract') |
|
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True) |
|
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.eval() |
|
model.to(device) |
|
import json |
|
label_list={ |
|
0:'其他', |
|
1:'电话', |
|
2:'毕业时间', |
|
3:'出生日期', |
|
4:'项目名称', |
|
5:'毕业院校', |
|
6:'职务', |
|
7:'籍贯', |
|
8:'学位', |
|
9:'性别', |
|
10:'姓名', |
|
11:'工作时间', |
|
12:'落户市县', |
|
13:'项目时间', |
|
14:'最高学历', |
|
15:'工作单位', |
|
16:'政治面貌', |
|
17:'工作内容', |
|
18:'项目责任', |
|
} |
|
|
|
def get_info(text): |
|
|
|
text=text.strip() |
|
text=text.replace('\n',',') |
|
text=text.replace('\r',',') |
|
text=text.replace('\t',',') |
|
text=text.replace(' ',',') |
|
|
|
while ',,' in text: |
|
text=text.replace(',,',',') |
|
block_list=[] |
|
if len(text)>300: |
|
|
|
|
|
sentence_list=text.split(',') |
|
|
|
boundary=300 |
|
block_list=[] |
|
block=sentence_list[0] |
|
for i in range(1,len(sentence_list)): |
|
if len(block)+len(sentence_list[i])<=boundary: |
|
block+=sentence_list[i] |
|
else: |
|
block_list.append(block) |
|
block=sentence_list[i] |
|
block_list.append(block) |
|
else: |
|
block_list.append(text) |
|
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True) |
|
|
|
input_ids = _input['input_ids'].to(device) |
|
attention_mask = _input['attention_mask'].to(device) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0] |
|
|
|
|
|
|
|
ids = torch.argmax(logits, dim=-1) |
|
input_ids=input_ids.reshape(-1) |
|
|
|
ids =ids.reshape(-1) |
|
|
|
extracted_info = {} |
|
word_list=[] |
|
flag=None |
|
for idx, label_id in enumerate(ids): |
|
label_id = label_id.item() |
|
if label_id!= 0 and (flag==None or flag==label_id): |
|
if flag==None: |
|
flag=label_id |
|
label = label_list[label_id] |
|
word_list.append(input_ids[idx].item()) |
|
if label not in extracted_info: |
|
extracted_info[label] = [] |
|
else: |
|
if word_list: |
|
sentence=''.join(tokenizer.decode(word_list)) |
|
extracted_info[label].append(sentence) |
|
flag=None |
|
word_list=[] |
|
if label_id!= 0: |
|
label = label_list[label_id] |
|
word_list.append(input_ids[idx].item()) |
|
if label not in extracted_info: |
|
extracted_info[label] = [] |
|
|
|
return extracted_info |