Spaces:
Runtime error
Runtime error
soutrik
commited on
Commit
·
4828471
1
Parent(s):
f057c2a
added new changes as per ResnetClassifier and tested with local and docker
Browse files- configs/experiment/catdog_experiment_resnet.yaml +4 -5
- configs/infer.yaml +1 -1
- configs/model/catdog_classifier_resnet.yaml +3 -6
- docker-compose.yaml +5 -38
- poetry.lock +75 -75
- pyproject.toml +1 -1
- src/infer.py +10 -3
- src/models/catdog_model_resnet.py +4 -2
- src/train_optuna_callbacks.py +11 -3
- src/utils/aws_s3_services.py +78 -0
configs/experiment/catdog_experiment_resnet.yaml
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
defaults:
|
7 |
- override /paths: catdog
|
8 |
- override /data: catdog
|
9 |
-
- override /model:
|
10 |
- override /callbacks: default
|
11 |
- override /logger: default
|
12 |
- override /trainer: default
|
@@ -15,7 +15,7 @@ defaults:
|
|
15 |
# this allows you to overwrite only specified parameters
|
16 |
|
17 |
seed: 42
|
18 |
-
name: "
|
19 |
|
20 |
# Logger-specific configurations
|
21 |
logger:
|
@@ -33,7 +33,7 @@ data:
|
|
33 |
image_size: 160
|
34 |
|
35 |
model:
|
36 |
-
base_model:
|
37 |
pretrained: True
|
38 |
lr: 1e-3
|
39 |
weight_decay: 1e-5
|
@@ -41,11 +41,10 @@ model:
|
|
41 |
patience: 5
|
42 |
min_lr: 1e-6
|
43 |
num_classes: 2
|
44 |
-
kernel_sizes: 7
|
45 |
|
46 |
trainer:
|
47 |
min_epochs: 1
|
48 |
-
max_epochs:
|
49 |
|
50 |
callbacks:
|
51 |
model_checkpoint:
|
|
|
6 |
defaults:
|
7 |
- override /paths: catdog
|
8 |
- override /data: catdog
|
9 |
+
- override /model: catdog_classifier_resnet
|
10 |
- override /callbacks: default
|
11 |
- override /logger: default
|
12 |
- override /trainer: default
|
|
|
15 |
# this allows you to overwrite only specified parameters
|
16 |
|
17 |
seed: 42
|
18 |
+
name: "catdog_experiment_resnet"
|
19 |
|
20 |
# Logger-specific configurations
|
21 |
logger:
|
|
|
33 |
image_size: 160
|
34 |
|
35 |
model:
|
36 |
+
base_model: efficientnet_b0
|
37 |
pretrained: True
|
38 |
lr: 1e-3
|
39 |
weight_decay: 1e-5
|
|
|
41 |
patience: 5
|
42 |
min_lr: 1e-6
|
43 |
num_classes: 2
|
|
|
44 |
|
45 |
trainer:
|
46 |
min_epochs: 1
|
47 |
+
max_epochs: 5
|
48 |
|
49 |
callbacks:
|
50 |
model_checkpoint:
|
configs/infer.yaml
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
defaults:
|
6 |
- _self_
|
7 |
- data: catdog
|
8 |
-
- model:
|
9 |
- callbacks: default
|
10 |
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
- trainer: default
|
|
|
5 |
defaults:
|
6 |
- _self_
|
7 |
- data: catdog
|
8 |
+
- model: catdog_classifier
|
9 |
- callbacks: default
|
10 |
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
- trainer: default
|
configs/model/catdog_classifier_resnet.yaml
CHANGED
@@ -1,16 +1,13 @@
|
|
1 |
-
|
2 |
-
_target_: src.models.catdog_classifier_convnext.ConvNextClassifier
|
3 |
|
4 |
# model params
|
5 |
-
base_model:
|
6 |
pretrained: True
|
7 |
num_classes: 2
|
8 |
-
kernel_sizes: 7
|
9 |
# optimizer params
|
10 |
lr: 1e-3
|
11 |
weight_decay: 1e-5
|
12 |
-
|
13 |
# scheduler params
|
14 |
factor: 0.1
|
15 |
patience: 10
|
16 |
-
min_lr: 1e-6
|
|
|
1 |
+
_target_: src.models.catdog_model_resnet.ResnetClassifier
|
|
|
2 |
|
3 |
# model params
|
4 |
+
base_model: efficientnet_b0
|
5 |
pretrained: True
|
6 |
num_classes: 2
|
|
|
7 |
# optimizer params
|
8 |
lr: 1e-3
|
9 |
weight_decay: 1e-5
|
|
|
10 |
# scheduler params
|
11 |
factor: 0.1
|
12 |
patience: 10
|
13 |
+
min_lr: 1e-6
|
docker-compose.yaml
CHANGED
@@ -3,7 +3,7 @@ services:
|
|
3 |
build:
|
4 |
context: .
|
5 |
command: |
|
6 |
-
python -m src.train_optuna_callbacks experiment=
|
7 |
python -m src.create_artifacts && \
|
8 |
touch ./checkpoints/train_done.flag
|
9 |
volumes:
|
@@ -31,7 +31,7 @@ services:
|
|
31 |
build:
|
32 |
context: .
|
33 |
command: |
|
34 |
-
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=
|
35 |
volumes:
|
36 |
- ./data:/app/data
|
37 |
- ./checkpoints:/app/checkpoints
|
@@ -52,13 +52,12 @@ services:
|
|
52 |
- driver: nvidia
|
53 |
count: 1
|
54 |
capabilities: [gpu]
|
55 |
-
|
56 |
|
57 |
-
|
58 |
build:
|
59 |
context: .
|
60 |
command: |
|
61 |
-
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.
|
62 |
volumes:
|
63 |
- ./data:/app/data
|
64 |
- ./checkpoints:/app/checkpoints
|
@@ -67,14 +66,11 @@ services:
|
|
67 |
environment:
|
68 |
- PYTHONUNBUFFERED=1
|
69 |
- PYTHONPATH=/app
|
70 |
-
- SERVER_URL=http://localhost:8080
|
71 |
shm_size: '4g'
|
72 |
networks:
|
73 |
- default
|
74 |
env_file:
|
75 |
- .env
|
76 |
-
ports:
|
77 |
-
- "8080:8080"
|
78 |
deploy:
|
79 |
resources:
|
80 |
reservations:
|
@@ -82,37 +78,8 @@ services:
|
|
82 |
- driver: nvidia
|
83 |
count: 1
|
84 |
capabilities: [gpu]
|
85 |
-
|
86 |
-
client:
|
87 |
-
build:
|
88 |
-
context: .
|
89 |
-
command: |
|
90 |
-
sh -c 'until curl -s http://server:8080/health; do echo "Waiting for server to be ready..."; sleep 5; done && \
|
91 |
-
./run_client.sh'
|
92 |
-
volumes:
|
93 |
-
- ./data:/app/data
|
94 |
-
- ./checkpoints:/app/checkpoints
|
95 |
-
- ./artifacts:/app/artifacts
|
96 |
-
- ./logs:/app/logs
|
97 |
-
environment:
|
98 |
-
- PYTHONUNBUFFERED=1
|
99 |
-
- PYTHONPATH=/app
|
100 |
-
- SERVER_URL=http://server:8080
|
101 |
-
shm_size: '4g'
|
102 |
-
networks:
|
103 |
-
- default
|
104 |
-
env_file:
|
105 |
-
- .env
|
106 |
-
|
107 |
-
deploy:
|
108 |
-
resources:
|
109 |
-
reservations:
|
110 |
-
devices:
|
111 |
-
- driver: nvidia
|
112 |
-
count: 1
|
113 |
-
capabilities: [gpu]
|
114 |
-
|
115 |
|
|
|
116 |
volumes:
|
117 |
data:
|
118 |
checkpoints:
|
|
|
3 |
build:
|
4 |
context: .
|
5 |
command: |
|
6 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment_resnet ++task_name=train ++train=True ++test=False && \
|
7 |
python -m src.create_artifacts && \
|
8 |
touch ./checkpoints/train_done.flag
|
9 |
volumes:
|
|
|
31 |
build:
|
32 |
context: .
|
33 |
command: |
|
34 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment_resnet ++task_name=test ++train=False ++test=True'
|
35 |
volumes:
|
36 |
- ./data:/app/data
|
37 |
- ./checkpoints:/app/checkpoints
|
|
|
52 |
- driver: nvidia
|
53 |
count: 1
|
54 |
capabilities: [gpu]
|
|
|
55 |
|
56 |
+
inference:
|
57 |
build:
|
58 |
context: .
|
59 |
command: |
|
60 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.infer experiment=catdog_experiment_resnet'
|
61 |
volumes:
|
62 |
- ./data:/app/data
|
63 |
- ./checkpoints:/app/checkpoints
|
|
|
66 |
environment:
|
67 |
- PYTHONUNBUFFERED=1
|
68 |
- PYTHONPATH=/app
|
|
|
69 |
shm_size: '4g'
|
70 |
networks:
|
71 |
- default
|
72 |
env_file:
|
73 |
- .env
|
|
|
|
|
74 |
deploy:
|
75 |
resources:
|
76 |
reservations:
|
|
|
78 |
- driver: nvidia
|
79 |
count: 1
|
80 |
capabilities: [gpu]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
|
83 |
volumes:
|
84 |
data:
|
85 |
checkpoints:
|
poetry.lock
CHANGED
@@ -3014,70 +3014,70 @@ test = ["objgraph", "psutil"]
|
|
3014 |
|
3015 |
[[package]]
|
3016 |
name = "grpcio"
|
3017 |
-
version = "1.68.
|
3018 |
description = "HTTP/2-based RPC framework"
|
3019 |
optional = false
|
3020 |
python-versions = ">=3.8"
|
3021 |
files = [
|
3022 |
-
{file = "grpcio-1.68.
|
3023 |
-
{file = "grpcio-1.68.
|
3024 |
-
{file = "grpcio-1.68.
|
3025 |
-
{file = "grpcio-1.68.
|
3026 |
-
{file = "grpcio-1.68.
|
3027 |
-
{file = "grpcio-1.68.
|
3028 |
-
{file = "grpcio-1.68.
|
3029 |
-
{file = "grpcio-1.68.
|
3030 |
-
{file = "grpcio-1.68.
|
3031 |
-
{file = "grpcio-1.68.
|
3032 |
-
{file = "grpcio-1.68.
|
3033 |
-
{file = "grpcio-1.68.
|
3034 |
-
{file = "grpcio-1.68.
|
3035 |
-
{file = "grpcio-1.68.
|
3036 |
-
{file = "grpcio-1.68.
|
3037 |
-
{file = "grpcio-1.68.
|
3038 |
-
{file = "grpcio-1.68.
|
3039 |
-
{file = "grpcio-1.68.
|
3040 |
-
{file = "grpcio-1.68.
|
3041 |
-
{file = "grpcio-1.68.
|
3042 |
-
{file = "grpcio-1.68.
|
3043 |
-
{file = "grpcio-1.68.
|
3044 |
-
{file = "grpcio-1.68.
|
3045 |
-
{file = "grpcio-1.68.
|
3046 |
-
{file = "grpcio-1.68.
|
3047 |
-
{file = "grpcio-1.68.
|
3048 |
-
{file = "grpcio-1.68.
|
3049 |
-
{file = "grpcio-1.68.
|
3050 |
-
{file = "grpcio-1.68.
|
3051 |
-
{file = "grpcio-1.68.
|
3052 |
-
{file = "grpcio-1.68.
|
3053 |
-
{file = "grpcio-1.68.
|
3054 |
-
{file = "grpcio-1.68.
|
3055 |
-
{file = "grpcio-1.68.
|
3056 |
-
{file = "grpcio-1.68.
|
3057 |
-
{file = "grpcio-1.68.
|
3058 |
-
{file = "grpcio-1.68.
|
3059 |
-
{file = "grpcio-1.68.
|
3060 |
-
{file = "grpcio-1.68.
|
3061 |
-
{file = "grpcio-1.68.
|
3062 |
-
{file = "grpcio-1.68.
|
3063 |
-
{file = "grpcio-1.68.
|
3064 |
-
{file = "grpcio-1.68.
|
3065 |
-
{file = "grpcio-1.68.
|
3066 |
-
{file = "grpcio-1.68.
|
3067 |
-
{file = "grpcio-1.68.
|
3068 |
-
{file = "grpcio-1.68.
|
3069 |
-
{file = "grpcio-1.68.
|
3070 |
-
{file = "grpcio-1.68.
|
3071 |
-
{file = "grpcio-1.68.
|
3072 |
-
{file = "grpcio-1.68.
|
3073 |
-
{file = "grpcio-1.68.
|
3074 |
-
{file = "grpcio-1.68.
|
3075 |
-
{file = "grpcio-1.68.
|
3076 |
-
{file = "grpcio-1.68.
|
3077 |
-
]
|
3078 |
-
|
3079 |
-
[package.extras]
|
3080 |
-
protobuf = ["grpcio-tools (>=1.68.
|
3081 |
|
3082 |
[[package]]
|
3083 |
name = "gto"
|
@@ -3600,13 +3600,13 @@ files = [
|
|
3600 |
|
3601 |
[[package]]
|
3602 |
name = "jsonargparse"
|
3603 |
-
version = "4.34.
|
3604 |
description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables."
|
3605 |
optional = false
|
3606 |
python-versions = ">=3.8"
|
3607 |
files = [
|
3608 |
-
{file = "jsonargparse-4.34.
|
3609 |
-
{file = "jsonargparse-4.34.
|
3610 |
]
|
3611 |
|
3612 |
[package.dependencies]
|
@@ -5744,13 +5744,13 @@ tests = ["chardet", "parameterized", "pytest", "pytest-cov", "pytest-xdist[psuti
|
|
5744 |
|
5745 |
[[package]]
|
5746 |
name = "pydrive2"
|
5747 |
-
version = "1.21.
|
5748 |
description = "Google Drive API made easy. Maintained fork of PyDrive."
|
5749 |
optional = false
|
5750 |
python-versions = ">=3.8"
|
5751 |
files = [
|
5752 |
-
{file = "PyDrive2-1.21.
|
5753 |
-
{file = "pydrive2-1.21.
|
5754 |
]
|
5755 |
|
5756 |
[package.dependencies]
|
@@ -5759,7 +5759,7 @@ fsspec = {version = ">=2021.07.0", optional = true, markers = "extra == \"fsspec
|
|
5759 |
funcy = {version = ">=1.14", optional = true, markers = "extra == \"fsspec\""}
|
5760 |
google-api-python-client = ">=1.12.5"
|
5761 |
oauth2client = ">=4.0.0"
|
5762 |
-
pyOpenSSL = ">=19.1.0
|
5763 |
PyYAML = ">=3.0"
|
5764 |
tqdm = {version = ">=4.0.0", optional = true, markers = "extra == \"fsspec\""}
|
5765 |
|
@@ -5877,21 +5877,21 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
|
|
5877 |
|
5878 |
[[package]]
|
5879 |
name = "pyopenssl"
|
5880 |
-
version = "
|
5881 |
description = "Python wrapper module around the OpenSSL library"
|
5882 |
optional = false
|
5883 |
-
python-versions = ">=3.
|
5884 |
files = [
|
5885 |
-
{file = "pyOpenSSL-
|
5886 |
-
{file = "
|
5887 |
]
|
5888 |
|
5889 |
[package.dependencies]
|
5890 |
-
cryptography = ">=
|
5891 |
|
5892 |
[package.extras]
|
5893 |
-
docs = ["sphinx", "
|
5894 |
-
test = ["
|
5895 |
|
5896 |
[[package]]
|
5897 |
name = "pyparsing"
|
@@ -8397,4 +8397,4 @@ type = ["pytest-mypy"]
|
|
8397 |
[metadata]
|
8398 |
lock-version = "2.0"
|
8399 |
python-versions = "3.10.15"
|
8400 |
-
content-hash = "
|
|
|
3014 |
|
3015 |
[[package]]
|
3016 |
name = "grpcio"
|
3017 |
+
version = "1.68.1"
|
3018 |
description = "HTTP/2-based RPC framework"
|
3019 |
optional = false
|
3020 |
python-versions = ">=3.8"
|
3021 |
files = [
|
3022 |
+
{file = "grpcio-1.68.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:d35740e3f45f60f3c37b1e6f2f4702c23867b9ce21c6410254c9c682237da68d"},
|
3023 |
+
{file = "grpcio-1.68.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:d99abcd61760ebb34bdff37e5a3ba333c5cc09feda8c1ad42547bea0416ada78"},
|
3024 |
+
{file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:f8261fa2a5f679abeb2a0a93ad056d765cdca1c47745eda3f2d87f874ff4b8c9"},
|
3025 |
+
{file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0feb02205a27caca128627bd1df4ee7212db051019a9afa76f4bb6a1a80ca95e"},
|
3026 |
+
{file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919d7f18f63bcad3a0f81146188e90274fde800a94e35d42ffe9eadf6a9a6330"},
|
3027 |
+
{file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:963cc8d7d79b12c56008aabd8b457f400952dbea8997dd185f155e2f228db079"},
|
3028 |
+
{file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ccf2ebd2de2d6661e2520dae293298a3803a98ebfc099275f113ce1f6c2a80f1"},
|
3029 |
+
{file = "grpcio-1.68.1-cp310-cp310-win32.whl", hash = "sha256:2cc1fd04af8399971bcd4f43bd98c22d01029ea2e56e69c34daf2bf8470e47f5"},
|
3030 |
+
{file = "grpcio-1.68.1-cp310-cp310-win_amd64.whl", hash = "sha256:ee2e743e51cb964b4975de572aa8fb95b633f496f9fcb5e257893df3be854746"},
|
3031 |
+
{file = "grpcio-1.68.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:55857c71641064f01ff0541a1776bfe04a59db5558e82897d35a7793e525774c"},
|
3032 |
+
{file = "grpcio-1.68.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4b177f5547f1b995826ef529d2eef89cca2f830dd8b2c99ffd5fde4da734ba73"},
|
3033 |
+
{file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:3522c77d7e6606d6665ec8d50e867f13f946a4e00c7df46768f1c85089eae515"},
|
3034 |
+
{file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d1fae6bbf0816415b81db1e82fb3bf56f7857273c84dcbe68cbe046e58e1ccd"},
|
3035 |
+
{file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:298ee7f80e26f9483f0b6f94cc0a046caf54400a11b644713bb5b3d8eb387600"},
|
3036 |
+
{file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb5780e2e740b6b4f2d208e90453591036ff80c02cc605fea1af8e6fc6b1bbe"},
|
3037 |
+
{file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ddda1aa22495d8acd9dfbafff2866438d12faec4d024ebc2e656784d96328ad0"},
|
3038 |
+
{file = "grpcio-1.68.1-cp311-cp311-win32.whl", hash = "sha256:b33bd114fa5a83f03ec6b7b262ef9f5cac549d4126f1dc702078767b10c46ed9"},
|
3039 |
+
{file = "grpcio-1.68.1-cp311-cp311-win_amd64.whl", hash = "sha256:7f20ebec257af55694d8f993e162ddf0d36bd82d4e57f74b31c67b3c6d63d8b2"},
|
3040 |
+
{file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"},
|
3041 |
+
{file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"},
|
3042 |
+
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"},
|
3043 |
+
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"},
|
3044 |
+
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"},
|
3045 |
+
{file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"},
|
3046 |
+
{file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"},
|
3047 |
+
{file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"},
|
3048 |
+
{file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"},
|
3049 |
+
{file = "grpcio-1.68.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:a47faedc9ea2e7a3b6569795c040aae5895a19dde0c728a48d3c5d7995fda385"},
|
3050 |
+
{file = "grpcio-1.68.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:390eee4225a661c5cd133c09f5da1ee3c84498dc265fd292a6912b65c421c78c"},
|
3051 |
+
{file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:66a24f3d45c33550703f0abb8b656515b0ab777970fa275693a2f6dc8e35f1c1"},
|
3052 |
+
{file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c08079b4934b0bf0a8847f42c197b1d12cba6495a3d43febd7e99ecd1cdc8d54"},
|
3053 |
+
{file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8720c25cd9ac25dd04ee02b69256d0ce35bf8a0f29e20577427355272230965a"},
|
3054 |
+
{file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:04cfd68bf4f38f5bb959ee2361a7546916bd9a50f78617a346b3aeb2b42e2161"},
|
3055 |
+
{file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c28848761a6520c5c6071d2904a18d339a796ebe6b800adc8b3f474c5ce3c3ad"},
|
3056 |
+
{file = "grpcio-1.68.1-cp313-cp313-win32.whl", hash = "sha256:77d65165fc35cff6e954e7fd4229e05ec76102d4406d4576528d3a3635fc6172"},
|
3057 |
+
{file = "grpcio-1.68.1-cp313-cp313-win_amd64.whl", hash = "sha256:a8040f85dcb9830d8bbb033ae66d272614cec6faceee88d37a88a9bd1a7a704e"},
|
3058 |
+
{file = "grpcio-1.68.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:eeb38ff04ab6e5756a2aef6ad8d94e89bb4a51ef96e20f45c44ba190fa0bcaad"},
|
3059 |
+
{file = "grpcio-1.68.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a3869a6661ec8f81d93f4597da50336718bde9eb13267a699ac7e0a1d6d0bea"},
|
3060 |
+
{file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:2c4cec6177bf325eb6faa6bd834d2ff6aa8bb3b29012cceb4937b86f8b74323c"},
|
3061 |
+
{file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12941d533f3cd45d46f202e3667be8ebf6bcb3573629c7ec12c3e211d99cfccf"},
|
3062 |
+
{file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80af6f1e69c5e68a2be529990684abdd31ed6622e988bf18850075c81bb1ad6e"},
|
3063 |
+
{file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e8dbe3e00771bfe3d04feed8210fc6617006d06d9a2679b74605b9fed3e8362c"},
|
3064 |
+
{file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:83bbf5807dc3ee94ce1de2dfe8a356e1d74101e4b9d7aa8c720cc4818a34aded"},
|
3065 |
+
{file = "grpcio-1.68.1-cp38-cp38-win32.whl", hash = "sha256:8cb620037a2fd9eeee97b4531880e439ebfcd6d7d78f2e7dcc3726428ab5ef63"},
|
3066 |
+
{file = "grpcio-1.68.1-cp38-cp38-win_amd64.whl", hash = "sha256:52fbf85aa71263380d330f4fce9f013c0798242e31ede05fcee7fbe40ccfc20d"},
|
3067 |
+
{file = "grpcio-1.68.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb400138e73969eb5e0535d1d06cae6a6f7a15f2cc74add320e2130b8179211a"},
|
3068 |
+
{file = "grpcio-1.68.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a1b988b40f2fd9de5c820f3a701a43339d8dcf2cb2f1ca137e2c02671cc83ac1"},
|
3069 |
+
{file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:96f473cdacfdd506008a5d7579c9f6a7ff245a9ade92c3c0265eb76cc591914f"},
|
3070 |
+
{file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37ea3be171f3cf3e7b7e412a98b77685eba9d4fd67421f4a34686a63a65d99f9"},
|
3071 |
+
{file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ceb56c4285754e33bb3c2fa777d055e96e6932351a3082ce3559be47f8024f0"},
|
3072 |
+
{file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dffd29a2961f3263a16d73945b57cd44a8fd0b235740cb14056f0612329b345e"},
|
3073 |
+
{file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:025f790c056815b3bf53da850dd70ebb849fd755a4b1ac822cb65cd631e37d43"},
|
3074 |
+
{file = "grpcio-1.68.1-cp39-cp39-win32.whl", hash = "sha256:1098f03dedc3b9810810568060dea4ac0822b4062f537b0f53aa015269be0a76"},
|
3075 |
+
{file = "grpcio-1.68.1-cp39-cp39-win_amd64.whl", hash = "sha256:334ab917792904245a028f10e803fcd5b6f36a7b2173a820c0b5b076555825e1"},
|
3076 |
+
{file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"},
|
3077 |
+
]
|
3078 |
+
|
3079 |
+
[package.extras]
|
3080 |
+
protobuf = ["grpcio-tools (>=1.68.1)"]
|
3081 |
|
3082 |
[[package]]
|
3083 |
name = "gto"
|
|
|
3600 |
|
3601 |
[[package]]
|
3602 |
name = "jsonargparse"
|
3603 |
+
version = "4.34.1"
|
3604 |
description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables."
|
3605 |
optional = false
|
3606 |
python-versions = ">=3.8"
|
3607 |
files = [
|
3608 |
+
{file = "jsonargparse-4.34.1-py3-none-any.whl", hash = "sha256:1d595080e080d4581ef821b8ec71364037052be4da0f9c7bdafe95962672cee6"},
|
3609 |
+
{file = "jsonargparse-4.34.1.tar.gz", hash = "sha256:6e0d1ab67b12b1086fe7bbe5ba429fbe9865c36493c9bb6aeb1e243047fdb58c"},
|
3610 |
]
|
3611 |
|
3612 |
[package.dependencies]
|
|
|
5744 |
|
5745 |
[[package]]
|
5746 |
name = "pydrive2"
|
5747 |
+
version = "1.21.1"
|
5748 |
description = "Google Drive API made easy. Maintained fork of PyDrive."
|
5749 |
optional = false
|
5750 |
python-versions = ">=3.8"
|
5751 |
files = [
|
5752 |
+
{file = "PyDrive2-1.21.1-py3-none-any.whl", hash = "sha256:d24b3334bc5c242e5ec58ad6ee7efbd2216aa92098c3eed353ce644f27a7e97b"},
|
5753 |
+
{file = "pydrive2-1.21.1.tar.gz", hash = "sha256:70da0244a29a6922e28620a32e251ac6ab018449f1bb0485e9a39114a069dde0"},
|
5754 |
]
|
5755 |
|
5756 |
[package.dependencies]
|
|
|
5759 |
funcy = {version = ">=1.14", optional = true, markers = "extra == \"fsspec\""}
|
5760 |
google-api-python-client = ">=1.12.5"
|
5761 |
oauth2client = ">=4.0.0"
|
5762 |
+
pyOpenSSL = ">=19.1.0"
|
5763 |
PyYAML = ">=3.0"
|
5764 |
tqdm = {version = ">=4.0.0", optional = true, markers = "extra == \"fsspec\""}
|
5765 |
|
|
|
5877 |
|
5878 |
[[package]]
|
5879 |
name = "pyopenssl"
|
5880 |
+
version = "24.3.0"
|
5881 |
description = "Python wrapper module around the OpenSSL library"
|
5882 |
optional = false
|
5883 |
+
python-versions = ">=3.7"
|
5884 |
files = [
|
5885 |
+
{file = "pyOpenSSL-24.3.0-py3-none-any.whl", hash = "sha256:e474f5a473cd7f92221cc04976e48f4d11502804657a08a989fb3be5514c904a"},
|
5886 |
+
{file = "pyopenssl-24.3.0.tar.gz", hash = "sha256:49f7a019577d834746bc55c5fce6ecbcec0f2b4ec5ce1cf43a9a173b8138bb36"},
|
5887 |
]
|
5888 |
|
5889 |
[package.dependencies]
|
5890 |
+
cryptography = ">=41.0.5,<45"
|
5891 |
|
5892 |
[package.extras]
|
5893 |
+
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx_rtd_theme"]
|
5894 |
+
test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
|
5895 |
|
5896 |
[[package]]
|
5897 |
name = "pyparsing"
|
|
|
8397 |
[metadata]
|
8398 |
lock-version = "2.0"
|
8399 |
python-versions = "3.10.15"
|
8400 |
+
content-hash = "78d6a9383381dfba8ea585d06d83d833b85434561a2c3f39d8eda83e5e1697d9"
|
pyproject.toml
CHANGED
@@ -77,9 +77,9 @@ nvitop = "^1.3.2"
|
|
77 |
gradio = "5.7.1"
|
78 |
gradio-client = "^1.5.0"
|
79 |
accelerate = "^1.1.1"
|
80 |
-
pyopenssl = "<23.0.0"
|
81 |
cryptography = "^44.0.0"
|
82 |
boto3 = "*"
|
|
|
83 |
|
84 |
[tool.poetry.dev-dependencies]
|
85 |
pytest-asyncio = "^0.20.3"
|
|
|
77 |
gradio = "5.7.1"
|
78 |
gradio-client = "^1.5.0"
|
79 |
accelerate = "^1.1.1"
|
|
|
80 |
cryptography = "^44.0.0"
|
81 |
boto3 = "*"
|
82 |
+
pyopenssl = "^24.3.0"
|
83 |
|
84 |
[tool.poetry.dev-dependencies]
|
85 |
pytest-asyncio = "^0.20.3"
|
src/infer.py
CHANGED
@@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
from PIL import Image
|
7 |
from torchvision import transforms
|
8 |
-
from src.models.
|
9 |
from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
|
10 |
import hydra
|
11 |
from omegaconf import DictConfig, OmegaConf
|
@@ -13,6 +13,7 @@ from dotenv import load_dotenv, find_dotenv
|
|
13 |
import rootutils
|
14 |
import time
|
15 |
from loguru import logger
|
|
|
16 |
|
17 |
# Load environment variables
|
18 |
load_dotenv(find_dotenv(".env"))
|
@@ -93,9 +94,15 @@ def main_infer(cfg: DictConfig):
|
|
93 |
if flag_file.exists():
|
94 |
flag_file.unlink()
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# Load the trained model
|
97 |
-
model =
|
98 |
-
classes =
|
99 |
|
100 |
# Download an image for inference
|
101 |
download_image(cfg)
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
from PIL import Image
|
7 |
from torchvision import transforms
|
8 |
+
from src.models.catdog_model_resnet import ResnetClassifier
|
9 |
from src.utils.logging_utils import setup_logger, task_wrapper, get_rich_progress
|
10 |
import hydra
|
11 |
from omegaconf import DictConfig, OmegaConf
|
|
|
13 |
import rootutils
|
14 |
import time
|
15 |
from loguru import logger
|
16 |
+
from src.utils.aws_s3_services import S3Handler
|
17 |
|
18 |
# Load environment variables
|
19 |
load_dotenv(find_dotenv(".env"))
|
|
|
94 |
if flag_file.exists():
|
95 |
flag_file.unlink()
|
96 |
|
97 |
+
# download the model from S3
|
98 |
+
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
|
99 |
+
s3_handler.download_folder(
|
100 |
+
"checkpoints",
|
101 |
+
"checkpoints",
|
102 |
+
)
|
103 |
# Load the trained model
|
104 |
+
model = ResnetClassifier.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
|
105 |
+
classes = cfg.labels
|
106 |
|
107 |
# Download an image for inference
|
108 |
download_image(cfg)
|
src/models/catdog_model_resnet.py
CHANGED
@@ -9,6 +9,8 @@ import torch
|
|
9 |
class ResnetClassifier(L.LightningModule):
|
10 |
def __init__(
|
11 |
self,
|
|
|
|
|
12 |
num_classes: int = 2, # Binary classification with two classes
|
13 |
lr: float = 1e-3,
|
14 |
weight_decay: float = 1e-5,
|
@@ -21,7 +23,7 @@ class ResnetClassifier(L.LightningModule):
|
|
21 |
|
22 |
# Vision Transformer model initialization
|
23 |
self.model = timm.create_model(
|
24 |
-
|
25 |
)
|
26 |
|
27 |
# Define accuracy and F1 metrics for binary classification
|
@@ -90,5 +92,5 @@ class ResnetClassifier(L.LightningModule):
|
|
90 |
|
91 |
|
92 |
if __name__ == "__main__":
|
93 |
-
model =
|
94 |
print(model)
|
|
|
9 |
class ResnetClassifier(L.LightningModule):
|
10 |
def __init__(
|
11 |
self,
|
12 |
+
base_model: str = "efficientnet_b0",
|
13 |
+
pretrained: bool = True,
|
14 |
num_classes: int = 2, # Binary classification with two classes
|
15 |
lr: float = 1e-3,
|
16 |
weight_decay: float = 1e-5,
|
|
|
23 |
|
24 |
# Vision Transformer model initialization
|
25 |
self.model = timm.create_model(
|
26 |
+
base_model, pretrained=pretrained, num_classes=num_classes
|
27 |
)
|
28 |
|
29 |
# Define accuracy and F1 metrics for binary classification
|
|
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
95 |
+
model = ResnetClassifier()
|
96 |
print(model)
|
src/train_optuna_callbacks.py
CHANGED
@@ -19,6 +19,7 @@ from lightning.pytorch.callbacks import Callback
|
|
19 |
import optuna
|
20 |
from lightning.pytorch import Trainer
|
21 |
import json
|
|
|
22 |
|
23 |
# Load environment variables
|
24 |
load_dotenv(find_dotenv(".env"))
|
@@ -131,10 +132,10 @@ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[Callba
|
|
131 |
"""Objective function for Optuna hyperparameter tuning."""
|
132 |
|
133 |
# Sample hyperparameters for the model
|
134 |
-
cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
|
135 |
-
cfg.model.depth = trial.suggest_int("depth", 2, 6)
|
136 |
cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
|
137 |
-
cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
|
138 |
|
139 |
# Initialize data module and model
|
140 |
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
@@ -207,6 +208,13 @@ def setup_trainer(cfg: DictConfig):
|
|
207 |
with open(Path(cfg.paths.ckpt_dir) / "train_done.flag", "w") as f:
|
208 |
f.write("Training completed successfully!")
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
# Testing phase with best hyperparameters
|
211 |
if cfg.get("test", False):
|
212 |
best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
|
|
|
19 |
import optuna
|
20 |
from lightning.pytorch import Trainer
|
21 |
import json
|
22 |
+
from src.utils.aws_s3_services import S3Handler
|
23 |
|
24 |
# Load environment variables
|
25 |
load_dotenv(find_dotenv(".env"))
|
|
|
132 |
"""Objective function for Optuna hyperparameter tuning."""
|
133 |
|
134 |
# Sample hyperparameters for the model
|
135 |
+
# cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
|
136 |
+
# cfg.model.depth = trial.suggest_int("depth", 2, 6)
|
137 |
cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
|
138 |
+
# cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
|
139 |
|
140 |
# Initialize data module and model
|
141 |
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
|
208 |
with open(Path(cfg.paths.ckpt_dir) / "train_done.flag", "w") as f:
|
209 |
f.write("Training completed successfully!")
|
210 |
|
211 |
+
# upload the checkpoints to S3
|
212 |
+
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
|
213 |
+
s3_handler.upload_folder(
|
214 |
+
"checkpoints",
|
215 |
+
"checkpoints",
|
216 |
+
)
|
217 |
+
|
218 |
# Testing phase with best hyperparameters
|
219 |
if cfg.get("test", False):
|
220 |
best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
|
src/utils/aws_s3_services.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from dotenv import load_dotenv, find_dotenv
|
5 |
+
|
6 |
+
# Load environment variables from .env file
|
7 |
+
load_dotenv(find_dotenv(".env"))
|
8 |
+
|
9 |
+
|
10 |
+
class S3Handler:
|
11 |
+
def __init__(self, bucket_name):
|
12 |
+
self.bucket_name = bucket_name
|
13 |
+
self.s3 = boto3.client(
|
14 |
+
"s3",
|
15 |
+
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
16 |
+
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
17 |
+
region_name=os.getenv("AWS_REGION"),
|
18 |
+
)
|
19 |
+
|
20 |
+
def upload_folder(self, source_folder, dest_folder, filenames=None):
|
21 |
+
"""
|
22 |
+
Upload specified files or all files from a local folder to an S3 folder.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
source_folder (str): Local source folder path.
|
26 |
+
dest_folder (str): Destination folder path in S3.
|
27 |
+
filenames (list): List of filenames to upload (relative to source_folder). If None, uploads all files.
|
28 |
+
"""
|
29 |
+
source_folder = Path(source_folder)
|
30 |
+
|
31 |
+
# Select files based on filenames list or all files if filenames is None
|
32 |
+
files_to_upload = (
|
33 |
+
[source_folder / file for file in filenames]
|
34 |
+
if filenames
|
35 |
+
else list(source_folder.rglob("*"))
|
36 |
+
)
|
37 |
+
|
38 |
+
for file_path in files_to_upload:
|
39 |
+
if file_path.is_file():
|
40 |
+
s3_path = f"{dest_folder}/{file_path.relative_to(source_folder)}"
|
41 |
+
self.s3.upload_file(str(file_path), self.bucket_name, s3_path)
|
42 |
+
print(f"Uploaded: {file_path} to {s3_path}")
|
43 |
+
else:
|
44 |
+
print(f"File not found: {file_path}")
|
45 |
+
|
46 |
+
def download_folder(self, s3_folder, dest_folder):
|
47 |
+
"""
|
48 |
+
Download all files from an S3 folder to a local folder.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
s3_folder (str): Source folder in S3.
|
52 |
+
dest_folder (str): Local destination folder path.
|
53 |
+
"""
|
54 |
+
dest_folder = Path(dest_folder)
|
55 |
+
paginator = self.s3.get_paginator("list_objects_v2")
|
56 |
+
|
57 |
+
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=s3_folder):
|
58 |
+
for obj in page.get("Contents", []):
|
59 |
+
s3_path = obj["Key"]
|
60 |
+
local_path = dest_folder / Path(s3_path).relative_to(s3_folder)
|
61 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
62 |
+
self.s3.download_file(self.bucket_name, s3_path, str(local_path))
|
63 |
+
print(f"Downloaded: {s3_path} to {local_path}")
|
64 |
+
|
65 |
+
|
66 |
+
# Usage Example
|
67 |
+
if __name__ == "__main__":
|
68 |
+
# Initialize with bucket name
|
69 |
+
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
|
70 |
+
|
71 |
+
# Upload specific files
|
72 |
+
s3_handler.upload_folder(
|
73 |
+
"checkpoints",
|
74 |
+
"checkpoints_test",
|
75 |
+
)
|
76 |
+
|
77 |
+
# Download example
|
78 |
+
s3_handler.download_folder("checkpoints_test", "checkpoints")
|