File size: 2,025 Bytes
b29fd2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
import re
import tempfile
from urllib.parse import urlparse
import joblib
import numpy as np
from datasets import load_dataset
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
load_dotenv()
TEST_DATA_ID = os.environ.get("TEST_DATA_ID", None)
def relative_error_loss(predicted_age, true_age):
true_age_safe = np.where(true_age == 0, 0.1, true_age)
relative_error = np.abs((true_age - predicted_age) / true_age_safe)
return np.mean(relative_error)
def parse_model_url(model_url: str):
parsed = urlparse(model_url)
path_parts = parsed.path.strip("/").split("/")
if len(path_parts) < 5:
raise ValueError("Unexpected URL format. Make sure it's a Hub URL with /resolve/main/ or /blob/main/")
repo_id = "/".join(path_parts[:2])
revision = path_parts[3]
filename = path_parts[4]
if not filename.endswith(".joblib"):
raise ValueError("The file must be a .joblib file.")
return repo_id, revision, filename
def evaluate_model(model_url: str) -> float:
if not model_url.startswith("https://huggingface.co/"):
raise ValueError("Invalid model URL. Must start with https://huggingface.co/")
repo_id, revision, filename = parse_model_url(model_url)
ds_test_meta = load_dataset(TEST_DATA_ID, "meta")
ds_test_main = load_dataset(TEST_DATA_ID, "main")
X_test = ds_test_main["test"].to_pandas().drop(columns=["SampleID"])
X_test = X_test.values.astype(np.float32)
y_test = np.array(ds_test_meta["test"]["Age"])
with tempfile.TemporaryDirectory() as tmpdir:
local_model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=tmpdir)
try:
model = joblib.load(local_model_path)
except Exception as e:
raise ValueError(f"Failed to load the model. Please check the .joblib file. Error: {e}")
predicted_age = model.predict(X_test)
score = relative_error_loss(predicted_age, y_test)
return score
|