import torch |
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel |
from torch import nn |
from itertools import chain |
from torch.nn import MSELoss, CrossEntropyLoss |
from cleantext import clean |
from num2words import num2words |
import re |
import string |
punct_chars = list((set(string.punctuation) | {'β', 'β', 'β', 'β', '~', '|', 'β', 'β', 'β¦', "'", "`", '_'})) |
punct_chars.sort() |
punctuation = ''.join(punct_chars) |
replace = re.compile('[%s]' % re.escape(punctuation)) |
"sum", |
"arc", |
"mass", |
"digit", |
"graph", |
"liter", |
"gram", |
"add", |
"angle", |
"scale", |
"data", |
"array", |
"ruler", |
"meter", |
"total", |
"unit", |
"prism", |
"median", |
"ratio", |
"area", |
] |
"absolute value", |
"area", |
"average", |
"base of", |
"box plot", |
"categorical", |
"coefficient", |
"common factor", |
"common multiple", |
"compose", |
"coordinate", |
"cubed", |
"decompose", |
"dependent variable", |
"distribution", |
"dot plot", |
"double number line diagram", |
"equivalent", |
"equivalent expression", |
"ratio", |
"exponent", |
"frequency", |
"greatest common factor", |
"gcd", |
"height of", |
"histogram", |
"independent variable", |
"interquartile range", |
"iqr", |
"least common multiple", |
"long division", |
"mean absolute deviation", |
"median", |
"negative number", |
"opposite vertex", |
"parallelogram", |
"percent", |
"polygon", |
"polyhedron", |
"positive number", |
"prism", |
"pyramid", |
"quadrant", |
"quadrilateral", |
"quartile", |
"rational number", |
"reciprocal", |
"equality", |
"inequality", |
"squared", |
"statistic", |
"surface area", |
"identity property", |
"addend", |
"unit", |
"number sentence", |
"make ten", |
"take from ten", |
"number bond", |
"total", |
"estimate", |
"hashmark", |
"meter", |
"number line", |
"ruler", |
"centimeter", |
"base ten", |
"expanded form", |
"hundred", |
"thousand", |
"place value", |
"number disk", |
"standard form", |
"unit form", |
"word form", |
"tens place", |
"algorithm", |
"equation", |
"simplif", |
"addition", |
"subtract", |
"array", |
"even number", |
"odd number", |
"repeated addition", |
"tessellat", |
"whole number", |
"number path", |
"rectangle", |
"square", |
"bar graph", |
"data", |
"degree", |
"line plot", |
"picture graph", |
"scale", |
"survey", |
"thermometer", |
"estimat", |
"tape diagram", |
"value", |
"analog", |
"angle", |
"parallel", |
"partition", |
"pentagon", |
"right angle", |
"cube", |
"digital", |
"quarter of", |
"tangram", |
"circle", |
"hexagon", |
"half circle", |
"half-circle", |
"quarter circle", |
"quarter-circle", |
"semicircle", |
"semi-circle", |
"rectang", |
"rhombus", |
"trapezoid", |
"triangle", |
"commutative", |
"equal group", |
"distributive", |
"divide", |
"division", |
"multipl", |
"parentheses", |
"quotient", |
"rotate", |
"unknown", |
"add", |
"capacity", |
"continuous", |
"endpoint", |
"gram", |
"interval", |
"kilogram", |
"volume", |
"liter", |
"milliliter", |
"approximate", |
"area model", |
"square unit", |
"unit square", |
"geometr", |
"equivalent fraction", |
"fraction form", |
"fractional unit", |
"unit fraction", |
"unit interval", |
"measur", |
"graph", |
"scaled graph", |
"diagonal", |
"perimeter", |
"regular polygon", |
"tessellate", |
"tetromino", |
"heptagon", |
"octagon", |
"digit", |
"expression", |
"sum", |
"kilometer", |
"mass", |
"mixed unit", |
"length", |
"measure", |
"simplify", |
"associative", |
"composite", |
"divisible", |
"divisor", |
"partial product", |
"prime number", |
"remainder", |
"acute", |
"arc", |
"collinear", |
"equilateral", |
"intersect", |
"isosceles", |
"symmetry", |
"line segment", |
"line", |
"obtuse", |
"perpendicular", |
"protractor", |
"scalene", |
"straight angle", |
"supplementary angle", |
"vertex", |
"common denominator", |
"denominator", |
"fraction", |
"mixed number", |
"numerator", |
"whole", |
"decimal expanded form", |
"decimal", |
"hundredth", |
"tenth", |
"customary system of measurement", |
"customary unit", |
"gallon", |
"metric", |
"metric unit", |
"ounce", |
"pint", |
"quart", |
"convert", |
"distance", |
"millimeter", |
"thousandth", |
"hundredths", |
"conversion factor", |
"decimal fraction", |
"multiplier", |
"equivalence", |
"multiple", |
"product", |
"benchmark fraction", |
"cup", |
"pound", |
"yard", |
"whole unit", |
"decimal divisor", |
"factors", |
"bisect", |
"cubic units", |
"hierarchy", |
"unit cube", |
"attribute", |
"kite", |
"bisector", |
"solid figure", |
"square units", |
"dimension", |
"axis", |
"ordered pair", |
"angle measure", |
"horizontal", |
"vertical", |
"categorical data", |
"lcm", |
"measure of center", |
"meters per second", |
"numerical", |
"solution", |
"unit price", |
"unit rate", |
"variability", |
"variable", |
] |
def get_num_words(text): |
if not isinstance(text, str): |
print("%s is not a string" % text) |
text = replace.sub(' ', text) |
text = re.sub(r'\s+', ' ', text) |
text = text.strip() |
text = re.sub(r'\[.+\]', " ", text) |
return len(text.split()) |
def number_to_words(num): |
try: |
return num2words(re.sub(",", "", num)) |
except: |
return num |
clean_str = lambda s: clean(s, |
fix_unicode=True, |
to_ascii=True, |
lower=True, |
no_line_breaks=True, |
no_urls=True, |
no_emails=True, |
no_phone_numbers=True, |
no_numbers=True, |
no_digits=False, |
no_currency_symbols=False, |
no_punct=False, |
replace_with_url="<URL>", |
replace_with_email="<EMAIL>", |
replace_with_phone_number="<PHONE>", |
replace_with_number=lambda m: number_to_words(m.group()), |
replace_with_digit="0", |
replace_with_currency_symbol="<CUR>", |
lang="en" |
) |
clean_str_nopunct = lambda s: clean(s, |
fix_unicode=True, |
to_ascii=True, |
lower=True, |
no_line_breaks=True, |
no_urls=True, |
no_emails=True, |
no_phone_numbers=True, |
no_numbers=True, |
no_digits=False, |
no_currency_symbols=False, |
no_punct=True, |
replace_with_url="<URL>", |
replace_with_email="<EMAIL>", |
replace_with_phone_number="<PHONE>", |
replace_with_number=lambda m: number_to_words(m.group()), |
replace_with_digit="0", |
replace_with_currency_symbol="<CUR>", |
lang="en" |
) |
class MultiHeadModel(BertPreTrainedModel): |
"""Pre-trained BERT model that uses our loss functions""" |
def __init__(self, config, head2size): |
super(MultiHeadModel, self).__init__(config, head2size) |
config.num_labels = 1 |
self.bert = BertModel(config) |
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
module_dict = {} |
for head_name, num_labels in head2size.items(): |
module_dict[head_name] = nn.Linear(config.hidden_size, num_labels) |
self.heads = nn.ModuleDict(module_dict) |
self.init_weights() |
def forward(self, input_ids, token_type_ids=None, attention_mask=None, |
head2labels=None, return_pooler_output=False, head2mask=None, |
nsp_loss_weights=None): |
device = "cuda" if torch.cuda.is_available() else "cpu" |
output = self.bert( |
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, |
output_attentions=False, output_hidden_states=False, return_dict=True) |
pooled_output = self.dropout(output["pooler_output"]).to(device) |
head2logits = {} |
return_dict = {} |
for head_name, head in self.heads.items(): |
head2logits[head_name] = self.heads[head_name](pooled_output) |
head2logits[head_name] = head2logits[head_name].float() |
return_dict[head_name + "_logits"] = head2logits[head_name] |
if head2labels is not None: |
for head_name, labels in head2labels.items(): |
num_classes = head2logits[head_name].shape[1] |
if num_classes == 1: |
if head2mask is not None and head_name in head2mask: |
num_positives = head2labels[head2mask[head_name]].sum() |
if num_positives == 0: |
return_dict[head_name + "_loss"] = torch.tensor([0]).to(device) |
else: |
loss_fct = MSELoss(reduction='none') |
loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives |
else: |
loss_fct = MSELoss() |
return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
else: |
loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float()) |
return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1)) |
if return_pooler_output: |
return_dict["pooler_output"] = output["pooler_output"] |
return return_dict |
class InputBuilder(object): |
"""Base class for building inputs from segments.""" |
def __init__(self, tokenizer): |
self.tokenizer = tokenizer |
self.mask = [tokenizer.mask_token_id] |
def build_inputs(self, history, reply, max_length): |
raise NotImplementedError |
def mask_seq(self, sequence, seq_id): |
sequence[seq_id] = self.mask |
return sequence |
@classmethod |
def _combine_sequence(self, history, reply, max_length, flipped=False): |
history = [s[:max_length] for s in history] |
reply = reply[:max_length] |
if flipped: |
return [reply] + history |
return history + [reply] |
class BertInputBuilder(InputBuilder): |
"""Processor for BERT inputs""" |
def __init__(self, tokenizer): |
InputBuilder.__init__(self, tokenizer) |
self.cls = [tokenizer.cls_token_id] |
self.sep = [tokenizer.sep_token_id] |
self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"] |
self.padded_inputs = ["input_ids", "token_type_ids"] |
self.flipped = False |
def build_inputs(self, history, reply, max_length, input_str=True): |
"""See base class.""" |
if input_str: |
history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history] |
reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply)) |
sequence = self._combine_sequence(history, reply, max_length, self.flipped) |
sequence = [s + self.sep for s in sequence] |
sequence[0] = self.cls + sequence[0] |
instance = {} |
instance["input_ids"] = list(chain(*sequence)) |
last_speaker = 0 |
other_speaker = 1 |
seq_length = len(sequence) |
instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker |
for i, s in enumerate(sequence) for _ in s] |
return instance |