Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.criterions import LegacyFairseqCriterion, register_criterion | |
from fairseq.data import encoders | |
class WSCCriterion(LegacyFairseqCriterion): | |
def __init__(self, args, task): | |
super().__init__(args, task) | |
if self.args.save_predictions is not None: | |
self.prediction_h = open(self.args.save_predictions, "w") | |
else: | |
self.prediction_h = None | |
self.bpe = encoders.build_bpe(args.bpe) | |
self.tokenizer = encoders.build_tokenizer(args.tokenizer) | |
def __del__(self): | |
if self.prediction_h is not None: | |
self.prediction_h.close() | |
def add_args(parser): | |
"""Add criterion-specific arguments to the parser.""" | |
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0) | |
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0) | |
parser.add_argument( | |
"--wsc-cross-entropy", | |
action="store_true", | |
help="use cross entropy formulation instead of margin loss", | |
) | |
parser.add_argument( | |
"--save-predictions", metavar="FILE", help="file to save predictions to" | |
) | |
def get_masked_input(self, tokens, mask): | |
masked_tokens = tokens.clone() | |
masked_tokens[mask] = self.task.mask | |
return masked_tokens | |
def get_lprobs(self, model, tokens, mask): | |
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask)) | |
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float) | |
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1) | |
mask = mask.type_as(scores) | |
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1) | |
return scores | |
def get_loss(self, query_lprobs, cand_lprobs): | |
if self.args.wsc_cross_entropy: | |
return F.cross_entropy( | |
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0), | |
query_lprobs.new([0]).long(), | |
) | |
else: | |
return ( | |
-query_lprobs | |
+ self.args.wsc_margin_alpha | |
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0) | |
).sum() | |
def forward(self, model, sample, reduce=True): | |
# compute loss and accuracy | |
loss, nloss = 0.0, 0 | |
ncorrect, nqueries = 0, 0 | |
for i, label in enumerate(sample["labels"]): | |
query_lprobs = self.get_lprobs( | |
model, | |
sample["query_tokens"][i].unsqueeze(0), | |
sample["query_masks"][i].unsqueeze(0), | |
) | |
cand_lprobs = self.get_lprobs( | |
model, | |
sample["candidate_tokens"][i], | |
sample["candidate_masks"][i], | |
) | |
pred = (query_lprobs >= cand_lprobs).all().item() | |
if label is not None: | |
label = 1 if label else 0 | |
ncorrect += 1 if pred == label else 0 | |
nqueries += 1 | |
if label: | |
# only compute a loss for positive instances | |
nloss += 1 | |
loss += self.get_loss(query_lprobs, cand_lprobs) | |
id = sample["id"][i].item() | |
if self.prediction_h is not None: | |
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) | |
if nloss == 0: | |
loss = torch.tensor(0.0, requires_grad=True) | |
sample_size = nqueries if nqueries > 0 else 1 | |
logging_output = { | |
"loss": utils.item(loss.data) if reduce else loss.data, | |
"ntokens": sample["ntokens"], | |
"nsentences": sample["nsentences"], | |
"sample_size": sample_size, | |
"ncorrect": ncorrect, | |
"nqueries": nqueries, | |
} | |
return loss, sample_size, logging_output | |
def aggregate_logging_outputs(logging_outputs): | |
"""Aggregate logging outputs from data parallel training.""" | |
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | |
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) | |
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) | |
agg_output = { | |
"loss": loss_sum / sample_size / math.log(2), | |
"ntokens": ntokens, | |
"nsentences": nsentences, | |
"sample_size": sample_size, | |
} | |
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) | |
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs) | |
if nqueries > 0: | |
agg_output["accuracy"] = ncorrect / float(nqueries) | |
return agg_output | |
class WinograndeCriterion(WSCCriterion): | |
def forward(self, model, sample, reduce=True): | |
# compute loss and accuracy | |
query_lprobs = self.get_lprobs( | |
model, | |
sample["query_tokens"], | |
sample["query_masks"], | |
) | |
cand_lprobs = self.get_lprobs( | |
model, | |
sample["candidate_tokens"], | |
sample["candidate_masks"], | |
) | |
pred = query_lprobs >= cand_lprobs | |
loss = self.get_loss(query_lprobs, cand_lprobs) | |
sample_size = sample["query_tokens"].size(0) | |
ncorrect = pred.sum().item() | |
logging_output = { | |
"loss": utils.item(loss.data) if reduce else loss.data, | |
"ntokens": sample["ntokens"], | |
"nsentences": sample["nsentences"], | |
"sample_size": sample_size, | |
"ncorrect": ncorrect, | |
"nqueries": sample_size, | |
} | |
return loss, sample_size, logging_output | |