import urllib.request, urllib.error, urllib.parse
import json
import pandas as pd
import ssl
import torch
import re
from pprint import pprint
from captum.attr import visualization
REST_URL = "http://data.bioontology.org"
API_KEY = "604a90bc-ef14-4c26-a347-f4928fa086ea"
ssl._create_default_https_context = ssl._create_unverified_context
class PyTMinMaxScalerVectorized(object):
From https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455
Transforms each channel to the range [0, 1].
def __call__(self, tensor):
scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0])
tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0])
return tensor
def find_end(text):
"""Find the end of the report."""
ends = [len(text)]
patterns = [
re.compile(r'\n {3,}DR.', re.I),
re.compile(r'[ ]{1,}RADLINE ', re.I),
re.compile(r'.*electronically signed on', re.I),
for pattern in patterns:
matchobj = pattern.search(text)
if matchobj:
return min(ends)
def pattern_repl(matchobj):
Return a replacement string to be used for match object
return ' '.rjust(len(matchobj.group(0)))
def clean_text(text):
Clean text
# Replace [**Patterns**] with spaces.
text = re.sub(r'\[\*\*.*?\*\*\]', pattern_repl, text)
# Replace `_` with spaces.
text = re.sub(r'_', ' ', text)
start = 0
end = find_end(text)
new_text = ''
if start > 0:
new_text += ' ' * start
new_text = text[start:end]
# make sure the new text has the same length of old text.
if len(text) - end > 0:
new_text += ' ' * (len(text) - end)
return new_text
def get_drg_link(drg_code):
drg_code = str(drg_code)
if len(drg_code) == 1:
drg_code = '00' + drg_code
elif len(drg_code) == 2:
drg_code = '0' + drg_code
return f'https://www.findacode.com/code.php?set=DRG&c={drg_code}'
def prettify(dict_list, k):
li = [di[k] for di in dict_list]
result = "\n".join(l for l in li)
return result
def get_json(text_to_annotate):
url = REST_URL + "/annotator?text=" + urllib.parse.quote(text_to_annotate) + "&ontologies=ICD9CM" +\
"&longest_only=false" + "&exclude_numbers=false" + "&whole_word_only=true" + '&exclude_synonyms=false'
opener = urllib.request.build_opener()
opener.addheaders = [('Authorization', 'apikey token=' + API_KEY)]
return json.loads(opener.open(url).read())
return []
def parse_results(results):
if len(results) == 0:
return []
rlist = []
for result in results:
annotations = result['annotations']
for annotation in annotations:
start = annotation['from']-1
end = annotation['to'] - 1
text = annotation['text']
'start': start,
'end': end,
'text': text,
'link': result['annotatedClass']['@id']
return rlist
def get_icd_annotations(text):
response = get_json(text)
annotation_list = parse_results(response)
return annotation_list
def subfinder(mylist, pattern):
mylist = mylist.tolist()
pattern = pattern.tolist()
return list(filter(lambda x: x in pattern, mylist))
def tokenize_icds(tokenizer, annotations, token_ids):
icd_tokens = torch.zeros(token_ids.shape)
for annotation in annotations:
icd = annotation['text']
icd_token_ids = tokenizer(icd, add_special_tokens=False, return_tensors='pt').input_ids[0]
# find index of the beginning icd token
starting_indices = (token_ids==icd_token_ids[0]).nonzero(as_tuple=False)
num_icd_tokens = icd_token_ids.shape[0]
# if there's more than 1 icd token for the given annotation
if num_icd_tokens > 1:
# if there's only one starting index
if starting_indices.shape[0] == 1:
starting_index = starting_indices.item()
icd_tokens[starting_index: starting_index + num_icd_tokens] = 1
# if there's more than 1 starting index, determine which is the appropriate
for starting_index in starting_indices:
if token_ids[starting_index + num_icd_tokens] == icd_token_ids:
icd_tokens[starting_index: starting_index + num_icd_tokens] = 1
# otherwise, set the corresponding index to a value of 1
icd_tokens[starting_indices] = 1
return icd_tokens
def get_attribution(text, tokenizer, model_outputs, inputs, k=7):
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
padding_idx = tokens.index('[PAD]')
tokens = tokens[:padding_idx][1:-1]
attn = model_outputs[-1][0]
agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn)
return agg_attn, final_text
def reconstruct_text(tokenizer, tokens, attn):
find a word -> token_id mapping that allows you to
perform an aggregation on the sub-tokens' attention
reconstructed_text = tokenizer.convert_tokens_to_string(tokens)
num_subtokens = len([t for t in tokens if t.startswith('#')])
aggregated_attn = torch.zeros(len(tokens) - num_subtokens)
token_indices = [0]
token_idx = 0
reconstructed_tokens = []
for i, token in enumerate(tokens[1:], start=1):
# case when a token is a subtoken
if token.startswith('#'):
# reconstruct the tokens to make sure you're doing this correctly
reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices)
# find the corresponding attention vectors
aggregated_attn[token_idx] = torch.mean(attn[token_indices])
# create new index list
token_indices = [i]
token_idx += 1
# reconstruct the tokens to make sure you're doing this correctly
reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices)
# find the corresponding attention vectors
aggregated_attn[token_idx] = torch.mean(attn[token_indices])
# final representation of text
final_text = ' '.join(reconstructed_tokens).replace(' .', '.')
final_text = final_text.replace(' ,', ',')
# final_text == reconstructed_text
return aggregated_attn, reconstructed_tokens
def load_rule(path):
rule_df = pd.read_csv(path)
# remove MDC 15 - neonate and couple other codes related to postcare
if 'MS' in path:
msk = (rule_df['MDC']!='15') & (~rule_df['MS-DRG'].isin([945, 946, 949, 950, 998, 999]))
space = sorted(rule_df[msk]['DRG_CODE'].unique())
elif 'APR' in path:
msk = (rule_df['MDC']!='15') & (~rule_df['APR-DRG'].isin([860, 863]))
space = sorted(rule_df[msk]['DRG_CODE'].unique())
drg2idx = {}
for d in space:
drg2idx[d] = len(drg2idx)
i2d = {v:k for k,v in drg2idx.items()}
d2mdc, d2w = {}, {}
for _, r in rule_df.iterrows():
drg = r['DRG_CODE']
mdc = r['MDC']
w = r['WEIGHT']
d2mdc[drg] = mdc
d2w[drg] = w
return rule_df, drg2idx, i2d, d2mdc, d2w
def visualize_attn(model_results):
class_id = model_results['class_dsc']
prob = model_results['prob']
attn = model_results['attn']
tokens = model_results['tokens']
scaler = PyTMinMaxScalerVectorized()
normalized_attn = scaler(attn)
viz_record = visualization.VisualizationDataRecord(
return visualize_text(viz_record, drg_link=model_results['drg_link'], icd_annotations=model_results['icd_results'])
def modify_attn_html(attn_html):
attn_split = attn_html.split('<mark')
htmls = [attn_split[0]]
for html in attn_split[1:]:
# wrap around href tag
href_html = f'<a href="https://espn.com" \
<mark{html} \
return "".join(htmls)
def modify_code_html(html, link, icd=False):
html = html.split('<td>')[1].split('</td>')[0]
href_html = f'<td><a href="{link}"{html}</a></td>'
if icd:
href_html = href_html.replace('<td>', '').replace('</td>', '')
return href_html
def modify_drg_html(html, drg_link):
return modify_code_html(html=html, link=drg_link, icd=False)
def get_icd_html(icd_list):
if len(icd_list) == 0:
return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
final_html = '<td>'
for icd_dict in icd_list:
text, link = icd_dict['text'], icd_dict['link']
tmp_html = visualization.format_classname(classname=text)
html = modify_code_html(html=tmp_html, link=link, icd=True)
final_html += html
return final_html + '</td>'
# copied out of captum because we need raw html instead of a jupyter widget
def visualize_text(datarecord, drg_link, icd_annotations):
dom = ["<table width: 100%>"]
rows = [
"<th style='text-align: left'>Predicted DRG</th>"
"<th style='text-align: left'>Word Importance</th>"
"<th style='text-align: left'>ICD Codes</th>"
pred_class_html = visualization.format_classname(datarecord.pred_class)
icd_class_html = get_icd_html(icd_annotations)
pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
word_attn_html = visualization.format_word_importances(
datarecord.raw_input_ids, datarecord.word_attributions
word_attn_html = modify_attn_html(word_attn_html)
html = "".join(dom)
return html