soutrik commited on
Commit
4828471
·
1 Parent(s): f057c2a

added new changes as per ResnetClassifier and tested with local and docker

Browse files
configs/experiment/catdog_experiment_resnet.yaml CHANGED
@@ -6,7 +6,7 @@
6
  defaults:
7
  - override /paths: catdog
8
  - override /data: catdog
9
- - override /model: catdog_classifier_convnext
10
  - override /callbacks: default
11
  - override /logger: default
12
  - override /trainer: default
@@ -15,7 +15,7 @@ defaults:
15
  # this allows you to overwrite only specified parameters
16
 
17
  seed: 42
18
- name: "catdog_experiment_convnext"
19
 
20
  # Logger-specific configurations
21
  logger:
@@ -33,7 +33,7 @@ data:
33
  image_size: 160
34
 
35
  model:
36
- base_model: convnext_tiny.fb_in22k_ft_in1k
37
  pretrained: True
38
  lr: 1e-3
39
  weight_decay: 1e-5
@@ -41,11 +41,10 @@ model:
41
  patience: 5
42
  min_lr: 1e-6
43
  num_classes: 2
44
- kernel_sizes: 7
45
 
46
  trainer:
47
  min_epochs: 1
48
- max_epochs: 3
49
 
50
  callbacks:
51
  model_checkpoint:
 
6
  defaults:
7
  - override /paths: catdog
8
  - override /data: catdog
9
+ - override /model: catdog_classifier_resnet
10
  - override /callbacks: default
11
  - override /logger: default
12
  - override /trainer: default
 
15
  # this allows you to overwrite only specified parameters
16
 
17
  seed: 42
18
+ name: "catdog_experiment_resnet"
19
 
20
  # Logger-specific configurations
21
  logger:
 
33
  image_size: 160
34
 
35
  model:
36
+ base_model: efficientnet_b0
37
  pretrained: True
38
  lr: 1e-3
39
  weight_decay: 1e-5
 
41
  patience: 5
42
  min_lr: 1e-6
43
  num_classes: 2
 
44
 
45
  trainer:
46
  min_epochs: 1
47
+ max_epochs: 5
48
 
49
  callbacks:
50
  model_checkpoint:
configs/infer.yaml CHANGED
@@ -5,7 +5,7 @@
5
  defaults:
6
  - _self_
7
  - data: catdog
8
- - model: catdog_model
9
  - callbacks: default
10
  - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
  - trainer: default
 
5
  defaults:
6
  - _self_
7
  - data: catdog
8
+ - model: catdog_classifier
9
  - callbacks: default
10
  - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
  - trainer: default
configs/model/catdog_classifier_resnet.yaml CHANGED
@@ -1,16 +1,13 @@
1
-
2
- _target_: src.models.catdog_classifier_convnext.ConvNextClassifier
3
 
4
  # model params
5
- base_model: convnext_tiny.in12k_ft_in1k
6
  pretrained: True
7
  num_classes: 2
8
- kernel_sizes: 7
9
  # optimizer params
10
  lr: 1e-3
11
  weight_decay: 1e-5
12
-
13
  # scheduler params
14
  factor: 0.1
15
  patience: 10
16
- min_lr: 1e-6
 
1
+ _target_: src.models.catdog_model_resnet.ResnetClassifier
 
2
 
3
  # model params
4
+ base_model: efficientnet_b0
5
  pretrained: True
6
  num_classes: 2
 
7
  # optimizer params
8
  lr: 1e-3
9
  weight_decay: 1e-5
 
10
  # scheduler params
11
  factor: 0.1
12
  patience: 10
13
+ min_lr: 1e-6
docker-compose.yaml CHANGED
@@ -3,7 +3,7 @@ services:
3
  build:
4
  context: .
5
  command: |
6
- python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
7
  python -m src.create_artifacts && \
8
  touch ./checkpoints/train_done.flag
9
  volumes:
@@ -31,7 +31,7 @@ services:
31
  build:
32
  context: .
33
  command: |
34
- sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=test ++train=False ++test=True'
35
  volumes:
36
  - ./data:/app/data
37
  - ./checkpoints:/app/checkpoints
@@ -52,13 +52,12 @@ services:
52
  - driver: nvidia
53
  count: 1
54
  capabilities: [gpu]
55
-
56
 
57
- server:
58
  build:
59
  context: .
60
  command: |
61
- sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.server'
62
  volumes:
63
  - ./data:/app/data
64
  - ./checkpoints:/app/checkpoints
@@ -67,14 +66,11 @@ services:
67
  environment:
68
  - PYTHONUNBUFFERED=1
69
  - PYTHONPATH=/app
70
- - SERVER_URL=http://localhost:8080
71
  shm_size: '4g'
72
  networks:
73
  - default
74
  env_file:
75
  - .env
76
- ports:
77
- - "8080:8080"
78
  deploy:
79
  resources:
80
  reservations:
@@ -82,37 +78,8 @@ services:
82
  - driver: nvidia
83
  count: 1
84
  capabilities: [gpu]
85
-
86
- client:
87
- build:
88
- context: .
89
- command: |
90
- sh -c 'until curl -s http://server:8080/health; do echo "Waiting for server to be ready..."; sleep 5; done && \
91
- ./run_client.sh'
92
- volumes:
93
- - ./data:/app/data
94
- - ./checkpoints:/app/checkpoints
95
- - ./artifacts:/app/artifacts
96
- - ./logs:/app/logs
97
- environment:
98
- - PYTHONUNBUFFERED=1
99
- - PYTHONPATH=/app
100
- - SERVER_URL=http://server:8080
101
- shm_size: '4g'
102
- networks:
103
- - default
104
- env_file:
105
- - .env
106
-
107
- deploy:
108
- resources:
109
- reservations:
110
- devices:
111
- - driver: nvidia
112
- count: 1
113
- capabilities: [gpu]
114
-
115
 
 
116
  volumes:
117
  data:
118
  checkpoints:
 
3
  build:
4
  context: .
5
  command: |
6
+ python -m src.train_optuna_callbacks experiment=catdog_experiment_resnet ++task_name=train ++train=True ++test=False && \
7
  python -m src.create_artifacts && \
8
  touch ./checkpoints/train_done.flag
9
  volumes:
 
31
  build:
32
  context: .
33
  command: |
34
+ sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment_resnet ++task_name=test ++train=False ++test=True'
35
  volumes:
36
  - ./data:/app/data
37
  - ./checkpoints:/app/checkpoints
 
52
  - driver: nvidia
53
  count: 1
54
  capabilities: [gpu]
 
55
 
56
+ inference:
57
  build:
58
  context: .
59
  command: |
60
+ sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.infer experiment=catdog_experiment_resnet'
61
  volumes:
62
  - ./data:/app/data
63
  - ./checkpoints:/app/checkpoints
 
66
  environment:
67
  - PYTHONUNBUFFERED=1
68
  - PYTHONPATH=/app
 
69
  shm_size: '4g'
70
  networks:
71
  - default
72
  env_file:
73
  - .env
 
 
74
  deploy:
75
  resources:
76
  reservations:
 
78
  - driver: nvidia
79
  count: 1
80
  capabilities: [gpu]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+
83
  volumes:
84
  data:
85
  checkpoints:
poetry.lock CHANGED
@@ -3014,70 +3014,70 @@ test = ["objgraph", "psutil"]
3014
 
3015
  [[package]]
3016
  name = "grpcio"
3017
- version = "1.68.0"
3018
  description = "HTTP/2-based RPC framework"
3019
  optional = false
3020
  python-versions = ">=3.8"
3021
  files = [
3022
- {file = "grpcio-1.68.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:619b5d0f29f4f5351440e9343224c3e19912c21aeda44e0c49d0d147a8d01544"},
3023
- {file = "grpcio-1.68.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:a59f5822f9459bed098ffbceb2713abbf7c6fd13f2b9243461da5c338d0cd6c3"},
3024
- {file = "grpcio-1.68.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:c03d89df516128febc5a7e760d675b478ba25802447624edf7aa13b1e7b11e2a"},
3025
- {file = "grpcio-1.68.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44bcbebb24363d587472089b89e2ea0ab2e2b4df0e4856ba4c0b087c82412121"},
3026
- {file = "grpcio-1.68.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79f81b7fbfb136247b70465bd836fa1733043fdee539cd6031cb499e9608a110"},
3027
- {file = "grpcio-1.68.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:88fb2925789cfe6daa20900260ef0a1d0a61283dfb2d2fffe6194396a354c618"},
3028
- {file = "grpcio-1.68.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:99f06232b5c9138593ae6f2e355054318717d32a9c09cdc5a2885540835067a1"},
3029
- {file = "grpcio-1.68.0-cp310-cp310-win32.whl", hash = "sha256:a6213d2f7a22c3c30a479fb5e249b6b7e648e17f364598ff64d08a5136fe488b"},
3030
- {file = "grpcio-1.68.0-cp310-cp310-win_amd64.whl", hash = "sha256:15327ab81131ef9b94cb9f45b5bd98803a179c7c61205c8c0ac9aff9d6c4e82a"},
3031
- {file = "grpcio-1.68.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:3b2b559beb2d433129441783e5f42e3be40a9e1a89ec906efabf26591c5cd415"},
3032
- {file = "grpcio-1.68.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e46541de8425a4d6829ac6c5d9b16c03c292105fe9ebf78cb1c31e8d242f9155"},
3033
- {file = "grpcio-1.68.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:c1245651f3c9ea92a2db4f95d37b7597db6b246d5892bca6ee8c0e90d76fb73c"},
3034
- {file = "grpcio-1.68.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f1931c7aa85be0fa6cea6af388e576f3bf6baee9e5d481c586980c774debcb4"},
3035
- {file = "grpcio-1.68.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b0ff09c81e3aded7a183bc6473639b46b6caa9c1901d6f5e2cba24b95e59e30"},
3036
- {file = "grpcio-1.68.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8c73f9fbbaee1a132487e31585aa83987ddf626426d703ebcb9a528cf231c9b1"},
3037
- {file = "grpcio-1.68.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6b2f98165ea2790ea159393a2246b56f580d24d7da0d0342c18a085299c40a75"},
3038
- {file = "grpcio-1.68.0-cp311-cp311-win32.whl", hash = "sha256:e1e7ed311afb351ff0d0e583a66fcb39675be112d61e7cfd6c8269884a98afbc"},
3039
- {file = "grpcio-1.68.0-cp311-cp311-win_amd64.whl", hash = "sha256:e0d2f68eaa0a755edd9a47d40e50dba6df2bceda66960dee1218da81a2834d27"},
3040
- {file = "grpcio-1.68.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:8af6137cc4ae8e421690d276e7627cfc726d4293f6607acf9ea7260bd8fc3d7d"},
3041
- {file = "grpcio-1.68.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4028b8e9a3bff6f377698587d642e24bd221810c06579a18420a17688e421af7"},
3042
- {file = "grpcio-1.68.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f60fa2adf281fd73ae3a50677572521edca34ba373a45b457b5ebe87c2d01e1d"},
3043
- {file = "grpcio-1.68.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e18589e747c1e70b60fab6767ff99b2d0c359ea1db8a2cb524477f93cdbedf5b"},
3044
- {file = "grpcio-1.68.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0d30f3fee9372796f54d3100b31ee70972eaadcc87314be369360248a3dcffe"},
3045
- {file = "grpcio-1.68.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7e0a3e72c0e9a1acab77bef14a73a416630b7fd2cbd893c0a873edc47c42c8cd"},
3046
- {file = "grpcio-1.68.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a831dcc343440969aaa812004685ed322cdb526cd197112d0db303b0da1e8659"},
3047
- {file = "grpcio-1.68.0-cp312-cp312-win32.whl", hash = "sha256:5a180328e92b9a0050958ced34dddcb86fec5a8b332f5a229e353dafc16cd332"},
3048
- {file = "grpcio-1.68.0-cp312-cp312-win_amd64.whl", hash = "sha256:2bddd04a790b69f7a7385f6a112f46ea0b34c4746f361ebafe9ca0be567c78e9"},
3049
- {file = "grpcio-1.68.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:fc05759ffbd7875e0ff2bd877be1438dfe97c9312bbc558c8284a9afa1d0f40e"},
3050
- {file = "grpcio-1.68.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:15fa1fe25d365a13bc6d52fcac0e3ee1f9baebdde2c9b3b2425f8a4979fccea1"},
3051
- {file = "grpcio-1.68.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:32a9cb4686eb2e89d97022ecb9e1606d132f85c444354c17a7dbde4a455e4a3b"},
3052
- {file = "grpcio-1.68.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dba037ff8d284c8e7ea9a510c8ae0f5b016004f13c3648f72411c464b67ff2fb"},
3053
- {file = "grpcio-1.68.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0efbbd849867e0e569af09e165363ade75cf84f5229b2698d53cf22c7a4f9e21"},
3054
- {file = "grpcio-1.68.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:4e300e6978df0b65cc2d100c54e097c10dfc7018b9bd890bbbf08022d47f766d"},
3055
- {file = "grpcio-1.68.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:6f9c7ad1a23e1047f827385f4713b5b8c6c7d325705be1dd3e31fb00dcb2f665"},
3056
- {file = "grpcio-1.68.0-cp313-cp313-win32.whl", hash = "sha256:3ac7f10850fd0487fcce169c3c55509101c3bde2a3b454869639df2176b60a03"},
3057
- {file = "grpcio-1.68.0-cp313-cp313-win_amd64.whl", hash = "sha256:afbf45a62ba85a720491bfe9b2642f8761ff348006f5ef67e4622621f116b04a"},
3058
- {file = "grpcio-1.68.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:f8f695d9576ce836eab27ba7401c60acaf9ef6cf2f70dfe5462055ba3df02cc3"},
3059
- {file = "grpcio-1.68.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9fe1b141cda52f2ca73e17d2d3c6a9f3f3a0c255c216b50ce616e9dca7e3441d"},
3060
- {file = "grpcio-1.68.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:4df81d78fd1646bf94ced4fb4cd0a7fe2e91608089c522ef17bc7db26e64effd"},
3061
- {file = "grpcio-1.68.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46a2d74d4dd8993151c6cd585594c082abe74112c8e4175ddda4106f2ceb022f"},
3062
- {file = "grpcio-1.68.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a17278d977746472698460c63abf333e1d806bd41f2224f90dbe9460101c9796"},
3063
- {file = "grpcio-1.68.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:15377bce516b1c861c35e18eaa1c280692bf563264836cece693c0f169b48829"},
3064
- {file = "grpcio-1.68.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cc5f0a4f5904b8c25729a0498886b797feb817d1fd3812554ffa39551112c161"},
3065
- {file = "grpcio-1.68.0-cp38-cp38-win32.whl", hash = "sha256:def1a60a111d24376e4b753db39705adbe9483ef4ca4761f825639d884d5da78"},
3066
- {file = "grpcio-1.68.0-cp38-cp38-win_amd64.whl", hash = "sha256:55d3b52fd41ec5772a953612db4e70ae741a6d6ed640c4c89a64f017a1ac02b5"},
3067
- {file = "grpcio-1.68.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:0d230852ba97654453d290e98d6aa61cb48fa5fafb474fb4c4298d8721809354"},
3068
- {file = "grpcio-1.68.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:50992f214264e207e07222703c17d9cfdcc2c46ed5a1ea86843d440148ebbe10"},
3069
- {file = "grpcio-1.68.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:14331e5c27ed3545360464a139ed279aa09db088f6e9502e95ad4bfa852bb116"},
3070
- {file = "grpcio-1.68.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f84890b205692ea813653ece4ac9afa2139eae136e419231b0eec7c39fdbe4c2"},
3071
- {file = "grpcio-1.68.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0cf343c6f4f6aa44863e13ec9ddfe299e0be68f87d68e777328bff785897b05"},
3072
- {file = "grpcio-1.68.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:fd2c2d47969daa0e27eadaf15c13b5e92605c5e5953d23c06d0b5239a2f176d3"},
3073
- {file = "grpcio-1.68.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:18668e36e7f4045820f069997834e94e8275910b1f03e078a6020bd464cb2363"},
3074
- {file = "grpcio-1.68.0-cp39-cp39-win32.whl", hash = "sha256:2af76ab7c427aaa26aa9187c3e3c42f38d3771f91a20f99657d992afada2294a"},
3075
- {file = "grpcio-1.68.0-cp39-cp39-win_amd64.whl", hash = "sha256:e694b5928b7b33ca2d3b4d5f9bf8b5888906f181daff6b406f4938f3a997a490"},
3076
- {file = "grpcio-1.68.0.tar.gz", hash = "sha256:7e7483d39b4a4fddb9906671e9ea21aaad4f031cdfc349fec76bdfa1e404543a"},
3077
- ]
3078
-
3079
- [package.extras]
3080
- protobuf = ["grpcio-tools (>=1.68.0)"]
3081
 
3082
  [[package]]
3083
  name = "gto"
@@ -3600,13 +3600,13 @@ files = [
3600
 
3601
  [[package]]
3602
  name = "jsonargparse"
3603
- version = "4.34.0"
3604
  description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables."
3605
  optional = false
3606
  python-versions = ">=3.8"
3607
  files = [
3608
- {file = "jsonargparse-4.34.0-py3-none-any.whl", hash = "sha256:a3eb8a9a289332066b1b33463efa49d5d2a8d729b6cb60e9e30231d0c19dfb13"},
3609
- {file = "jsonargparse-4.34.0.tar.gz", hash = "sha256:88b3ff0beaff40909dc69244f0527b054f8be0132086c72aa7e1d99414024b43"},
3610
  ]
3611
 
3612
  [package.dependencies]
@@ -5744,13 +5744,13 @@ tests = ["chardet", "parameterized", "pytest", "pytest-cov", "pytest-xdist[psuti
5744
 
5745
  [[package]]
5746
  name = "pydrive2"
5747
- version = "1.21.2"
5748
  description = "Google Drive API made easy. Maintained fork of PyDrive."
5749
  optional = false
5750
  python-versions = ">=3.8"
5751
  files = [
5752
- {file = "PyDrive2-1.21.2-py3-none-any.whl", hash = "sha256:a3b72e522b8f5ba4e93ab165bbf120544567583c61a6c7904ef1ff47afc005d6"},
5753
- {file = "pydrive2-1.21.2.tar.gz", hash = "sha256:2a21c8319a225943c70e7566eb13a1524d1d7193621de1eb8e5f95e037641508"},
5754
  ]
5755
 
5756
  [package.dependencies]
@@ -5759,7 +5759,7 @@ fsspec = {version = ">=2021.07.0", optional = true, markers = "extra == \"fsspec
5759
  funcy = {version = ">=1.14", optional = true, markers = "extra == \"fsspec\""}
5760
  google-api-python-client = ">=1.12.5"
5761
  oauth2client = ">=4.0.0"
5762
- pyOpenSSL = ">=19.1.0,<=24.2.1"
5763
  PyYAML = ">=3.0"
5764
  tqdm = {version = ">=4.0.0", optional = true, markers = "extra == \"fsspec\""}
5765
 
@@ -5877,21 +5877,21 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
5877
 
5878
  [[package]]
5879
  name = "pyopenssl"
5880
- version = "22.0.0"
5881
  description = "Python wrapper module around the OpenSSL library"
5882
  optional = false
5883
- python-versions = ">=3.6"
5884
  files = [
5885
- {file = "pyOpenSSL-22.0.0-py2.py3-none-any.whl", hash = "sha256:ea252b38c87425b64116f808355e8da644ef9b07e429398bfece610f893ee2e0"},
5886
- {file = "pyOpenSSL-22.0.0.tar.gz", hash = "sha256:660b1b1425aac4a1bea1d94168a85d99f0b3144c869dd4390d27629d0087f1bf"},
5887
  ]
5888
 
5889
  [package.dependencies]
5890
- cryptography = ">=35.0"
5891
 
5892
  [package.extras]
5893
- docs = ["sphinx", "sphinx-rtd-theme"]
5894
- test = ["flaky", "pretend", "pytest (>=3.0.1)"]
5895
 
5896
  [[package]]
5897
  name = "pyparsing"
@@ -8397,4 +8397,4 @@ type = ["pytest-mypy"]
8397
  [metadata]
8398
  lock-version = "2.0"
8399
  python-versions = "3.10.15"
8400
- content-hash = "13c48bb5304783670c78f25e07ef651b8b5ff25ca5f04bced20f9d76327958d3"
 
3014
 
3015
  [[package]]
3016
  name = "grpcio"
3017
+ version = "1.68.1"
3018
  description = "HTTP/2-based RPC framework"
3019
  optional = false
3020
  python-versions = ">=3.8"
3021
  files = [
3022
+ {file = "grpcio-1.68.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:d35740e3f45f60f3c37b1e6f2f4702c23867b9ce21c6410254c9c682237da68d"},
3023
+ {file = "grpcio-1.68.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:d99abcd61760ebb34bdff37e5a3ba333c5cc09feda8c1ad42547bea0416ada78"},
3024
+ {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:f8261fa2a5f679abeb2a0a93ad056d765cdca1c47745eda3f2d87f874ff4b8c9"},
3025
+ {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0feb02205a27caca128627bd1df4ee7212db051019a9afa76f4bb6a1a80ca95e"},
3026
+ {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919d7f18f63bcad3a0f81146188e90274fde800a94e35d42ffe9eadf6a9a6330"},
3027
+ {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:963cc8d7d79b12c56008aabd8b457f400952dbea8997dd185f155e2f228db079"},
3028
+ {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ccf2ebd2de2d6661e2520dae293298a3803a98ebfc099275f113ce1f6c2a80f1"},
3029
+ {file = "grpcio-1.68.1-cp310-cp310-win32.whl", hash = "sha256:2cc1fd04af8399971bcd4f43bd98c22d01029ea2e56e69c34daf2bf8470e47f5"},
3030
+ {file = "grpcio-1.68.1-cp310-cp310-win_amd64.whl", hash = "sha256:ee2e743e51cb964b4975de572aa8fb95b633f496f9fcb5e257893df3be854746"},
3031
+ {file = "grpcio-1.68.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:55857c71641064f01ff0541a1776bfe04a59db5558e82897d35a7793e525774c"},
3032
+ {file = "grpcio-1.68.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4b177f5547f1b995826ef529d2eef89cca2f830dd8b2c99ffd5fde4da734ba73"},
3033
+ {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:3522c77d7e6606d6665ec8d50e867f13f946a4e00c7df46768f1c85089eae515"},
3034
+ {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d1fae6bbf0816415b81db1e82fb3bf56f7857273c84dcbe68cbe046e58e1ccd"},
3035
+ {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:298ee7f80e26f9483f0b6f94cc0a046caf54400a11b644713bb5b3d8eb387600"},
3036
+ {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb5780e2e740b6b4f2d208e90453591036ff80c02cc605fea1af8e6fc6b1bbe"},
3037
+ {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ddda1aa22495d8acd9dfbafff2866438d12faec4d024ebc2e656784d96328ad0"},
3038
+ {file = "grpcio-1.68.1-cp311-cp311-win32.whl", hash = "sha256:b33bd114fa5a83f03ec6b7b262ef9f5cac549d4126f1dc702078767b10c46ed9"},
3039
+ {file = "grpcio-1.68.1-cp311-cp311-win_amd64.whl", hash = "sha256:7f20ebec257af55694d8f993e162ddf0d36bd82d4e57f74b31c67b3c6d63d8b2"},
3040
+ {file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"},
3041
+ {file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"},
3042
+ {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"},
3043
+ {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"},
3044
+ {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"},
3045
+ {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"},
3046
+ {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"},
3047
+ {file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"},
3048
+ {file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"},
3049
+ {file = "grpcio-1.68.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:a47faedc9ea2e7a3b6569795c040aae5895a19dde0c728a48d3c5d7995fda385"},
3050
+ {file = "grpcio-1.68.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:390eee4225a661c5cd133c09f5da1ee3c84498dc265fd292a6912b65c421c78c"},
3051
+ {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:66a24f3d45c33550703f0abb8b656515b0ab777970fa275693a2f6dc8e35f1c1"},
3052
+ {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c08079b4934b0bf0a8847f42c197b1d12cba6495a3d43febd7e99ecd1cdc8d54"},
3053
+ {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8720c25cd9ac25dd04ee02b69256d0ce35bf8a0f29e20577427355272230965a"},
3054
+ {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:04cfd68bf4f38f5bb959ee2361a7546916bd9a50f78617a346b3aeb2b42e2161"},
3055
+ {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c28848761a6520c5c6071d2904a18d339a796ebe6b800adc8b3f474c5ce3c3ad"},
3056
+ {file = "grpcio-1.68.1-cp313-cp313-win32.whl", hash = "sha256:77d65165fc35cff6e954e7fd4229e05ec76102d4406d4576528d3a3635fc6172"},
3057
+ {file = "grpcio-1.68.1-cp313-cp313-win_amd64.whl", hash = "sha256:a8040f85dcb9830d8bbb033ae66d272614cec6faceee88d37a88a9bd1a7a704e"},
3058
+ {file = "grpcio-1.68.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:eeb38ff04ab6e5756a2aef6ad8d94e89bb4a51ef96e20f45c44ba190fa0bcaad"},
3059
+ {file = "grpcio-1.68.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a3869a6661ec8f81d93f4597da50336718bde9eb13267a699ac7e0a1d6d0bea"},
3060
+ {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:2c4cec6177bf325eb6faa6bd834d2ff6aa8bb3b29012cceb4937b86f8b74323c"},
3061
+ {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12941d533f3cd45d46f202e3667be8ebf6bcb3573629c7ec12c3e211d99cfccf"},
3062
+ {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80af6f1e69c5e68a2be529990684abdd31ed6622e988bf18850075c81bb1ad6e"},
3063
+ {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e8dbe3e00771bfe3d04feed8210fc6617006d06d9a2679b74605b9fed3e8362c"},
3064
+ {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:83bbf5807dc3ee94ce1de2dfe8a356e1d74101e4b9d7aa8c720cc4818a34aded"},
3065
+ {file = "grpcio-1.68.1-cp38-cp38-win32.whl", hash = "sha256:8cb620037a2fd9eeee97b4531880e439ebfcd6d7d78f2e7dcc3726428ab5ef63"},
3066
+ {file = "grpcio-1.68.1-cp38-cp38-win_amd64.whl", hash = "sha256:52fbf85aa71263380d330f4fce9f013c0798242e31ede05fcee7fbe40ccfc20d"},
3067
+ {file = "grpcio-1.68.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb400138e73969eb5e0535d1d06cae6a6f7a15f2cc74add320e2130b8179211a"},
3068
+ {file = "grpcio-1.68.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a1b988b40f2fd9de5c820f3a701a43339d8dcf2cb2f1ca137e2c02671cc83ac1"},
3069
+ {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:96f473cdacfdd506008a5d7579c9f6a7ff245a9ade92c3c0265eb76cc591914f"},
3070
+ {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37ea3be171f3cf3e7b7e412a98b77685eba9d4fd67421f4a34686a63a65d99f9"},
3071
+ {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ceb56c4285754e33bb3c2fa777d055e96e6932351a3082ce3559be47f8024f0"},
3072
+ {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dffd29a2961f3263a16d73945b57cd44a8fd0b235740cb14056f0612329b345e"},
3073
+ {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:025f790c056815b3bf53da850dd70ebb849fd755a4b1ac822cb65cd631e37d43"},
3074
+ {file = "grpcio-1.68.1-cp39-cp39-win32.whl", hash = "sha256:1098f03dedc3b9810810568060dea4ac0822b4062f537b0f53aa015269be0a76"},
3075
+ {file = "grpcio-1.68.1-cp39-cp39-win_amd64.whl", hash = "sha256:334ab917792904245a028f10e803fcd5b6f36a7b2173a820c0b5b076555825e1"},
3076
+ {file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"},
3077
+ ]
3078
+
3079
+ [package.extras]
3080
+ protobuf = ["grpcio-tools (>=1.68.1)"]
3081
 
3082
  [[package]]
3083
  name = "gto"
 
3600
 
3601
  [[package]]
3602
  name = "jsonargparse"
3603
+ version = "4.34.1"
3604
  description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables."
3605
  optional = false
3606
  python-versions = ">=3.8"
3607
  files = [
3608
+ {file = "jsonargparse-4.34.1-py3-none-any.whl", hash = "sha256:1d595080e080d4581ef821b8ec71364037052be4da0f9c7bdafe95962672cee6"},
3609
+ {file = "jsonargparse-4.34.1.tar.gz", hash = "sha256:6e0d1ab67b12b1086fe7bbe5ba429fbe9865c36493c9bb6aeb1e243047fdb58c"},
3610
  ]
3611
 
3612
  [package.dependencies]
 
5744
 
5745
  [[package]]
5746
  name = "pydrive2"
5747
+ version = "1.21.1"
5748
  description = "Google Drive API made easy. Maintained fork of PyDrive."
5749
  optional = false
5750
  python-versions = ">=3.8"
5751
  files = [
5752
+ {file = "PyDrive2-1.21.1-py3-none-any.whl", hash = "sha256:d24b3334bc5c242e5ec58ad6ee7efbd2216aa92098c3eed353ce644f27a7e97b"},
5753
+ {file = "pydrive2-1.21.1.tar.gz", hash = "sha256:70da0244a29a6922e28620a32e251ac6ab018449f1bb0485e9a39114a069dde0"},
5754
  ]
5755
 
5756
  [package.dependencies]
 
5759
  funcy = {version = ">=1.14", optional = true, markers = "extra == \"fsspec\""}
5760
  google-api-python-client = ">=1.12.5"
5761
  oauth2client = ">=4.0.0"
5762
+ pyOpenSSL = ">=19.1.0"
5763
  PyYAML = ">=3.0"
5764
  tqdm = {version = ">=4.0.0", optional = true, markers = "extra == \"fsspec\""}
5765
 
 
5877
 
5878
  [[package]]
5879
  name = "pyopenssl"
5880
+ version = "24.3.0"
5881
  description = "Python wrapper module around the OpenSSL library"
5882
  optional = false
5883
+ python-versions = ">=3.7"
5884
  files = [
5885
+ {file = "pyOpenSSL-24.3.0-py3-none-any.whl", hash = "sha256:e474f5a473cd7f92221cc04976e48f4d11502804657a08a989fb3be5514c904a"},
5886
+ {file = "pyopenssl-24.3.0.tar.gz", hash = "sha256:49f7a019577d834746bc55c5fce6ecbcec0f2b4ec5ce1cf43a9a173b8138bb36"},
5887
  ]
5888
 
5889
  [package.dependencies]
5890
+ cryptography = ">=41.0.5,<45"
5891
 
5892
  [package.extras]
5893
+ docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx_rtd_theme"]
5894
+ test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
5895
 
5896
  [[package]]
5897
  name = "pyparsing"
 
8397
  [metadata]
8398
  lock-version = "2.0"
8399
  python-versions = "3.10.15"
8400
+ content-hash = "78d6a9383381dfba8ea585d06d83d833b85434561a2c3f39d8eda83e5e1697d9"
pyproject.toml CHANGED
@@ -77,9 +77,9 @@ nvitop = "^1.3.2"
77
  gradio = "5.7.1"
78
  gradio-client = "^1.5.0"
79
  accelerate = "^1.1.1"
80
- pyopenssl = "<23.0.0"
81
  cryptography = "^44.0.0"
82
  boto3 = "*"
 
83
 
84
  [tool.poetry.dev-dependencies]
85
  pytest-asyncio = "^0.20.3"
 
77
  gradio = "5.7.1"
78
  gradio-client = "^1.5.0"
79
  accelerate = "^1.1.1"
 
80
  cryptography = "^44.0.0"
81
  boto3 = "*"
82
+ pyopenssl = "^24.3.0"
83
 
84
  [tool.poetry.dev-dependencies]
85
  pytest-asyncio = "^0.20.3"
src/infer.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn.functional as F
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
7
  from torchvision import transforms
8
- from src.models.catdog_model import ViTTinyClassifier
9
  from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
10
  import hydra
11
  from omegaconf import DictConfig, OmegaConf
@@ -13,6 +13,7 @@ from dotenv import load_dotenv, find_dotenv
13
  import rootutils
14
  import time
15
  from loguru import logger
 
16
 
17
  # Load environment variables
18
  load_dotenv(find_dotenv(".env"))
@@ -93,9 +94,15 @@ def main_infer(cfg: DictConfig):
93
  if flag_file.exists():
94
  flag_file.unlink()
95
 
 
 
 
 
 
 
96
  # Load the trained model
97
- model = ViTTinyClassifier.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
98
- classes = ["dog", "cat"]
99
 
100
  # Download an image for inference
101
  download_image(cfg)
 
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
7
  from torchvision import transforms
8
+ from src.models.catdog_model_resnet import ResnetClassifier
9
  from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
10
  import hydra
11
  from omegaconf import DictConfig, OmegaConf
 
13
  import rootutils
14
  import time
15
  from loguru import logger
16
+ from src.utils.aws_s3_services import S3Handler
17
 
18
  # Load environment variables
19
  load_dotenv(find_dotenv(".env"))
 
94
  if flag_file.exists():
95
  flag_file.unlink()
96
 
97
+ # download the model from S3
98
+ s3_handler = S3Handler(bucket_name="deep-bucket-s3")
99
+ s3_handler.download_folder(
100
+ "checkpoints",
101
+ "checkpoints",
102
+ )
103
  # Load the trained model
104
+ model = ResnetClassifier.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
105
+ classes = cfg.labels
106
 
107
  # Download an image for inference
108
  download_image(cfg)
src/models/catdog_model_resnet.py CHANGED
@@ -9,6 +9,8 @@ import torch
9
  class ResnetClassifier(L.LightningModule):
10
  def __init__(
11
  self,
 
 
12
  num_classes: int = 2, # Binary classification with two classes
13
  lr: float = 1e-3,
14
  weight_decay: float = 1e-5,
@@ -21,7 +23,7 @@ class ResnetClassifier(L.LightningModule):
21
 
22
  # Vision Transformer model initialization
23
  self.model = timm.create_model(
24
- "efficientnet_b0", pretrained=True, num_classes=num_classes
25
  )
26
 
27
  # Define accuracy and F1 metrics for binary classification
@@ -90,5 +92,5 @@ class ResnetClassifier(L.LightningModule):
90
 
91
 
92
  if __name__ == "__main__":
93
- model = ViTTinyClassifier()
94
  print(model)
 
9
  class ResnetClassifier(L.LightningModule):
10
  def __init__(
11
  self,
12
+ base_model: str = "efficientnet_b0",
13
+ pretrained: bool = True,
14
  num_classes: int = 2, # Binary classification with two classes
15
  lr: float = 1e-3,
16
  weight_decay: float = 1e-5,
 
23
 
24
  # Vision Transformer model initialization
25
  self.model = timm.create_model(
26
+ base_model, pretrained=pretrained, num_classes=num_classes
27
  )
28
 
29
  # Define accuracy and F1 metrics for binary classification
 
92
 
93
 
94
  if __name__ == "__main__":
95
+ model = ResnetClassifier()
96
  print(model)
src/train_optuna_callbacks.py CHANGED
@@ -19,6 +19,7 @@ from lightning.pytorch.callbacks import Callback
19
  import optuna
20
  from lightning.pytorch import Trainer
21
  import json
 
22
 
23
  # Load environment variables
24
  load_dotenv(find_dotenv(".env"))
@@ -131,10 +132,10 @@ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[Callba
131
  """Objective function for Optuna hyperparameter tuning."""
132
 
133
  # Sample hyperparameters for the model
134
- cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
135
- cfg.model.depth = trial.suggest_int("depth", 2, 6)
136
  cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
137
- cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
138
 
139
  # Initialize data module and model
140
  data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
@@ -207,6 +208,13 @@ def setup_trainer(cfg: DictConfig):
207
  with open(Path(cfg.paths.ckpt_dir) / "train_done.flag", "w") as f:
208
  f.write("Training completed successfully!")
209
 
 
 
 
 
 
 
 
210
  # Testing phase with best hyperparameters
211
  if cfg.get("test", False):
212
  best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
 
19
  import optuna
20
  from lightning.pytorch import Trainer
21
  import json
22
+ from src.utils.aws_s3_services import S3Handler
23
 
24
  # Load environment variables
25
  load_dotenv(find_dotenv(".env"))
 
132
  """Objective function for Optuna hyperparameter tuning."""
133
 
134
  # Sample hyperparameters for the model
135
+ # cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
136
+ # cfg.model.depth = trial.suggest_int("depth", 2, 6)
137
  cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
138
+ # cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
139
 
140
  # Initialize data module and model
141
  data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
 
208
  with open(Path(cfg.paths.ckpt_dir) / "train_done.flag", "w") as f:
209
  f.write("Training completed successfully!")
210
 
211
+ # upload the checkpoints to S3
212
+ s3_handler = S3Handler(bucket_name="deep-bucket-s3")
213
+ s3_handler.upload_folder(
214
+ "checkpoints",
215
+ "checkpoints",
216
+ )
217
+
218
  # Testing phase with best hyperparameters
219
  if cfg.get("test", False):
220
  best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
src/utils/aws_s3_services.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import os
3
+ from pathlib import Path
4
+ from dotenv import load_dotenv, find_dotenv
5
+
6
+ # Load environment variables from .env file
7
+ load_dotenv(find_dotenv(".env"))
8
+
9
+
10
+ class S3Handler:
11
+ def __init__(self, bucket_name):
12
+ self.bucket_name = bucket_name
13
+ self.s3 = boto3.client(
14
+ "s3",
15
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
16
+ aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
17
+ region_name=os.getenv("AWS_REGION"),
18
+ )
19
+
20
+ def upload_folder(self, source_folder, dest_folder, filenames=None):
21
+ """
22
+ Upload specified files or all files from a local folder to an S3 folder.
23
+
24
+ Args:
25
+ source_folder (str): Local source folder path.
26
+ dest_folder (str): Destination folder path in S3.
27
+ filenames (list): List of filenames to upload (relative to source_folder). If None, uploads all files.
28
+ """
29
+ source_folder = Path(source_folder)
30
+
31
+ # Select files based on filenames list or all files if filenames is None
32
+ files_to_upload = (
33
+ [source_folder / file for file in filenames]
34
+ if filenames
35
+ else list(source_folder.rglob("*"))
36
+ )
37
+
38
+ for file_path in files_to_upload:
39
+ if file_path.is_file():
40
+ s3_path = f"{dest_folder}/{file_path.relative_to(source_folder)}"
41
+ self.s3.upload_file(str(file_path), self.bucket_name, s3_path)
42
+ print(f"Uploaded: {file_path} to {s3_path}")
43
+ else:
44
+ print(f"File not found: {file_path}")
45
+
46
+ def download_folder(self, s3_folder, dest_folder):
47
+ """
48
+ Download all files from an S3 folder to a local folder.
49
+
50
+ Args:
51
+ s3_folder (str): Source folder in S3.
52
+ dest_folder (str): Local destination folder path.
53
+ """
54
+ dest_folder = Path(dest_folder)
55
+ paginator = self.s3.get_paginator("list_objects_v2")
56
+
57
+ for page in paginator.paginate(Bucket=self.bucket_name, Prefix=s3_folder):
58
+ for obj in page.get("Contents", []):
59
+ s3_path = obj["Key"]
60
+ local_path = dest_folder / Path(s3_path).relative_to(s3_folder)
61
+ local_path.parent.mkdir(parents=True, exist_ok=True)
62
+ self.s3.download_file(self.bucket_name, s3_path, str(local_path))
63
+ print(f"Downloaded: {s3_path} to {local_path}")
64
+
65
+
66
+ # Usage Example
67
+ if __name__ == "__main__":
68
+ # Initialize with bucket name
69
+ s3_handler = S3Handler(bucket_name="deep-bucket-s3")
70
+
71
+ # Upload specific files
72
+ s3_handler.upload_folder(
73
+ "checkpoints",
74
+ "checkpoints_test",
75
+ )
76
+
77
+ # Download example
78
+ s3_handler.download_folder("checkpoints_test", "checkpoints")