wiki-vae / start_training.py
Fraser's picture
start sagemaker code
caac576
raw
history blame
1.08 kB
from sagemaker.huggingface import HuggingFace
ROLE = ?
# hyperparameters, which are passed into the training job
hyperparameters = {
'epochs': 1,
'per_device_train_batch_size': 32,
'do_train': True,
'model_name_or_path': 'distilbert-base-uncased',
'output_dir': '/opt/ml/checkpoints'
}
# create the Estimator
huggingface_estimator = HuggingFace(
entry_point='train.py',
source_dir='.',
instance_type='local', # 'ml.p3.2xlarge',
instance_count=1,
checkpoint_s3_uri=f's3://{sess.default_bucket()}/checkpoints',
use_spot_instances=True,
max_wait=3600, # This should be equal to or greater than max_run in seconds'
max_run=1000,
role=ROLE,
transformers_version='4.4',
pytorch_version='1.6',
py_version='py36',
hyperparameters=hyperparameters,
)
huggingface_estimator.fit(
{
'train': 's3://sagemaker-us-east-1-558105141721/samples/datasets/imdb/train',
'test': 's3://sagemaker-us-east-1-558105141721/samples/datasets/imdb/test'
}
)