Spaces:
Runtime error
Runtime error
import pdb | |
import sys | |
WORD_POS = 1 | |
TAG_POS = 2 | |
MASK_TAG = "__entity__" | |
INPUT_MASK_TAG = ":__entity__" | |
RESET_POS_TAG='RESET' | |
noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD'] | |
cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP'] | |
def detect_masked_positions(terms_arr): | |
sentence_arr,span_arr = generate_masked_sentences(terms_arr) | |
new_sent_arr = [] | |
for i in range(len(terms_arr)): | |
new_sent_arr.append(terms_arr[i][WORD_POS]) | |
return new_sent_arr,sentence_arr,span_arr | |
def generate_masked_sentences(terms_arr): | |
size = len(terms_arr) | |
sentence_arr = [] | |
span_arr = [] | |
i = 0 | |
hack_for_no_nouns_case(terms_arr) | |
while (i < size): | |
term_info = terms_arr[i] | |
if (term_info[TAG_POS] in noun_tags): | |
skip = gen_sentence(sentence_arr,terms_arr,i) | |
i += skip | |
for j in range(skip): | |
span_arr.append(1) | |
else: | |
i += 1 | |
span_arr.append(0) | |
#print(sentence_arr) | |
return sentence_arr,span_arr | |
def hack_for_no_nouns_case(terms_arr): | |
''' | |
This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns | |
Happens for odd inputs like a single word like "eg" etc. | |
Just make the first term as a noun to proceed. | |
''' | |
size = len(terms_arr) | |
i = 0 | |
found = False | |
while (i < size): | |
term_info = terms_arr[i] | |
if (term_info[TAG_POS] in noun_tags): | |
found = True | |
break | |
else: | |
i += 1 | |
if (not found and len(terms_arr) >= 1): | |
term_info = terms_arr[0] | |
term_info[TAG_POS] = noun_tags[0] | |
def gen_sentence(sentence_arr,terms_arr,index): | |
size = len(terms_arr) | |
new_sent = [] | |
for prefix,term in enumerate(terms_arr[:index]): | |
new_sent.append(term[WORD_POS]) | |
i = index | |
skip = 0 | |
while (i < size): | |
if (terms_arr[i][TAG_POS] in noun_tags): | |
skip += 1 | |
i += 1 | |
else: | |
break | |
new_sent.append(MASK_TAG) | |
i = index + skip | |
while (i < size): | |
new_sent.append(terms_arr[i][WORD_POS]) | |
i += 1 | |
assert(skip != 0) | |
sentence_arr.append(new_sent) | |
return skip | |
def capitalize(terms_arr): | |
for i,term_tag in enumerate(terms_arr): | |
#print(term_tag) | |
if (term_tag[TAG_POS] in cap_tags): | |
word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:] | |
term_tag[WORD_POS] = word | |
#print(terms_arr) | |
def set_POS_based_on_entities(sent): | |
terms_arr = [] | |
sent_arr = sent.split() | |
for i,word in enumerate(sent_arr): | |
#print(term_tag) | |
term_tag = ['-']*5 | |
if (word.endswith(INPUT_MASK_TAG)): | |
term_tag[TAG_POS] = noun_tags[0] | |
term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"") | |
else: | |
term_tag[TAG_POS] = RESET_POS_TAG | |
term_tag[WORD_POS] = word | |
terms_arr.append(term_tag) | |
return terms_arr | |
#print(terms_arr) | |
def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs): | |
ret_span_arr = span_arr.copy() | |
ret_masked_sent_arr = [] | |
sent_index = 0 | |
loop_span_index = 0 | |
while (loop_span_index < len(span_arr)): | |
span_val = span_arr[loop_span_index] | |
orig_index = loop_span_index | |
if (span_val == 1): | |
curr_index = orig_index | |
is_all_common = True | |
while (curr_index < len(span_arr) and span_arr[curr_index] == 1): | |
term = terms_arr[curr_index] | |
if (term[WORD_POS].lower() not in common_descs): | |
is_all_common = False | |
curr_index += 1 | |
loop_span_index = curr_index #note the loop scan index is updated | |
if (is_all_common): | |
curr_index = orig_index | |
print("Filtering common span: ",end='') | |
while (curr_index < len(span_arr) and span_arr[curr_index] == 1): | |
print(terms_arr[curr_index][WORD_POS],' ',end='') | |
ret_span_arr[curr_index] = 0 | |
curr_index += 1 | |
print() | |
sent_index += 1 # we are skipping a span | |
else: | |
ret_masked_sent_arr.append(masked_sent_arr[sent_index]) | |
sent_index += 1 | |
else: | |
loop_span_index += 1 | |
return ret_masked_sent_arr,ret_span_arr | |
def normalize_casing(sent): | |
sent_arr = sent.split() | |
ret_sent_arr = [] | |
for i,word in enumerate(sent_arr): | |
if (len(word) > 1): | |
norm_word = word[0] + word[1:].lower() | |
else: | |
norm_word = word[0] | |
ret_sent_arr.append(norm_word) | |
return ' '.join(ret_sent_arr) | |