Julien Simon commited on
Commit
a5a9972
1 Parent(s): 8eb4464

Initial version

Browse files
code/Sentiment analysis with Hugging Face and SageMaker.ipynb ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Training and deploying Hugging Face models on Amazon SageMaker\n",
8
+ "\n",
9
+ "* https://huggingface.co/distilbert-base-uncased\n",
10
+ "* https://huggingface.co/transformers/model_doc/distilbert.html\n",
11
+ "* https://huggingface.co/datasets/generated_reviews_enth"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "# 1 - Setup"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {
25
+ "scrolled": true
26
+ },
27
+ "outputs": [],
28
+ "source": [
29
+ "!pip -q install sagemaker \"transformers>=4.4.2\" \"datasets[s3]==1.5.0\" widgetsnbextension ipywidgets huggingface_hub --upgrade"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash\n",
39
+ "!apt-get install git-lfs"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "import sagemaker\n",
49
+ "import transformers\n",
50
+ "import datasets\n",
51
+ "\n",
52
+ "print(sagemaker.__version__)\n",
53
+ "print(transformers.__version__)\n",
54
+ "print(datasets.__version__)"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "# 2 - Preprocessing"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "from datasets import load_dataset\n",
71
+ "\n",
72
+ "train_dataset, valid_dataset = load_dataset('generated_reviews_enth', split=['train', 'validation'])\n",
73
+ "\n",
74
+ "print(train_dataset.shape)\n",
75
+ "print(valid_dataset.shape)"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "train_dataset[0]"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "def map_stars_to_sentiment(row):\n",
94
+ " return {\n",
95
+ " 'labels': 1 if row['review_star'] >= 4 else 0\n",
96
+ " }"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "train_dataset = train_dataset.map(map_stars_to_sentiment)\n",
106
+ "valid_dataset = valid_dataset.map(map_stars_to_sentiment)"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "train_dataset[0]"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "train_dataset = train_dataset.flatten()\n",
125
+ "valid_dataset = valid_dataset.flatten()"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "train_dataset[0]"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "train_dataset = train_dataset.remove_columns(['correct', 'translation.th', 'review_star'])\n",
144
+ "valid_dataset = valid_dataset.remove_columns(['correct', 'translation.th', 'review_star'])"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "train_dataset = train_dataset.rename_column('translation.en', 'text')\n",
154
+ "valid_dataset = valid_dataset.rename_column('translation.en', 'text')"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "train_dataset[0]"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "## Tokenize"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "from transformers import AutoTokenizer\n",
180
+ "\n",
181
+ "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')\n",
182
+ "\n",
183
+ "def tokenize(batch):\n",
184
+ " return tokenizer(batch['text'], padding='max_length', truncation=True)"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "valid_dataset = valid_dataset.map(tokenize, batched=True, batch_size=len(valid_dataset))"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "import json\n",
212
+ "\n",
213
+ "json.dumps(train_dataset[0])"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "train_dataset = train_dataset.remove_columns(['text'])\n",
223
+ "valid_dataset = valid_dataset.remove_columns(['text'])"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "metadata": {},
229
+ "source": [
230
+ "# 3 - Upload data to S3"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "from datasets.filesystems import S3FileSystem\n",
240
+ "\n",
241
+ "s3 = S3FileSystem() \n",
242
+ "\n",
243
+ "s3_prefix = 'hugging-face/sentiment-analysis'\n",
244
+ "bucket = sagemaker.Session().default_bucket()\n",
245
+ "\n",
246
+ "train_input_path = 's3://{}/{}/training'.format(bucket, s3_prefix)\n",
247
+ "train_dataset.save_to_disk(train_input_path, fs=s3)\n",
248
+ "\n",
249
+ "valid_input_path = 's3://{}/{}/validation'.format(bucket, s3_prefix)\n",
250
+ "valid_dataset.save_to_disk(valid_input_path, fs=s3)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": [
259
+ "print(train_input_path)\n",
260
+ "print(valid_input_path)"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "metadata": {},
266
+ "source": [
267
+ "# 4 - Fine-tune a Hugging Face model on SageMaker"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "!pygmentize train.py"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": [
285
+ "hyperparameters={\n",
286
+ " 'epochs': 1,\n",
287
+ " 'train-batch_size': 32,\n",
288
+ " 'model-name':'distilbert-base-uncased'\n",
289
+ "}"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "from sagemaker.huggingface import HuggingFace\n",
299
+ "\n",
300
+ "huggingface_estimator = HuggingFace(\n",
301
+ " role=sagemaker.get_execution_role(),\n",
302
+ " # Fine-tuning script\n",
303
+ " entry_point='train.py',\n",
304
+ " hyperparameters=hyperparameters,\n",
305
+ " # Infrastructure\n",
306
+ " transformers_version='4.6.1',\n",
307
+ " pytorch_version='1.7.1',\n",
308
+ " py_version='py36',\n",
309
+ " instance_type='ml.p3.2xlarge', # 1 GPUs, $4.131/hour in eu-west-1\n",
310
+ " instance_count=1,\n",
311
+ " # Enable spot instances\n",
312
+ " use_spot_instances=True, # 70% discount is typical\n",
313
+ " max_run = 3600,\n",
314
+ " max_wait = 7200\n",
315
+ ")"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {
322
+ "scrolled": true
323
+ },
324
+ "outputs": [],
325
+ "source": [
326
+ "huggingface_estimator.fit({'train': train_input_path, 'valid': valid_input_path})"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "markdown",
331
+ "metadata": {},
332
+ "source": [
333
+ "# 5 - Deploy the model on SageMaker"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "huggingface_predictor = huggingface_estimator.deploy(\n",
343
+ " initial_instance_count=1,\n",
344
+ " instance_type='ml.m5.xlarge')"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "test_data = {\n",
354
+ " \"inputs\": \"This is a very nice camera, I'm super happy with it.\"\n",
355
+ "}"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {},
362
+ "outputs": [],
363
+ "source": [
364
+ "prediction = huggingface_predictor.predict(test_data)\n",
365
+ "print(prediction)"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "metadata": {},
372
+ "outputs": [],
373
+ "source": [
374
+ "test_data = {\n",
375
+ " \"inputs\": \"Terrible purchase, I want my money back!\"\n",
376
+ "}"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "prediction = huggingface_predictor.predict(test_data)\n",
386
+ "print(prediction)"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": null,
392
+ "metadata": {},
393
+ "outputs": [],
394
+ "source": [
395
+ "huggingface_predictor.delete_endpoint()"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "markdown",
400
+ "metadata": {},
401
+ "source": [
402
+ "# 6 - Push our model to the Hugging Face hub"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "# In a terminal, login to the Hub with 'huggingface-cli login' and your hub credentials"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "markdown",
416
+ "metadata": {},
417
+ "source": [
418
+ "## Create a new repo on the Hugging Face hub"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "repo_name='reviews-sentiment-analysis'"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": [
436
+ "%%sh -s $repo_name\n",
437
+ "huggingface-cli repo create -y $1\n",
438
+ "git clone https://huggingface.co/juliensimon/$1"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "metadata": {},
444
+ "source": [
445
+ "## Extract our model and push files to our hub repo"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": null,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "%%sh -s $huggingface_estimator.model_data $repo_name\n",
455
+ "aws s3 cp $1 .\n",
456
+ "tar xvz -C $2 -f model.tar.gz"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": [
465
+ "%%sh -s $repo_name\n",
466
+ "cd $1\n",
467
+ "git add .\n",
468
+ "git commit -m 'Initial version'\n",
469
+ "git push"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "markdown",
474
+ "metadata": {},
475
+ "source": [
476
+ "## Grab our model from the hub and work locally"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": null,
482
+ "metadata": {},
483
+ "outputs": [],
484
+ "source": [
485
+ "# With the Auto* API\n",
486
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification \n",
487
+ "\n",
488
+ "tokenizer = AutoTokenizer.from_pretrained('juliensimon/'+repo_name)\n",
489
+ "model = AutoModelForSequenceClassification.from_pretrained('juliensimon/'+repo_name)\n",
490
+ "\n",
491
+ "# With the pipeline API\n",
492
+ "from transformers import pipeline\n",
493
+ "\n",
494
+ "classifier = pipeline('sentiment-analysis', model='juliensimon/'+repo_name)"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "metadata": {},
501
+ "outputs": [],
502
+ "source": [
503
+ "classifier(\"This is a very nice camera, I'm super happy with it.\")"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "metadata": {},
510
+ "outputs": [],
511
+ "source": [
512
+ "classifier(\"Terrible purchase, I want my money back!\")"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "markdown",
517
+ "metadata": {},
518
+ "source": [
519
+ "## Grab our model from the hub and deploy it on a SageMaker endpoint"
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": null,
525
+ "metadata": {},
526
+ "outputs": [],
527
+ "source": [
528
+ "from sagemaker.huggingface.model import HuggingFaceModel\n",
529
+ "\n",
530
+ "hub = {\n",
531
+ " 'HF_MODEL_ID':'juliensimon/'+repo_name, \n",
532
+ " 'HF_TASK':'sentiment-analysis'\n",
533
+ "}\n",
534
+ "\n",
535
+ "huggingface_model = HuggingFaceModel(\n",
536
+ " env=hub, \n",
537
+ " role=sagemaker.get_execution_role(), \n",
538
+ " transformers_version='4.6.1', \n",
539
+ " pytorch_version='1.7.1', \n",
540
+ " py_version='py36' \n",
541
+ ")\n",
542
+ "\n",
543
+ "huggingface_predictor = huggingface_model.deploy(\n",
544
+ " initial_instance_count=1,\n",
545
+ " instance_type='ml.m5.xlarge'\n",
546
+ ")"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": null,
552
+ "metadata": {},
553
+ "outputs": [],
554
+ "source": [
555
+ "test_data = {\n",
556
+ " 'inputs': \"This is a very nice camera, I'm super happy with it.\"\n",
557
+ "}\n",
558
+ "\n",
559
+ "prediction = huggingface_predictor.predict(test_data)\n",
560
+ "print(prediction)"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "metadata": {},
567
+ "outputs": [],
568
+ "source": [
569
+ "huggingface_predictor.delete_endpoint()"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": []
578
+ }
579
+ ],
580
+ "metadata": {
581
+ "instance_type": "ml.m5.4xlarge",
582
+ "kernelspec": {
583
+ "display_name": "Python 3 (PyTorch 1.6 Python 3.6 CPU Optimized)",
584
+ "language": "python",
585
+ "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:eu-west-1:470317259841:image/pytorch-1.6-cpu-py36-ubuntu16.04-v1"
586
+ },
587
+ "language_info": {
588
+ "codemirror_mode": {
589
+ "name": "ipython",
590
+ "version": 3
591
+ },
592
+ "file_extension": ".py",
593
+ "mimetype": "text/x-python",
594
+ "name": "python",
595
+ "nbconvert_exporter": "python",
596
+ "pygments_lexer": "ipython3",
597
+ "version": "3.6.13"
598
+ }
599
+ },
600
+ "nbformat": 4,
601
+ "nbformat_minor": 4
602
+ }
code/train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random, sys, argparse, os, logging, torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
3
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
4
+ from datasets import load_from_disk
5
+
6
+ if __name__ == "__main__":
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ # hyperparameters sent by the client are passed as command-line arguments to the script.
11
+ parser.add_argument("--epochs", type=int, default=3)
12
+ parser.add_argument("--train-batch-size", type=int, default=32)
13
+ parser.add_argument("--eval-batch-size", type=int, default=64)
14
+ parser.add_argument("--save-strategy", type=str, default='no')
15
+ parser.add_argument("--save-steps", type=int, default=500)
16
+ parser.add_argument("--model-name", type=str)
17
+ parser.add_argument("--learning-rate", type=str, default=5e-5)
18
+
19
+ # Data, model, and output directories
20
+ parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
21
+ parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
22
+ parser.add_argument("--n-gpus", type=str, default=os.environ["SM_NUM_GPUS"])
23
+ parser.add_argument("--train-dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
24
+ parser.add_argument("--valid-dir", type=str, default=os.environ["SM_CHANNEL_VALID"])
25
+
26
+ args, _ = parser.parse_known_args()
27
+
28
+ # load datasets
29
+ train_dataset = load_from_disk(args.train_dir)
30
+ valid_dataset = load_from_disk(args.valid_dir)
31
+
32
+ logger = logging.getLogger(__name__)
33
+ logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
34
+ logger.info(f" loaded valid_dataset length is: {len(valid_dataset)}")
35
+
36
+ # compute metrics function for binary classification
37
+ def compute_metrics(pred):
38
+ labels = pred.label_ids
39
+ preds = pred.predictions.argmax(-1)
40
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
41
+ acc = accuracy_score(labels, preds)
42
+ return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
43
+
44
+ # download model from model hub
45
+ model = AutoModelForSequenceClassification.from_pretrained(args.model_name)
46
+
47
+ # download the tokenizer too, which will be saved in the model artifact
48
+ # and used at prediction time
49
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
50
+
51
+ # define training args
52
+ training_args = TrainingArguments(
53
+ output_dir=args.model_dir,
54
+ num_train_epochs=args.epochs,
55
+ per_device_train_batch_size=args.train_batch_size,
56
+ per_device_eval_batch_size=args.eval_batch_size,
57
+ save_strategy=args.save_strategy,
58
+ save_steps=args.save_steps,
59
+ evaluation_strategy="epoch",
60
+ logging_dir=f"{args.output_data_dir}/logs",
61
+ learning_rate=float(args.learning_rate),
62
+ )
63
+
64
+ # create Trainer instance
65
+ trainer = Trainer(
66
+ model=model,
67
+ args=training_args,
68
+ tokenizer=tokenizer,
69
+ compute_metrics=compute_metrics,
70
+ train_dataset=train_dataset,
71
+ eval_dataset=valid_dataset,
72
+ )
73
+
74
+ # train model
75
+ trainer.train()
76
+
77
+ # evaluate model
78
+ eval_result = trainer.evaluate(eval_dataset=valid_dataset)
79
+
80
+ # writes eval result to file which can be accessed later in s3 output
81
+ with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
82
+ print(f"***** Eval results *****")
83
+ for key, value in sorted(eval_result.items()):
84
+ writer.write(f"{key} = {value}\n")
85
+
86
+ # Saves the model to s3
87
+ trainer.save_model(args.model_dir)
88
+