|
"""Utils""" |
|
from __future__ import annotations |
|
|
|
import json |
|
from pathlib import Path |
|
from typing import Literal |
|
|
|
from loguru import logger |
|
|
|
|
|
def download_model( |
|
model_name: str, |
|
model_stage: Literal["staging", "production"], |
|
model_dir: str | Path = "model", |
|
) -> Path: |
|
"""Download model from mlflow""" |
|
import mlflow.artifacts |
|
import mlflow.models |
|
from mlflow.client import MlflowClient |
|
|
|
logger.info(f"Looking for model {model_name}/{model_stage}") |
|
|
|
if isinstance(model_dir, str): |
|
model_dir = Path(model_dir) |
|
|
|
client = MlflowClient() |
|
model_versions = client.get_latest_versions(model_name, stages=[model_stage]) |
|
if len(model_versions) != 1: |
|
raise ValueError(f"No model version for {model_name}/{model_stage}") |
|
|
|
artifact_uri = model_versions[0].source |
|
model_version = model_versions[0].version |
|
|
|
logger.info(f"Found version {model_version} for {model_name}/{model_stage}") |
|
|
|
model_path = model_dir / artifact_uri.split("/")[-1] |
|
if model_path.exists(): |
|
logger.info(f"Found model in {model_path}, skipping download") |
|
return model_path |
|
|
|
logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}") |
|
model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir)) |
|
logger.info(f"Succesfully downloaded {model_name}") |
|
|
|
model_info = mlflow.models.get_model_info(model_path) |
|
metadata = model_info.metadata |
|
metadata_path = Path(model_path) / "metadata.json" |
|
logger.info(f"Saving metadata to {metadata_path}") |
|
with open(metadata_path, "w", encoding="utf-8") as file: |
|
json.dump(metadata, file) |
|
|
|
return Path(model_path) |
|
|