File size: 869 Bytes
6fd648a
e3c7b5a
 
6fd648a
 
 
 
 
 
64a6414
 
6fd648a
e3c7b5a
6fd648a
e3c7b5a
 
6fd648a
e3c7b5a
6fd648a
 
64a6414
6fd648a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import argparse
from idiomify.models import Idiomifier, Pipeline
from idiomify.fetchers import fetch_config, fetch_idiomifier
from transformers import BartTokenizer


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--src", type=str,
                        default="If there's any good to loosing my job,"
                                " it's that I'll now be able to go to school full-time and finish my degree earlier.")
    args = parser.parse_args()
    config = fetch_config()['infer']
    config.update(vars(args))
    model = fetch_idiomifier(config['ver'])
    model.eval()  # this is crucial
    tokenizer = BartTokenizer.from_pretrained(config['bart'])
    idiomifier = Pipeline(model, tokenizer)
    src = config['src']
    tgt = idiomifier(src=config['src'])
    print(src, "\n->", tgt)


if __name__ == '__main__':
    main()