|
|
|
""" |
|
Created on Nov 03 20:08:30 2022 |
|
|
|
@author: luol2 |
|
""" |
|
|
|
|
|
import logging |
|
import regex |
|
import sys |
|
import io |
|
|
|
""" |
|
A Python 3 refactoring of Vincent Van Asch's Python 2 code at |
|
|
|
http://www.cnts.ua.ac.be/~vincent/scripts/abbreviations.py |
|
|
|
Based on |
|
|
|
A Simple Algorithm for Identifying Abbreviations Definitions in Biomedical Text |
|
A. Schwartz and M. Hearst |
|
Biocomputing, 2003, pp 451-462. |
|
|
|
""" |
|
|
|
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) |
|
log = logging.getLogger('Abbre') |
|
|
|
|
|
class Candidate(str): |
|
def __init__(self, value): |
|
super().__init__() |
|
self.start = 0 |
|
self.stop = 0 |
|
|
|
def set_position(self, start, stop): |
|
self.start = start |
|
self.stop = stop |
|
|
|
|
|
def yield_lines_from_file(file_path): |
|
with open(file_path, 'rb') as f: |
|
for line in f: |
|
try: |
|
line = line.decode('utf-8') |
|
except UnicodeDecodeError: |
|
line = line.decode('latin-1').encode('utf-8').decode('utf-8') |
|
line = line.strip() |
|
yield line |
|
f.close() |
|
|
|
|
|
def yield_lines_from_doc(doc_text): |
|
for line in doc_text.split("\n"): |
|
yield line.strip() |
|
|
|
|
|
def best_candidates(sentence): |
|
""" |
|
:param sentence: line read from input file |
|
:return: a Candidate iterator |
|
""" |
|
|
|
if '(' in sentence: |
|
|
|
if sentence.count('(') != sentence.count(')'): |
|
raise ValueError("Unbalanced parentheses: {}".format(sentence)) |
|
|
|
if sentence.find('(') > sentence.find(')'): |
|
raise ValueError("First parentheses is right: {}".format(sentence)) |
|
|
|
closeindex = -1 |
|
while 1: |
|
|
|
openindex = sentence.find('(', closeindex + 1) |
|
|
|
if openindex == -1: break |
|
|
|
|
|
closeindex = openindex + 1 |
|
open = 1 |
|
skip = False |
|
while open: |
|
try: |
|
char = sentence[closeindex] |
|
except IndexError: |
|
|
|
|
|
skip = True |
|
break |
|
if char == '(': |
|
open += 1 |
|
elif char in [')', ';', ':']: |
|
open -= 1 |
|
closeindex += 1 |
|
|
|
if skip: |
|
closeindex = openindex + 1 |
|
continue |
|
|
|
|
|
start = openindex + 1 |
|
stop = closeindex - 1 |
|
candidate = sentence[start:stop] |
|
|
|
|
|
start = start + len(candidate) - len(candidate.lstrip()) |
|
stop = stop - len(candidate) + len(candidate.rstrip()) |
|
candidate = sentence[start:stop] |
|
|
|
if conditions(candidate): |
|
new_candidate = Candidate(candidate) |
|
new_candidate.set_position(start, stop) |
|
yield new_candidate |
|
|
|
|
|
def conditions(candidate): |
|
""" |
|
Based on Schwartz&Hearst |
|
|
|
2 <= len(str) <= 10 |
|
len(tokens) <= 2 |
|
re.search('\p{L}', str) |
|
str[0].isalnum() |
|
|
|
and extra: |
|
if it matches (\p{L}\.?\s?){2,} |
|
it is a good candidate. |
|
|
|
:param candidate: candidate abbreviation |
|
:return: True if this is a good candidate |
|
""" |
|
viable = True |
|
if regex.match('(\p{L}\.?\s?){2,}', candidate.lstrip()): |
|
viable = True |
|
if len(candidate) < 2 or len(candidate) > 10: |
|
viable = False |
|
if len(candidate.split()) > 2: |
|
viable = False |
|
if not regex.search('\p{L}', candidate): |
|
viable = False |
|
if not candidate[0].isalnum(): |
|
viable = False |
|
|
|
return viable |
|
|
|
|
|
def get_definition(candidate, sentence): |
|
""" |
|
Takes a candidate and a sentence and returns the definition candidate. |
|
|
|
The definintion candidate is the set of tokens (in front of the candidate) |
|
that starts with a token starting with the first character of the candidate |
|
|
|
:param candidate: candidate abbreviation |
|
:param sentence: current sentence (single line from input file) |
|
:return: candidate definition for this abbreviation |
|
""" |
|
|
|
tokens = regex.split(r'[\s\-]+', sentence[:candidate.start - 2].lower()) |
|
|
|
|
|
key = candidate[0].lower() |
|
|
|
|
|
|
|
firstchars = [t[0] for t in tokens] |
|
|
|
definition_freq = firstchars.count(key) |
|
candidate_freq = candidate.lower().count(key) |
|
|
|
|
|
|
|
if candidate_freq <= definition_freq: |
|
|
|
count = 0 |
|
start = 0 |
|
startindex = len(firstchars) - 1 |
|
|
|
while count < candidate_freq: |
|
if abs(start) > len(firstchars): |
|
raise ValueError("candiate {} not found".format(candidate)) |
|
start -= 1 |
|
|
|
try: |
|
startindex = firstchars.index(key, len(firstchars) + start) |
|
except ValueError: |
|
pass |
|
|
|
|
|
count = firstchars[startindex:].count(key) |
|
|
|
|
|
start = len(' '.join(tokens[:startindex])) |
|
stop = candidate.start - 1 |
|
candidate = sentence[start:stop] |
|
|
|
|
|
start = start + len(candidate) - len(candidate.lstrip()) |
|
stop = stop - len(candidate) + len(candidate.rstrip()) |
|
candidate = sentence[start:stop] |
|
|
|
new_candidate = Candidate(candidate) |
|
new_candidate.set_position(start, stop) |
|
|
|
|
|
return new_candidate |
|
|
|
else: |
|
raise ValueError('There are less keys in the tokens in front of candidate than there are in the candidate') |
|
|
|
|
|
def select_definition(definition, abbrev): |
|
""" |
|
Takes a definition candidate and an abbreviation candidate |
|
and returns True if the chars in the abbreviation occur in the definition |
|
|
|
Based on |
|
A simple algorithm for identifying abbreviation definitions in biomedical texts, Schwartz & Hearst |
|
:param definition: candidate definition |
|
:param abbrev: candidate abbreviation |
|
:return: |
|
""" |
|
|
|
|
|
if len(definition) < len(abbrev): |
|
raise ValueError('Abbreviation is longer than definition') |
|
|
|
if abbrev in definition.split(): |
|
raise ValueError('Abbreviation is full word of definition') |
|
|
|
sindex = -1 |
|
lindex = -1 |
|
|
|
while 1: |
|
try: |
|
longchar = definition[lindex].lower() |
|
except IndexError: |
|
raise |
|
|
|
shortchar = abbrev[sindex].lower() |
|
|
|
if not shortchar.isalnum(): |
|
sindex -= 1 |
|
|
|
if sindex == -1 * len(abbrev): |
|
if shortchar == longchar: |
|
if lindex == -1 * len(definition) or not definition[lindex - 1].isalnum(): |
|
break |
|
else: |
|
lindex -= 1 |
|
else: |
|
lindex -= 1 |
|
if lindex == -1 * (len(definition) + 1): |
|
raise ValueError("definition {} was not found in {}".format(abbrev, definition)) |
|
|
|
else: |
|
if shortchar == longchar: |
|
sindex -= 1 |
|
lindex -= 1 |
|
else: |
|
lindex -= 1 |
|
|
|
new_candidate = Candidate(definition[lindex:len(definition)]) |
|
new_candidate.set_position(definition.start+lindex+len(definition), definition.stop) |
|
definition = new_candidate |
|
|
|
tokens = len(definition.split()) |
|
length = len(abbrev) |
|
|
|
if tokens > min([length + 5, length * 2]): |
|
raise ValueError("did not meet min(|A|+5, |A|*2) constraint") |
|
|
|
|
|
if definition.count('(') != definition.count(')'): |
|
raise ValueError("Unbalanced parentheses not allowed in a definition") |
|
|
|
|
|
new_definition_dict={'definition':definition,'start':definition.start,'stop':definition.stop} |
|
return new_definition_dict |
|
|
|
|
|
def extract_abbreviation_definition_pairs(file_path=None, doc_text=None): |
|
abbrev_map = [] |
|
abbr_full_dict={} |
|
fullloc_abbr_dict={} |
|
omit = 0 |
|
written = 0 |
|
if file_path: |
|
sentence_iterator = enumerate(yield_lines_from_file(file_path)) |
|
elif doc_text: |
|
sentence_iterator = enumerate(yield_lines_from_doc(doc_text)) |
|
else: |
|
return abbrev_map |
|
|
|
for i, sentence in sentence_iterator: |
|
|
|
try: |
|
for candidate in best_candidates(sentence): |
|
|
|
try: |
|
|
|
definition = get_definition(candidate, sentence) |
|
|
|
|
|
|
|
except (ValueError, IndexError) as e: |
|
|
|
omit += 1 |
|
else: |
|
try: |
|
definition_dict = select_definition(definition, candidate) |
|
except (ValueError, IndexError) as e: |
|
|
|
omit += 1 |
|
else: |
|
definition_dict['abbre']=candidate |
|
abbrev_map.append(definition_dict) |
|
abbr_full_dict[definition_dict['abbre']]=(definition_dict['start'],definition_dict['stop']) |
|
fullloc_abbr_dict[str(definition_dict['start'])+' '+str(definition_dict['stop'])]=definition_dict['abbre'] |
|
written += 1 |
|
except (ValueError, IndexError) as e: |
|
log.debug("{} Error processing sentence {}: {}".format(i, sentence, e.args[0])) |
|
log.debug("{} abbreviations detected and kept ({} omitted)".format(written, omit)) |
|
return abbrev_map,abbr_full_dict,fullloc_abbr_dict |
|
|
|
|
|
def postprocess_abbr(ner_result,ori_text): |
|
|
|
final_result=[] |
|
if len(ner_result)==0: |
|
return {} |
|
|
|
|
|
abbr_list, abbr_full_dict,fullloc_abbr_dict=extract_abbreviation_definition_pairs(doc_text=ori_text) |
|
|
|
|
|
|
|
|
|
|
|
ner_loc_result={} |
|
for ele in ner_result.keys(): |
|
|
|
ner_loc_result[ner_result[ele][1]]=ner_result[ele] |
|
|
|
|
|
for entity_loc in ner_result.keys(): |
|
|
|
if (ner_result[entity_loc][-1]!='CellLine') and (ner_result[entity_loc][2] in abbr_full_dict.keys()) : |
|
|
|
fullname_loc_e=str(abbr_full_dict[ner_result[entity_loc][2]][1]) |
|
|
|
if fullname_loc_e in ner_loc_result.keys(): |
|
final_result.append([ner_result[entity_loc][0], ner_result[entity_loc][1],ner_result[entity_loc][2],ner_loc_result[fullname_loc_e][-1]]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif entity_loc in fullloc_abbr_dict.keys(): |
|
abbr_loc_s=ori_text.find(fullloc_abbr_dict[entity_loc],int(ner_result[entity_loc][1])) |
|
final_result.append(ner_result[entity_loc]) |
|
if abbr_loc_s>=0: |
|
abbr_loc_e=abbr_loc_s+len(fullloc_abbr_dict[entity_loc]) |
|
abbr_loc=str(abbr_loc_s)+' '+str(abbr_loc_e) |
|
|
|
if abbr_loc not in ner_result.keys(): |
|
final_result.append([str(abbr_loc_s),str(abbr_loc_e),ori_text[abbr_loc_s:abbr_loc_e],ner_result[entity_loc][-1]]) |
|
|
|
else: |
|
|
|
if len(ner_result[entity_loc][2])==1 and (not ner_result[entity_loc][2].isalpha()): |
|
pass |
|
|
|
else: |
|
final_result.append(ner_result[entity_loc]) |
|
|
|
|
|
|
|
return final_result |
|
|
|
|
|
def entity_consistency(ner_result,ori_text): |
|
|
|
final_result={} |
|
entity_loc_set=set() |
|
entity_type={} |
|
|
|
for segs in ner_result: |
|
entity_loc_set.add(segs[0]+' '+segs[1]) |
|
final_result['\t'.join(segs)]=[int(segs[0]),int(segs[1])] |
|
if len(segs[2])>1: |
|
if segs[2].isupper(): |
|
if segs[2] not in entity_type.keys(): |
|
entity_type[segs[2]]={segs[-1]:1} |
|
else: |
|
if segs[-1] in entity_type[segs[2]]: |
|
entity_type[segs[2]][segs[-1]]+=1 |
|
else: |
|
entity_type[segs[2]][segs[-1]]=1 |
|
else: |
|
if segs[2].lower() not in entity_type.keys(): |
|
entity_type[segs[2].lower()]={segs[-1]:1} |
|
else: |
|
if segs[-1] in entity_type[segs[2].lower()]: |
|
entity_type[segs[2].lower()][segs[-1]]+=1 |
|
else: |
|
entity_type[segs[2].lower()][segs[-1]]=1 |
|
|
|
|
|
|
|
|
|
entity_type_major={} |
|
for ele in entity_type.keys(): |
|
entity_type_major[ele]=max(zip(entity_type[ele].values(), entity_type[ele].keys()))[1] |
|
|
|
|
|
|
|
|
|
for entity_text in entity_type_major.keys(): |
|
|
|
if entity_text.isupper(): |
|
new_text=ori_text |
|
else: |
|
new_text=ori_text.lower() |
|
ent_eid=0 |
|
while new_text.find(entity_text,ent_eid)>=0: |
|
ent_sid=new_text.find(entity_text,ent_eid) |
|
ent_eid=ent_sid+len(entity_text) |
|
entity_loc=str(ent_sid)+' '+str(ent_eid) |
|
|
|
if entity_loc not in entity_loc_set: |
|
if ent_sid>0 and ent_eid<len(new_text): |
|
if new_text[ent_sid-1].isalnum()==False and new_text[ent_eid].isalnum()==False: |
|
final_result[str(ent_sid)+'\t'+str(ent_eid)+'\t'+ori_text[ent_sid:ent_eid]+'\t'+entity_type_major[entity_text]]=[ent_sid,ent_eid] |
|
entity_loc_set.add(entity_loc) |
|
elif ent_sid==0 and ent_eid<len(new_text): |
|
if new_text[ent_eid].isalnum()==False: |
|
final_result[str(ent_sid)+'\t'+str(ent_eid)+'\t'+ori_text[ent_sid:ent_eid]+'\t'+entity_type_major[entity_text]]=[ent_sid,ent_eid] |
|
entity_loc_set.add(entity_loc) |
|
elif ent_sid>0 and ent_eid==len(new_text): |
|
if new_text[ent_sid-1].isalnum()==False : |
|
final_result[str(ent_sid)+'\t'+str(ent_eid)+'\t'+ori_text[ent_sid:ent_eid]+'\t'+entity_type_major[entity_text]]=[ent_sid,ent_eid] |
|
entity_loc_set.add(entity_loc) |
|
|
|
if len(final_result)!=len(ner_result): |
|
final_result=sorted(final_result.items(), key=lambda kv:(kv[1]), reverse=False) |
|
mention_list=[] |
|
for ele in final_result: |
|
mention_list.append(ele[0].split('\t')) |
|
final_ner_result=combine_overlap(mention_list) |
|
else: |
|
final_ner_result=ner_result |
|
return final_ner_result |
|
|
|
def combine_overlap(mention_list): |
|
|
|
entity_list=[] |
|
if len(mention_list)>2: |
|
|
|
first_entity=mention_list[0] |
|
nest_list=[first_entity] |
|
max_eid=int(first_entity[1]) |
|
for i in range(1,len(mention_list)): |
|
segs=mention_list[i] |
|
if int(segs[0])>= max_eid: |
|
if len(nest_list)==1: |
|
entity_list.append(nest_list[0]) |
|
nest_list=[] |
|
nest_list.append(segs) |
|
if int(segs[1])>max_eid: |
|
max_eid=int(segs[1]) |
|
else: |
|
tem=find_max_entity(nest_list) |
|
entity_list.append(tem) |
|
nest_list=[] |
|
nest_list.append(segs) |
|
if int(segs[1])>max_eid: |
|
max_eid=int(segs[1]) |
|
|
|
else: |
|
nest_list.append(segs) |
|
if int(segs[1])>max_eid: |
|
max_eid=int(segs[1]) |
|
if nest_list!=[]: |
|
if len(nest_list)==1: |
|
entity_list.append(nest_list[0]) |
|
|
|
else: |
|
tem=find_max_entity(nest_list) |
|
entity_list.append(tem) |
|
else: |
|
entity_list=mention_list |
|
|
|
return entity_list |
|
|
|
def find_max_entity(nest_list): |
|
max_len=0 |
|
max_entity=[] |
|
for i in range(0, len(nest_list)): |
|
length=int(nest_list[i][1])-int(nest_list[i][0]) |
|
if length>max_len: |
|
max_len=length |
|
max_entity=nest_list[i] |
|
|
|
return max_entity |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
path='//panfs/pan1/bionlplab/luol2/PubTator3/example/post-out/' |
|
fin=open(path+'PubmedBERT-CRF-AIO_ALL.test_preds','r',encoding='utf-8') |
|
all_in=fin.read().strip().split('\n\n') |
|
fout=open(path+'PubmedBERT-CRF-AIO_ALL-post4.test_preds','w',encoding='utf-8') |
|
for doc in all_in: |
|
lines=doc.split('\n') |
|
pmid=lines[0].split('|t|')[0] |
|
ori_text=lines[0].split('|t|')[1]+' '+lines[1].split('|a|')[1] |
|
ner_result={} |
|
for i in range(2,len(lines)): |
|
seg=lines[i].split('\t') |
|
ner_result[seg[1]+' '+seg[2]]=seg[1:] |
|
|
|
final_ner=postprocess_abbr(ner_result,ori_text) |
|
|
|
final_ner=entity_consistency(final_ner,ori_text) |
|
|
|
fout.write(lines[0]+'\n'+lines[1]+'\n') |
|
for ele in final_ner: |
|
fout.write(pmid+'\t'+'\t'.join(ele)+'\n') |
|
fout.write('\n') |
|
fout.close() |
|
|
|
|
|
|
|
|
|
|