marinone94's picture
use python script to clone repo and trigger train
e417b0c
raw
history blame
3.59 kB
"""Script to run sagemaker training jobs for whisper finetuning jobs."""
import logging
import os
from pprint import pprint
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace
TEST = True
test_sm_instances = {
"ml.g4dn.xlarge":
{
"num_instances": 1,
"num_gpus": 1
}
}
full_sm_instances = {
"ml.g4dn.xlarge":
{
"num_instances": 1,
"num_gpus": 1
}
}
sm_instances = test_sm_instances if TEST else full_sm_instances
ENTRY_POINT = "run_sm.py"
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
if IMAGE_URI is None:
raise ValueError("IMAGE_URI variable not set, please update script.")
iam = boto3.client("iam")
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1"
role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"]
_ = sagemaker.Session() # not sure if this is necessary
sm_client = boto3.client("sagemaker")
def set_creds():
with open("creds.txt") as f:
creds = f.readlines()
for line in creds:
key, value = line.split("=")
os.environ[key] = value.replace("\n", "")
def parse_run_script():
"""Parse the run script to get the hyperparameters."""
hyperparameters = {}
with open(RUN_SCRIPT, "r") as f:
for line in f.readlines():
if line.startswith("python"):
continue
line = line \
.replace("\\", "") \
.replace("\t", "") \
.replace("--", "") \
.replace(" \n", "") \
.replace("\n", "") \
.replace('"', "")
line = line.split("=")
key = str(line[0])
try:
value = line[1]
except IndexError:
value = "True"
hyperparameters[key] = value
hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"'
return hyperparameters
set_creds()
# hyperparameters = parse_run_script()
# pprint(hyperparameters)
hf_token = os.environ.get("HF_TOKEN")
if hf_token is None:
raise ValueError("HF_TOKEN environment variable not set")
env_vars = {
"HF_TOKEN": hf_token,
"EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"),
"EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"),
"WANDB_TOKEN": os.environ.get("WANDB_TOKEN")
}
pprint(env_vars)
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}"
hyperparameters = {
"repo": repo,
"entrypoint": RUN_SCRIPT
}
for sm_instance_name, sm_instance_values in sm_instances.items():
num_instances: int = \
int(sm_instance_values["num_instances"])
num_gpus: int = \
int(sm_instance_values["num_gpus"])
try:
# instantiate and fit the sm Estimator
hf_estimator = HuggingFace(
entry_point=ENTRY_POINT,
instance_type=sm_instance_name,
instance_count=num_instances,
role=role,
py_version="py38",
image_uri=IMAGE_URI,
hyperparameters=hyperparameters,
environment=env_vars,
git_config={"repo": repo, "branch": "main"},
)
hf_estimator.fit()
break
except sm_client.exceptions.ResourceLimitExceeded as e_0:
logging.warning(f"Instance error {e_0}\nRetrying with new instance")