|
from sagemaker.huggingface import HuggingFace |
|
|
|
ROLE = ? |
|
|
|
|
|
hyperparameters = { |
|
'epochs': 1, |
|
'per_device_train_batch_size': 32, |
|
'do_train': True, |
|
'model_name_or_path': 'distilbert-base-uncased', |
|
'output_dir': '/opt/ml/checkpoints' |
|
} |
|
|
|
|
|
|
|
huggingface_estimator = HuggingFace( |
|
entry_point='train.py', |
|
source_dir='.', |
|
instance_type='local', |
|
instance_count=1, |
|
checkpoint_s3_uri=f's3://{sess.default_bucket()}/checkpoints', |
|
use_spot_instances=True, |
|
max_wait=3600, |
|
max_run=1000, |
|
role=ROLE, |
|
transformers_version='4.4', |
|
pytorch_version='1.6', |
|
py_version='py36', |
|
hyperparameters=hyperparameters, |
|
) |
|
|
|
|
|
huggingface_estimator.fit( |
|
{ |
|
'train': 's3://sagemaker-us-east-1-558105141721/samples/datasets/imdb/train', |
|
'test': 's3://sagemaker-us-east-1-558105141721/samples/datasets/imdb/test' |
|
} |
|
) |
|
|