Chinese_resume_extract / function.py
jiangchengchengNLP's picture
Update function.py
3c82d9b verified
import torch
from transformers import RobertaForTokenClassification, AutoTokenizer
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
model.to(device)
import json
label_list={
0:'其他',
1:'电话',
2:'毕业时间', #毕业时间
3:'出生日期', #出生日期
4:'项目名称', #项目名称
5:'毕业院校', #毕业院校
6:'职务', #职务
7:'籍贯', #籍贯
8:'学位', #学位
9:'性别', #性别
10:'姓名', #姓名
11:'工作时间', #工作时间
12:'落户市县', #落户市县
13:'项目时间', #项目时间
14:'最高学历', #最高学历
15:'工作单位', #工作单位
16:'政治面貌', #政治面貌
17:'工作内容', #工作内容
18:'项目责任', #项目责任
}
def get_info(text):
#文本处理
text=text.strip()
text=text.replace('\n',',') # 将换行符替换为逗号
text=text.replace('\r',',') # 将回车符替换为逗号
text=text.replace('\t',',') # 将制表符替换为逗号
text=text.replace(' ',',') # 将空格替换为逗号
#将连续的逗号合并成一个逗号
while ',,' in text:
text=text.replace(',,',',')
block_list=[]
if len(text)>300:
#切块策略
#先切分成句
sentence_list=text.split(',')
#然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
boundary=300
block_list=[]
block=sentence_list[0]
for i in range(1,len(sentence_list)):
if len(block)+len(sentence_list[i])<=boundary:
block+=sentence_list[i]
else:
block_list.append(block)
block=sentence_list[i]
block_list.append(block)
else:
block_list.append(text)
_input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
#如果有GPU,将输入数据移到GPU
input_ids = _input['input_ids'].to(device)
attention_mask = _input['attention_mask'].to(device)
# 模型推理
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
# 获取预测的标签ID
#print(logits.shape)
ids = torch.argmax(logits, dim=-1)
input_ids=input_ids.reshape(-1)
#将张量在最后一个维度拼接,并以0为分界,拼接成句
ids =ids.reshape(-1)
# 按标签组合成提取内容
extracted_info = {}
word_list=[]
flag=None
for idx, label_id in enumerate(ids):
label_id = label_id.item()
if label_id!= 0 and (flag==None or flag==label_id): #不等于零时
if flag==None:
flag=label_id
label = label_list[label_id] # 获取对应的标签
word_list.append(input_ids[idx].item())
if label not in extracted_info:
extracted_info[label] = []
else:
if word_list:
sentence=''.join(tokenizer.decode(word_list))
extracted_info[label].append(sentence)
flag=None
word_list=[]
if label_id!= 0:
label = label_list[label_id] # 获取对应的标签
word_list.append(input_ids[idx].item())
if label not in extracted_info:
extracted_info[label] = []
# 返回JSON格式的提取内容
return extracted_info