Polos-Demo / polos /models /__init__.py
yuwd's picture
update
a005919
raw
history blame
No virus
3.84 kB
# -*- coding: utf-8 -*-
import os
import click
import pandas as pd
import yaml
from torchnlp.download import download_file_maybe_extract
from .estimators import PolosEstimator, QualityEstimator
from .model_base import ModelBase
from .ranking import PolosRanker
str2model = {
"PolosEstimator": PolosEstimator,
"PolosRanker": PolosRanker,
# Model that use source only:
"QualityEstimator": QualityEstimator,
}
def get_cache_folder():
cache_directory = "./.cache/"
if not os.path.exists(cache_directory):
os.makedirs(cache_directory)
return cache_directory
def download_model(model: str, saving_directory: str = None) -> ModelBase:
"""Function that loads pretrained models from AWS.
:param model: Name of the model to be loaded.
:param saving_directory: RELATIVE path to the saving folder (must end with /).
Return:
- Pretrained model.
"""
if saving_directory is None:
saving_directory = get_cache_folder()
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
if os.path.exists(saving_directory + "reprod/reprod.ckpt"):
return saving_directory + "reprod/reprod.ckpt"
models = {"polos" : "https://polos-polaris.s3.ap-northeast-1.amazonaws.com/reprod.zip"}
if os.path.isdir(saving_directory + model):
click.secho(f"{model} is already in cache.", fg="yellow")
if not model.endswith("/"):
model += "/"
elif model not in models.keys():
raise Exception(f"{model} is not a valid Polos model!")
elif models[model].startswith("https://"):
download_file_maybe_extract(models[model], directory=saving_directory)
else:
raise Exception("Something went wrong while dowloading the model!")
if os.path.exists(saving_directory + model + ".zip"):
os.remove(saving_directory + model + ".zip")
click.secho("Download succeeded. Loading model...", fg="yellow")
experiment_folder = saving_directory + "reprod"
checkpoints = [
file for file in os.listdir(experiment_folder) if file.endswith(".ckpt")
]
checkpoint = checkpoints[-1]
checkpoint_path = experiment_folder + "/" + checkpoint
return checkpoint_path
def load_checkpoint(checkpoint: str) -> ModelBase:
"""Function that loads a model from a checkpoint file.
:param checkpoint: Path to the checkpoint file.
Returns:
- Polos Model
"""
if not os.path.exists(checkpoint):
raise Exception(f"{checkpoint} file not found!")
tags_csv_file = "/".join(checkpoint.split("/")[:-1] + ["meta_tags.csv"])
hparam_yaml_file = "/".join(checkpoint.split("/")[:-1] + ["hparams.yaml"])
if os.path.exists(tags_csv_file):
# Uggly convertion from older Lightning checkpoints
tags = pd.read_csv(
tags_csv_file, header=None, index_col=0, squeeze=True
).to_dict()
hparams = {}
for k, v in tags.items():
if isinstance(v, str) and v.replace(".", "", 1).isdigit():
hparams[k] = float(v) if "." in v else int(v)
else:
hparams[k] = v
model = str2model[tags["model"]].load_from_checkpoint(
checkpoint, hparams=hparams
)
elif os.path.exists(hparam_yaml_file):
with open(hparam_yaml_file) as yaml_file:
hparams = yaml.load(yaml_file.read(), Loader=yaml.FullLoader)
model = str2model[hparams["model"]].load_from_checkpoint(
checkpoint, hparams=hparams
)
else:
raise Exception(
"[meta_tags.csv|hparams.yaml is missing from the checkpoint folder."
" Please clean your cache folder (~/.cache/torch/yuigawada/) and try to download the model again."
)
model.eval()
model.freeze()
return model