Spaces:
Runtime error
Runtime error
# Copyright 2022 The OFA-Sys Team. | |
# All rights reserved. | |
# This source code is licensed under the Apache 2.0 license | |
# found in the LICENSE file in the root directory. | |
from dataclasses import dataclass, field | |
import logging | |
import os | |
import math | |
import torch | |
from typing import Dict, Optional | |
from fairseq import search | |
from fairseq.data import FairseqDataset, iterators | |
from fairseq.optim.amp_optimizer import AMPOptimizer | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.tasks import FairseqTask, register_task | |
from omegaconf import DictConfig | |
logger = logging.getLogger(__name__) | |
class OFAConfig(FairseqDataclass): | |
data: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "comma separated path to data list, will be iterated upon during epochs " | |
"in round-robin manner; valid data are always in the last" | |
}, | |
) | |
selected_cols: Optional[str] = field( | |
default=None, | |
metadata={"help": "selected cols"}, | |
) | |
bpe: Optional[str] = field( | |
default='gpt2', | |
metadata={"help": "which bpe to use"}, | |
) | |
bpe_dir: Optional[str] = field( | |
default=None, | |
metadata={"help": "bpe dir"}, | |
) | |
max_source_positions: int = field( | |
default=1024, metadata={"help": "max number of tokens in the source sequence"} | |
) | |
max_target_positions: int = field( | |
default=1024, metadata={"help": "max number of tokens in the target sequence"} | |
) | |
max_src_length: int = field( | |
default=128, metadata={"help": "the maximum src sequence length"} | |
) | |
max_tgt_length: int = field( | |
default=30, metadata={"help": "the maximum target sequence length"} | |
) | |
code_dict_size: int = field( | |
default=8192, metadata={"help": "code dict size"} | |
) | |
patch_image_size: int = field( | |
default=480, metadata={"help": "patch image size"} | |
) | |
orig_patch_image_size: int = field( | |
default=256, metadata={"help": "patch image size"} | |
) | |
num_bins: int = field( | |
default=1000, metadata={"help": "number of quantization bins"} | |
) | |
imagenet_default_mean_and_std: bool = field( | |
default=False, | |
metadata={"help": "imagenet normalize"}, | |
) | |
constraint_range: Optional[str] = field( | |
default=None, | |
metadata={"help": "constraint range"} | |
) | |
class OFATask(FairseqTask): | |
def __init__(self, cfg: OFAConfig, src_dict, tgt_dict): | |
super().__init__(cfg) | |
self.src_dict = src_dict | |
self.tgt_dict = tgt_dict | |
def setup_task(cls, cfg: DictConfig, **kwargs): | |
"""Setup the task.""" | |
# load dictionaries | |
src_dict = cls.load_dictionary( | |
os.path.join(cfg.bpe_dir, "dict.txt") | |
) | |
tgt_dict = cls.load_dictionary( | |
os.path.join(cfg.bpe_dir, "dict.txt") | |
) | |
src_dict.add_symbol("<mask>") | |
tgt_dict.add_symbol("<mask>") | |
for i in range(cfg.code_dict_size): | |
src_dict.add_symbol("<code_{}>".format(i)) | |
tgt_dict.add_symbol("<code_{}>".format(i)) | |
# quantization | |
for i in range(cfg.num_bins): | |
src_dict.add_symbol("<bin_{}>".format(i)) | |
tgt_dict.add_symbol("<bin_{}>".format(i)) | |
logger.info("source dictionary: {} types".format(len(src_dict))) | |
logger.info("target dictionary: {} types".format(len(tgt_dict))) | |
return cls(cfg, src_dict, tgt_dict) | |
def get_batch_iterator( | |
self, | |
dataset, | |
max_tokens=None, | |
max_sentences=None, | |
max_positions=None, | |
ignore_invalid_inputs=False, | |
required_batch_size_multiple=1, | |
seed=1, | |
num_shards=1, | |
shard_id=0, | |
num_workers=0, | |
epoch=1, | |
data_buffer_size=0, | |
disable_iterator_cache=False, | |
): | |
assert isinstance(dataset, FairseqDataset) | |
# initialize the dataset with the correct starting epoch | |
dataset.set_epoch(epoch) | |
# create mini-batches with given size constraints | |
batch_sampler = [ | |
[j for j in range(i, min(i + max_sentences, len(dataset)))] | |
for i in range(0, len(dataset), max_sentences) | |
] | |
total_row_count = dataset.dataset.get_total_row_count() | |
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) | |
if len(batch_sampler) < num_batches: | |
batch_sampler.append([]) | |
# return a reusable, sharded iterator | |
epoch_iter = iterators.EpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=batch_sampler, | |
seed=seed, | |
num_shards=1, | |
shard_id=0, | |
num_workers=num_workers, | |
epoch=epoch, | |
buffer_size=data_buffer_size | |
) | |
return epoch_iter | |
def build_model(self, cfg: FairseqDataclass): | |
model = super().build_model(cfg) | |
if self.cfg.bpe == 'bert': | |
bpe_dict = { | |
"_name": "bert", | |
"bpe_vocab_file": os.path.join(self.cfg.bpe_dir, "vocab.txt"), | |
"bpe_cased": False | |
} | |
bpe_dict = DictConfig(bpe_dict) | |
self.bpe = self.build_bpe(bpe_dict) | |
else: | |
bpe_dict = { | |
"_name": "gpt2", | |
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"), | |
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe") | |
} | |
bpe_dict = DictConfig(bpe_dict) | |
self.bpe = self.build_bpe(bpe_dict) | |
return model | |
def build_generator( | |
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, | |
): | |
""" | |
Build a :class:`~fairseq.SequenceGenerator` instance for this | |
task. | |
Args: | |
models (List[~fairseq.models.FairseqModel]): ensemble of models | |
args (fairseq.dataclass.configs.GenerationConfig): | |
configuration object (dataclass) for generation | |
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass | |
through to SequenceGenerator | |
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): | |
If provided, this function constrains the beam search to | |
allowed tokens only at each step. The provided function | |
should take 2 arguments: the batch ID (`batch_id: int`) | |
and a unidimensional tensor of token ids (`inputs_ids: | |
torch.Tensor`). It has to return a `List[int]` with the | |
allowed tokens for the next generation step conditioned | |
on the previously generated tokens (`inputs_ids`) and | |
the batch ID (`batch_id`). This argument is useful for | |
constrained generation conditioned on the prefix, as | |
described in "Autoregressive Entity Retrieval" | |
(https://arxiv.org/abs/2010.00904) and | |
https://github.com/facebookresearch/GENRE. | |
""" | |
if getattr(args, "score_reference", False): | |
from fairseq.sequence_scorer import SequenceScorer | |
return SequenceScorer( | |
self.target_dictionary, | |
compute_alignment=getattr(args, "print_alignment", False), | |
) | |
from fairseq.sequence_generator import ( | |
# SequenceGenerator, | |
SequenceGeneratorWithAlignment, | |
) | |
from models.sequence_generator import SequenceGenerator | |
# Choose search strategy. Defaults to Beam Search. | |
sampling = getattr(args, "sampling", False) | |
sampling_topk = getattr(args, "sampling_topk", -1) | |
sampling_topp = getattr(args, "sampling_topp", -1.0) | |
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) | |
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) | |
match_source_len = getattr(args, "match_source_len", False) | |
diversity_rate = getattr(args, "diversity_rate", -1) | |
constrained = getattr(args, "constraints", False) | |
if prefix_allowed_tokens_fn is None: | |
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) | |
if ( | |
sum( | |
int(cond) | |
for cond in [ | |
sampling, | |
diverse_beam_groups > 0, | |
match_source_len, | |
diversity_rate > 0, | |
] | |
) | |
> 1 | |
): | |
raise ValueError("Provided Search parameters are mutually exclusive.") | |
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" | |
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" | |
if sampling: | |
search_strategy = search.Sampling( | |
self.target_dictionary, sampling_topk, sampling_topp | |
) | |
elif diverse_beam_groups > 0: | |
search_strategy = search.DiverseBeamSearch( | |
self.target_dictionary, diverse_beam_groups, diverse_beam_strength | |
) | |
elif match_source_len: | |
# this is useful for tagging applications where the output | |
# length should match the input length, so we hardcode the | |
# length constraints for simplicity | |
search_strategy = search.LengthConstrainedBeamSearch( | |
self.target_dictionary, | |
min_len_a=1, | |
min_len_b=0, | |
max_len_a=1, | |
max_len_b=0, | |
) | |
elif diversity_rate > -1: | |
search_strategy = search.DiverseSiblingsSearch( | |
self.target_dictionary, diversity_rate | |
) | |
elif constrained: | |
search_strategy = search.LexicallyConstrainedBeamSearch( | |
self.target_dictionary, args.constraints | |
) | |
elif prefix_allowed_tokens_fn: | |
search_strategy = search.PrefixConstrainedBeamSearch( | |
self.target_dictionary, prefix_allowed_tokens_fn | |
) | |
else: | |
search_strategy = search.BeamSearch(self.target_dictionary) | |
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} | |
if seq_gen_cls is None: | |
if getattr(args, "print_alignment", False): | |
seq_gen_cls = SequenceGeneratorWithAlignment | |
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment | |
else: | |
seq_gen_cls = SequenceGenerator | |
return seq_gen_cls( | |
models, | |
self.target_dictionary, | |
beam_size=getattr(args, "beam", 5), | |
max_len_a=getattr(args, "max_len_a", 0), | |
max_len_b=getattr(args, "max_len_b", 200), | |
min_len=getattr(args, "min_len", 1), | |
normalize_scores=(not getattr(args, "unnormalized", False)), | |
len_penalty=getattr(args, "lenpen", 1), | |
unk_penalty=getattr(args, "unkpen", 0), | |
temperature=getattr(args, "temperature", 1.0), | |
match_source_len=getattr(args, "match_source_len", False), | |
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
search_strategy=search_strategy, | |
constraint_range=self.cfg.constraint_range, | |
**extra_gen_cls_kwargs, | |
) | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs | |
): | |
""" | |
Do forward and backward, and return the loss as computed by *criterion* | |
for the given *model* and *sample*. | |
Args: | |
sample (dict): the mini-batch. The format is defined by the | |
:class:`~fairseq.data.FairseqDataset`. | |
model (~fairseq.models.BaseFairseqModel): the model | |
criterion (~fairseq.criterions.FairseqCriterion): the criterion | |
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer | |
update_num (int): the current update | |
ignore_grad (bool): multiply loss by 0 if this is set to True | |
Returns: | |
tuple: | |
- the loss | |
- the sample size, which is used as the denominator for the | |
gradient | |
- logging outputs to display while training | |
""" | |
model.train() | |
model.set_num_updates(update_num) | |
with torch.autograd.profiler.record_function("forward"): | |
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): | |
loss, sample_size, logging_output = criterion(model, sample, update_num=update_num) | |
if ignore_grad: | |
loss *= 0 | |
with torch.autograd.profiler.record_function("backward"): | |
optimizer.backward(loss) | |
return loss, sample_size, logging_output | |
def max_positions(self): | |
"""Return the max sentence length allowed by the task.""" | |
return (self.cfg.max_source_positions, self.cfg.max_target_positions) | |
def source_dictionary(self): | |
"""Return the source :class:`~fairseq.data.Dictionary`.""" | |
return self.src_dict | |
def target_dictionary(self): | |
"""Return the target :class:`~fairseq.data.Dictionary`.""" | |
return self.tgt_dict | |