|
"""Script to run sagemaker training jobs for whisper finetuning jobs.""" |
|
|
|
import logging |
|
import os |
|
from pprint import pprint |
|
|
|
import boto3 |
|
import sagemaker |
|
from sagemaker.huggingface import HuggingFace |
|
|
|
|
|
TEST = True |
|
|
|
|
|
test_sm_instances = { |
|
"ml.g4dn.xlarge": |
|
{ |
|
"num_instances": 1, |
|
"num_gpus": 1 |
|
} |
|
} |
|
|
|
full_sm_instances = { |
|
"ml.g4dn.xlarge": |
|
{ |
|
"num_instances": 1, |
|
"num_gpus": 1 |
|
} |
|
} |
|
|
|
sm_instances = test_sm_instances if TEST else full_sm_instances |
|
|
|
ENTRY_POINT = "run_sm.py" |
|
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh" |
|
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2" |
|
if IMAGE_URI is None: |
|
raise ValueError("IMAGE_URI variable not set, please update script.") |
|
|
|
iam = boto3.client("iam") |
|
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1" |
|
role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"] |
|
_ = sagemaker.Session() |
|
sm_client = boto3.client("sagemaker") |
|
|
|
|
|
def set_creds(): |
|
with open("creds.txt") as f: |
|
creds = f.readlines() |
|
for line in creds: |
|
key, value = line.split("=") |
|
os.environ[key] = value.replace("\n", "") |
|
|
|
|
|
def parse_run_script(): |
|
"""Parse the run script to get the hyperparameters.""" |
|
hyperparameters = {} |
|
with open(RUN_SCRIPT, "r") as f: |
|
for line in f.readlines(): |
|
if line.startswith("python"): |
|
continue |
|
line = line \ |
|
.replace("\\", "") \ |
|
.replace("\t", "") \ |
|
.replace("--", "") \ |
|
.replace(" \n", "") \ |
|
.replace("\n", "") \ |
|
.replace('"', "") |
|
line = line.split("=") |
|
key = str(line[0]) |
|
try: |
|
value = line[1] |
|
except IndexError: |
|
value = "True" |
|
hyperparameters[key] = value |
|
hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"' |
|
return hyperparameters |
|
|
|
|
|
set_creds() |
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if hf_token is None: |
|
raise ValueError("HF_TOKEN environment variable not set") |
|
|
|
env_vars = { |
|
"HF_TOKEN": hf_token, |
|
"EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"), |
|
"EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"), |
|
"WANDB_TOKEN": os.environ.get("WANDB_TOKEN") |
|
} |
|
pprint(env_vars) |
|
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}" |
|
hyperparameters = { |
|
"repo": repo, |
|
"entrypoint": RUN_SCRIPT |
|
} |
|
for sm_instance_name, sm_instance_values in sm_instances.items(): |
|
num_instances: int = \ |
|
int(sm_instance_values["num_instances"]) |
|
num_gpus: int = \ |
|
int(sm_instance_values["num_gpus"]) |
|
try: |
|
|
|
hf_estimator = HuggingFace( |
|
entry_point=ENTRY_POINT, |
|
instance_type=sm_instance_name, |
|
instance_count=num_instances, |
|
role=role, |
|
py_version="py38", |
|
image_uri=IMAGE_URI, |
|
hyperparameters=hyperparameters, |
|
environment=env_vars, |
|
git_config={"repo": repo, "branch": "main"}, |
|
) |
|
hf_estimator.fit() |
|
break |
|
except sm_client.exceptions.ResourceLimitExceeded as e_0: |
|
logging.warning(f"Instance error {e_0}\nRetrying with new instance") |
|
|