File size: 4,248 Bytes
1212b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# +
from tokenizers import Tokenizer
import re
from utils import extract_variables, clean_text

template_tokenizer = Tokenizer.from_file("blank_wordpiece.tokenizer")
with open(
    "template_vocab.txt",
    "r",
) as f:
    template_vocab = [l.rstrip("\n") for l in f.readlines()]

template_tokenizer.add_tokens(template_vocab)

variable_tokenizer = Tokenizer.from_file("blank_wordpiece.tokenizer")
variable_vocab = ["[UNK]"]
variable_vocab.extend([str(i) for i in range(101)])
variable_vocab.extend(
    ["‡", "MILD", "MODERATE", "SEVERE", "MILD/MODERATE", "MODERATE/SEVERE"]
)
variable_vocab.extend(["." + str(i) for i in range(10)])
variable_vocab.extend(["0" + str(i) for i in range(10)])
variable_tokenizer.add_tokens(variable_vocab)

max_token_id = (
    len(template_tokenizer.get_vocab()) + len(variable_tokenizer.get_vocab()) - 1
)
BOS_token_id = max_token_id + 1
EOS_token_id = max_token_id + 2

var_symbol = re.compile(r"<#>")

# mapping from ID to number of missing numbers
var_counts = {}

for token, token_id in template_tokenizer.get_vocab().items():
    var_count = len(var_symbol.findall(token))
    var_counts[token_id] = var_count

template_vocab_len = len(template_tokenizer.get_vocab())

# Some final text cleaning replacements.
replacements = [
    (
        "THE INFERIOR VENA CAVA IS NORMAL IN SIZE AND SHOWS A NORMAL RESPIRATORY COLLAPSE, CONSISTENT WITH NORMAL RIGHT ATRIAL PRESSURE (<#>MMHG).",
        "THE INFERIOR VENA CAVA SHOWS A NORMAL RESPIRATORY COLLAPSE CONSISTENT WITH NORMAL RIGHT ATRIAL PRESSURE (<#>MMHG).",
    ),
    (
        "RESTING SEGMENTAL WALL MOTION ANALYSIS.:",
        "RESTING SEGMENTAL WALL MOTION ANALYSIS.",
    ),
]


def simple_replacement(text):
    for r in replacements:
        text = text.replace(r[0], r[1])
    return text


def pad_or_trunc(tokens, length):
    if len(tokens) > length:
        eos_token = max(tokens)
        tokens = tokens[:length]
        tokens[-1] = eos_token
    else:
        tokens = tokens + [0] * (length - len(tokens))
    return tokens


def template_tokenize(report):
    report = clean_text(report)

    # The "variables" (numbers, severity words) are removed from the report
    # and returned in a list. The report text has all variables replaced
    # with a placeholder symbol: <#>
    # The template tokenizer's vocabulary is made up of phrases with this
    # placeholder symbol.
    variables, report = extract_variables(report)
    report = simple_replacement(report)
    toks = template_tokenizer.encode(report)

    # Now we have a list of tokenized phrases, some of which had variables
    # extracted from them, and some of which didn't.
    var_mask = []
    unk = []
    for (start, end), tok, tok_id in zip(toks.offsets, toks.tokens, toks.ids):
        if not tok == "[UNK]":
            var_mask.extend([True] * var_counts[tok_id])
        else:
            source = report[start:end]
            unk.append((source, start))
            var_count = len(var_symbol.findall(source))
            var_mask.extend([False] * var_count)

    tok_ids = [t for t in toks.ids if not t == 0]
    matched_vars = [v for v, mask in zip(variables, var_mask) if mask]

    new_tok_ids = []
    for tok_id in tok_ids:
        var_count = var_counts[tok_id]
        recognized_vars = []
        for _ in range(var_count):
            recognized_vars.append(matched_vars.pop(0))

        # variables are joined with this weird char before being
        # tokenized so that the model can tell where one variable
        # ends and another begins
        var_string = "‡".join(recognized_vars)
        var_toks = variable_tokenizer.encode(var_string).ids
        var_toks = [v + template_vocab_len for v in var_toks]
        new_tok_ids.extend([tok_id, *var_toks])

    new_tok_ids = [BOS_token_id, *new_tok_ids, EOS_token_id]
    return pad_or_trunc(new_tok_ids, 77)


template_detokenizer = {k: f"[{v}]" for v, k in template_tokenizer.get_vocab().items()}

for k, v in variable_tokenizer.get_vocab().items():
    template_detokenizer[v + template_vocab_len] = f"<{k}>"
template_detokenizer[max_token_id + 1] = "[BOS]"
template_detokenizer[max_token_id + 2] = "[EOS]"


def template_detokenize(ids):
    return [template_detokenizer[i] for i in ids]