File size: 7,537 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq import options


def get_reranking_parser(default_task="translation"):
    parser = options.get_parser("Generation and reranking", default_task)
    add_reranking_args(parser)
    return parser


def get_tuning_parser(default_task="translation"):
    parser = options.get_parser("Reranking tuning", default_task)
    add_reranking_args(parser)
    add_tuning_args(parser)
    return parser


def add_reranking_args(parser):
    group = parser.add_argument_group("Reranking")
    # fmt: off
    group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True,
                       help='path to first model or ensemble of models for rescoring')
    group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False,
                       help='path to second model or ensemble of models for rescoring')
    group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10,
                       help='the number of candidate hypothesis to rescore')
    group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128,
                       help='batch size for generating the nbest list')
    group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'],
                       help='data subset to generate (train, valid, test)')
    group.add_argument('--gen-model', default=None, metavar='FILE',
                       help='the model to generate translations')
    group.add_argument('-b1', '--backwards1', action='store_true',
                       help='whether or not the first model group is backwards')
    group.add_argument('-b2', '--backwards2', action='store_true',
                       help='whether or not the second model group is backwards')
    group.add_argument('-a', '--weight1', default=1, nargs='+', type=float,
                       help='the weight(s) of the first model')
    group.add_argument('-b', '--weight2', default=1, nargs='+', type=float,
                       help='the weight(s) of the second model, or the gen model if using nbest from interactive.py')
    group.add_argument('-c', '--weight3', default=1, nargs='+', type=float,
                       help='the weight(s) of the third model')

    # lm arguments
    group.add_argument('-lm', '--language-model', default=None, metavar='FILE',
                       help='language model for target language to rescore translations')
    group.add_argument('--lm-dict', default=None, metavar='FILE',
                       help='the dict of the language model for the target language')
    group.add_argument('--lm-name', default=None,
                       help='the name of the language model for the target language')
    group.add_argument('--lm-bpe-code', default=None, metavar='FILE',
                       help='the bpe code for the language model for the target language')
    group.add_argument('--data-dir-name', default=None,
                       help='name of data directory')
    group.add_argument('--lenpen', default=1, nargs='+', type=float,
                       help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
    group.add_argument('--score-dict-dir', default=None,
                       help='the directory with dictionaries for the scoring models')
    group.add_argument('--right-to-left1', action='store_true',
                       help='whether the first model group is a right to left model')
    group.add_argument('--right-to-left2', action='store_true',
                       help='whether the second model group is a right to left model')
    group.add_argument('--post-process', '--remove-bpe', default='@@ ',
                       help='the bpe symbol, used for the bitext and LM')
    group.add_argument('--prefix-len', default=None, type=int,
                       help='the length of the target prefix to use in rescoring (in terms of words wo bpe)')
    group.add_argument('--sampling', action='store_true',
                       help='use sampling instead of beam search for generating n best list')
    group.add_argument('--diff-bpe', action='store_true',
                       help='bpe for rescoring and nbest list not the same')
    group.add_argument('--rescore-bpe-code', default=None,
                       help='bpe code for rescoring models')
    group.add_argument('--nbest-list', default=None,
                       help='use predefined nbest list in interactive.py format')
    group.add_argument('--write-hypos', default=None,
                       help='filename prefix to write hypos to')
    group.add_argument('--ref-translation', default=None,
                       help='reference translation to use with nbest list from interactive.py')
    group.add_argument('--backwards-score-dict-dir', default=None,
                       help='the directory with dictionaries for the backwards model,'
                            'if None then it is assumed the fw and backwards models share dictionaries')

    # extra scaling args
    group.add_argument('--gen-model-name', default=None,
                       help='the name of the models that generated the nbest list')
    group.add_argument('--model1-name', default=None,
                       help='the name of the set for model1 group ')
    group.add_argument('--model2-name', default=None,
                       help='the name of the set for model2 group')
    group.add_argument('--shard-id', default=0, type=int,
                       help='the id of the shard to generate')
    group.add_argument('--num-shards', default=1, type=int,
                       help='the number of shards to generate across')
    group.add_argument('--all-shards', action='store_true',
                       help='use all shards')
    group.add_argument('--target-prefix-frac', default=None, type=float,
                       help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)')
    group.add_argument('--source-prefix-frac', default=None, type=float,
                       help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)')
    group.add_argument('--normalize', action='store_true',
                       help='whether to normalize by src and target len')
    # fmt: on
    return group


def add_tuning_args(parser):
    group = parser.add_argument_group("Tuning")

    group.add_argument(
        "--lower-bound",
        default=[-0.7],
        nargs="+",
        type=float,
        help="lower bound of search space",
    )
    group.add_argument(
        "--upper-bound",
        default=[3],
        nargs="+",
        type=float,
        help="upper bound of search space",
    )
    group.add_argument(
        "--tune-param",
        default=["lenpen"],
        nargs="+",
        choices=["lenpen", "weight1", "weight2", "weight3"],
        help="the parameter(s) to tune",
    )
    group.add_argument(
        "--tune-subset",
        default="valid",
        choices=["valid", "test", "train"],
        help="the subset to tune on ",
    )
    group.add_argument(
        "--num-trials",
        default=1000,
        type=int,
        help="number of trials to do for random search",
    )
    group.add_argument(
        "--share-weights", action="store_true", help="share weight2 and weight 3"
    )
    return group