Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
from argparse import ArgumentParser | |
from dataclasses import dataclass, field | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.dataclass.utils import gen_parser_from_dataclass | |
class A(FairseqDataclass): | |
data: str = field(default="test", metadata={"help": "the data input"}) | |
num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) | |
class B(FairseqDataclass): | |
bar: A = field(default=A()) | |
foo: int = field(default=0, metadata={"help": "not a bar"}) | |
class D(FairseqDataclass): | |
arch: A = field(default=A()) | |
foo: int = field(default=0, metadata={"help": "not a bar"}) | |
class C(FairseqDataclass): | |
data: str = field(default="test", metadata={"help": "root level data input"}) | |
encoder: D = field(default=D()) | |
decoder: A = field(default=A()) | |
lr: int = field(default=0, metadata={"help": "learning rate"}) | |
class TestDataclassUtils(unittest.TestCase): | |
def test_argparse_convert_basic(self): | |
parser = ArgumentParser() | |
gen_parser_from_dataclass(parser, A(), True) | |
args = parser.parse_args(["--num-layers", '10', "the/data/path"]) | |
self.assertEqual(args.num_layers, 10) | |
self.assertEqual(args.data, "the/data/path") | |
def test_argparse_recursive(self): | |
parser = ArgumentParser() | |
gen_parser_from_dataclass(parser, B(), True) | |
args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) | |
self.assertEqual(args.num_layers, 10) | |
self.assertEqual(args.foo, 10) | |
self.assertEqual(args.data, "the/data/path") | |
def test_argparse_recursive_prefixing(self): | |
self.maxDiff = None | |
parser = ArgumentParser() | |
gen_parser_from_dataclass(parser, C(), True, "") | |
args = parser.parse_args( | |
[ | |
"--encoder-arch-data", | |
"ENCODER_ARCH_DATA", | |
"--encoder-arch-num-layers", | |
"10", | |
"--encoder-foo", | |
"10", | |
"--decoder-data", | |
"DECODER_DATA", | |
"--decoder-num-layers", | |
"10", | |
"--lr", | |
"10", | |
"the/data/path", | |
] | |
) | |
self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") | |
self.assertEqual(args.encoder_arch_num_layers, 10) | |
self.assertEqual(args.encoder_foo, 10) | |
self.assertEqual(args.decoder_data, "DECODER_DATA") | |
self.assertEqual(args.decoder_num_layers, 10) | |
self.assertEqual(args.lr, 10) | |
self.assertEqual(args.data, "the/data/path") | |
if __name__ == "__main__": | |
unittest.main() | |