Spaces:
Sleeping
Sleeping
import json | |
import os | |
import subprocess | |
import unittest | |
from ast import literal_eval | |
import pytest | |
from parameterized import parameterized_class | |
from . import is_sagemaker_available | |
if is_sagemaker_available(): | |
from sagemaker import Session, TrainingJobAnalytics | |
from sagemaker.huggingface import HuggingFace | |
class SingleNodeTest(unittest.TestCase): | |
def setUp(self): | |
if self.framework == "pytorch": | |
subprocess.run( | |
f"cp ./examples/pytorch/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), | |
encoding="utf-8", | |
check=True, | |
) | |
assert hasattr(self, "env") | |
def create_estimator(self, instance_count=1): | |
# creates estimator | |
return HuggingFace( | |
entry_point=self.script, | |
source_dir=self.env.test_path, | |
role=self.env.role, | |
image_uri=self.env.image_uri, | |
base_job_name=f"{self.env.base_job_name}-single", | |
instance_count=instance_count, | |
instance_type=self.instance_type, | |
debugger_hook_config=False, | |
hyperparameters={**self.env.hyperparameters, "model_name_or_path": self.model_name_or_path}, | |
metric_definitions=self.env.metric_definitions, | |
py_version="py36", | |
) | |
def save_results_as_csv(self, job_name): | |
TrainingJobAnalytics(job_name).export_csv(f"{self.env.test_path}/{job_name}_metrics.csv") | |
def test_glue(self): | |
# create estimator | |
estimator = self.create_estimator() | |
# run training | |
estimator.fit() | |
# result dataframe | |
result_metrics_df = TrainingJobAnalytics(estimator.latest_training_job.name).dataframe() | |
# extract kpis | |
eval_accuracy = list(result_metrics_df[result_metrics_df.metric_name == "eval_accuracy"]["value"]) | |
eval_loss = list(result_metrics_df[result_metrics_df.metric_name == "eval_loss"]["value"]) | |
# get train time from SageMaker job, this includes starting, preprocessing, stopping | |
train_runtime = ( | |
Session().describe_training_job(estimator.latest_training_job.name).get("TrainingTimeInSeconds", 999999) | |
) | |
# assert kpis | |
assert train_runtime <= self.results["train_runtime"] | |
assert all(t >= self.results["eval_accuracy"] for t in eval_accuracy) | |
assert all(t <= self.results["eval_loss"] for t in eval_loss) | |
# dump tests result into json file to share in PR | |
with open(f"{estimator.latest_training_job.name}.json", "w") as outfile: | |
json.dump({"train_time": train_runtime, "eval_accuracy": eval_accuracy, "eval_loss": eval_loss}, outfile) | |