import torch |
from transformers import RobertaForTokenClassification, AutoTokenizer |
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract') |
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True) |
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
model.eval() |
model.to(device) |
import json |
label_list={ |
0:'其他', |
1:'电话', |
2:'毕业时间', |
3:'出生日期', |
4:'项目名称', |
5:'毕业院校', |
6:'职务', |
7:'籍贯', |
8:'学位', |
9:'性别', |
10:'姓名', |
11:'工作时间', |
12:'落户市县', |
13:'项目时间', |
14:'最高学历', |
15:'工作单位', |
16:'政治面貌', |
17:'工作内容', |
18:'项目责任', |
} |
def get_info(text): |
text=text.strip() |
text=text.replace('\n',',') |
text=text.replace('\r',',') |
text=text.replace('\t',',') |
text=text.replace(' ',',') |
while ',,' in text: |
text=text.replace(',,',',') |
block_list=[] |
if len(text)>300: |
sentence_list=text.split(',') |
boundary=300 |
block_list=[] |
block=sentence_list[0] |
for i in range(1,len(sentence_list)): |
if len(block)+len(sentence_list[i])<=boundary: |
block+=sentence_list[i] |
else: |
block_list.append(block) |
block=sentence_list[i] |
block_list.append(block) |
else: |
block_list.append(text) |
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True) |
input_ids = _input['input_ids'].to(device) |
attention_mask = _input['attention_mask'].to(device) |
with torch.no_grad(): |
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0] |
ids = torch.argmax(logits, dim=-1) |
input_ids=input_ids.reshape(-1) |
ids =ids.reshape(-1) |
extracted_info = {} |
word_list=[] |
flag=None |
for idx, label_id in enumerate(ids): |
label_id = label_id.item() |
if label_id!= 0 and (flag==None or flag==label_id): |
if flag==None: |
flag=label_id |
label = label_list[label_id] |
word_list.append(input_ids[idx].item()) |
if label not in extracted_info: |
extracted_info[label] = [] |
else: |
if word_list: |
sentence=''.join(tokenizer.decode(word_list)) |
extracted_info[label].append(sentence) |
flag=None |
word_list=[] |
if label_id!= 0: |
label = label_list[label_id] |
word_list.append(input_ids[idx].item()) |
if label not in extracted_info: |
extracted_info[label] = [] |
return extracted_info |