File size: 3,841 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a005919
 
 
 
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
import os

import click
import pandas as pd
import yaml

from torchnlp.download import download_file_maybe_extract

from .estimators import PolosEstimator, QualityEstimator
from .model_base import ModelBase
from .ranking import PolosRanker

str2model = {
    "PolosEstimator": PolosEstimator,
    "PolosRanker": PolosRanker,
    # Model that use source only:
    "QualityEstimator": QualityEstimator,
}

def get_cache_folder():
    cache_directory = "./.cache/"
    if not os.path.exists(cache_directory):
        os.makedirs(cache_directory)
    return cache_directory

def download_model(model: str, saving_directory: str = None) -> ModelBase:
    """Function that loads pretrained models from AWS.
    :param model: Name of the model to be loaded.
    :param saving_directory: RELATIVE path to the saving folder (must end with /).

    Return:
        - Pretrained model.
    """
    if saving_directory is None:
        saving_directory = get_cache_folder()

    if not os.path.exists(saving_directory):
        os.makedirs(saving_directory)
    
    if os.path.exists(saving_directory + "reprod/reprod.ckpt"):
        return saving_directory + "reprod/reprod.ckpt"
    
    models = {"polos" : "https://polos-polaris.s3.ap-northeast-1.amazonaws.com/reprod.zip"}

    if os.path.isdir(saving_directory + model):
        click.secho(f"{model} is already in cache.", fg="yellow")
        if not model.endswith("/"):
            model += "/"

    elif model not in models.keys():
        raise Exception(f"{model} is not a valid Polos model!")

    elif models[model].startswith("https://"):
        download_file_maybe_extract(models[model], directory=saving_directory)

    else:
        raise Exception("Something went wrong while dowloading the model!")

    if os.path.exists(saving_directory + model + ".zip"):
        os.remove(saving_directory + model + ".zip")

    click.secho("Download succeeded. Loading model...", fg="yellow")
    experiment_folder = saving_directory + "reprod"
    checkpoints = [
        file for file in os.listdir(experiment_folder) if file.endswith(".ckpt")
    ]
    checkpoint = checkpoints[-1]
    checkpoint_path = experiment_folder + "/" + checkpoint
    return checkpoint_path


def load_checkpoint(checkpoint: str) -> ModelBase:
    """Function that loads a model from a checkpoint file.
    :param checkpoint: Path to the checkpoint file.

    Returns:
        - Polos Model
    """
    if not os.path.exists(checkpoint):
        raise Exception(f"{checkpoint} file not found!")

    tags_csv_file = "/".join(checkpoint.split("/")[:-1] + ["meta_tags.csv"])
    hparam_yaml_file = "/".join(checkpoint.split("/")[:-1] + ["hparams.yaml"])

    if os.path.exists(tags_csv_file):
        # Uggly convertion from older Lightning checkpoints
        tags = pd.read_csv(
            tags_csv_file, header=None, index_col=0, squeeze=True
        ).to_dict()
        hparams = {}
        for k, v in tags.items():
            if isinstance(v, str) and v.replace(".", "", 1).isdigit():
                hparams[k] = float(v) if "." in v else int(v)
            else:
                hparams[k] = v
        model = str2model[tags["model"]].load_from_checkpoint(
            checkpoint, hparams=hparams
        )
    elif os.path.exists(hparam_yaml_file):
        with open(hparam_yaml_file) as yaml_file:
            hparams = yaml.load(yaml_file.read(), Loader=yaml.FullLoader)
        model = str2model[hparams["model"]].load_from_checkpoint(
            checkpoint, hparams=hparams
        )
    else:
        raise Exception(
            "[meta_tags.csv|hparams.yaml is missing from the checkpoint folder."
            " Please clean your cache folder (~/.cache/torch/yuigawada/) and try to download the model again."
        )

    model.eval()
    model.freeze()
    return model