IliaLarchenko commited on
Commit
b1285d2
·
verified ·
1 Parent(s): d1620b2

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. README.md +38 -0
  3. config.json +48 -0
  4. config.yaml +235 -0
  5. model.safetensors +3 -0
  6. replay.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: lerobot
3
+ tags:
4
+ - model_hub_mixin
5
+ - pytorch_model_hub_mixin
6
+ - robotics
7
+ - dot
8
+ license: apache-2.0
9
+ datasets:
10
+ - lerobot/pusht_keypoints
11
+ pipeline_tag: robotics
12
+ ---
13
+
14
+ # Model Card for "Decoder Only Transformer (DOT) Policy" for PushT keypoints dataset
15
+
16
+ Read more about the model and implementation details in the [DOT Policy repository](https://github.com/IliaLarchenko/dot_policy).
17
+
18
+ This model is trained using the [LeRobot library](https://huggingface.co/lerobot) and achieves state-of-the-art results on behavior cloning on the PushT keypoints dataset. It achieves 94% success rate (and 0.985 average max reward) vs. ~78% for the previous state-of-the-art model or 69% that I managed to reproduce using VQ-BET implementation in LeRobot.
19
+
20
+ This is the best checkpoint for the model. These results are achievable assuming we have reliable validation and can select the best checkpoint based on the validation results (not always the case in robotics). If you are interested in more stable and reproducible results achievable without checkpoint selection, please refer to https://huggingface.co/IliaLarchenko/dot_pusht_keypoints
21
+
22
+ You can use this model by installing LeRobot from this branch: https://github.com/IliaLarchenko/lerobot/tree/dot
23
+
24
+ To train the model:
25
+
26
+ ```
27
+ python lerobot/scripts/train.py policy=dot_pusht_keypoints env=pusht env.obs_type=environment_state_agent_pos
28
+ ```
29
+
30
+ To evaluate the model:
31
+
32
+ ```bash
33
+ python lerobot/scripts/eval.py -p IliaLarchenko/dot_pusht_keypoints_best eval.n_episodes=1000 eval.batch_size=100 seed=1000000
34
+ ```
35
+
36
+ Model size:
37
+ - Total parameters: 2.1m
38
+ - Trainable parameters: 2.1m
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.75,
3
+ "crop_scale": 1.0,
4
+ "dim_feedforward": 512,
5
+ "dim_model": 128,
6
+ "dropout": 0.1,
7
+ "inference_horizon": 30,
8
+ "input_normalization_modes": {
9
+ "observation.environment_state": "min_max",
10
+ "observation.state": "min_max"
11
+ },
12
+ "input_shapes": {
13
+ "observation.environment_state": [
14
+ 16
15
+ ],
16
+ "observation.state": [
17
+ 2
18
+ ]
19
+ },
20
+ "lookback_aug": 5,
21
+ "lookback_obs_steps": 10,
22
+ "lora_rank": 20,
23
+ "merge_lora": true,
24
+ "n_decoder_layers": 8,
25
+ "n_heads": 8,
26
+ "n_obs_steps": 3,
27
+ "noise_decay": 0.999995,
28
+ "output_normalization_modes": {
29
+ "action": "min_max"
30
+ },
31
+ "output_shapes": {
32
+ "action": [
33
+ 2
34
+ ]
35
+ },
36
+ "pre_norm": true,
37
+ "predict_every_n": 1,
38
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
39
+ "rescale_shape": [
40
+ 96,
41
+ 96
42
+ ],
43
+ "return_every_n": 2,
44
+ "state_noise": 0.01,
45
+ "train_alpha": 0.9,
46
+ "train_horizon": 30,
47
+ "vision_backbone": "resnet18"
48
+ }
config.yaml ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resume: false
2
+ device: cuda
3
+ use_amp: true
4
+ seed: 100000
5
+ dataset_repo_id: lerobot/pusht_keypoints
6
+ video_backend: pyav
7
+ training:
8
+ offline_steps: 1000000
9
+ num_workers: 24
10
+ batch_size: 24
11
+ eval_freq: 10000
12
+ log_freq: 1000
13
+ save_checkpoint: true
14
+ save_freq: 50000
15
+ online_steps: 0
16
+ online_rollout_n_episodes: 1
17
+ online_rollout_batch_size: 1
18
+ online_steps_between_rollouts: 1
19
+ online_sampling_ratio: 0.5
20
+ online_env_seed: null
21
+ online_buffer_capacity: null
22
+ online_buffer_seed_size: 0
23
+ do_online_rollout_async: false
24
+ image_transforms:
25
+ enable: false
26
+ max_num_transforms: 3
27
+ random_order: false
28
+ brightness:
29
+ weight: 1
30
+ min_max:
31
+ - 0.8
32
+ - 1.2
33
+ contrast:
34
+ weight: 1
35
+ min_max:
36
+ - 0.8
37
+ - 1.2
38
+ saturation:
39
+ weight: 1
40
+ min_max:
41
+ - 0.5
42
+ - 1.5
43
+ hue:
44
+ weight: 1
45
+ min_max:
46
+ - -0.05
47
+ - 0.05
48
+ sharpness:
49
+ weight: 1
50
+ min_max:
51
+ - 0.8
52
+ - 1.2
53
+ save_model: true
54
+ grad_clip_norm: 50
55
+ lr: 0.0001
56
+ min_lr: 0.0001
57
+ lr_cycle_steps: 300000
58
+ weight_decay: 1.0e-05
59
+ delta_timestamps:
60
+ observation.environment_state:
61
+ - -1.5
62
+ - -1.4
63
+ - -1.3
64
+ - -1.2
65
+ - -1.1
66
+ - -1.0
67
+ - -0.9
68
+ - -0.8
69
+ - -0.7
70
+ - -0.6
71
+ - -0.5
72
+ - -0.1
73
+ - 0.0
74
+ observation.state:
75
+ - -1.5
76
+ - -1.4
77
+ - -1.3
78
+ - -1.2
79
+ - -1.1
80
+ - -1.0
81
+ - -0.9
82
+ - -0.8
83
+ - -0.7
84
+ - -0.6
85
+ - -0.5
86
+ - -0.1
87
+ - 0.0
88
+ action:
89
+ - -1.5
90
+ - -1.4
91
+ - -1.3
92
+ - -1.2
93
+ - -1.1
94
+ - -1.0
95
+ - -0.9
96
+ - -0.8
97
+ - -0.7
98
+ - -0.6
99
+ - -0.5
100
+ - -0.1
101
+ - 0.0
102
+ - 0.1
103
+ - 0.2
104
+ - 0.3
105
+ - 0.4
106
+ - 0.5
107
+ - 0.6
108
+ - 0.7
109
+ - 0.8
110
+ - 0.9
111
+ - 1.0
112
+ - 1.1
113
+ - 1.2
114
+ - 1.3
115
+ - 1.4
116
+ - 1.5
117
+ - 1.6
118
+ - 1.7
119
+ - 1.8
120
+ - 1.9
121
+ - 2.0
122
+ - 2.1
123
+ - 2.2
124
+ - 2.3
125
+ - 2.4
126
+ - 2.5
127
+ - 2.6
128
+ - 2.7
129
+ - 2.8
130
+ - 2.9
131
+ eval:
132
+ n_episodes: 100
133
+ batch_size: 100
134
+ use_async_envs: false
135
+ wandb:
136
+ enable: true
137
+ disable_artifact: false
138
+ project: pusht
139
+ notes: ''
140
+ fps: 10
141
+ env:
142
+ name: pusht
143
+ task: PushT-v0
144
+ image_size: 96
145
+ state_dim: 2
146
+ action_dim: 2
147
+ fps: ${fps}
148
+ episode_length: 300
149
+ gym:
150
+ obs_type: environment_state_agent_pos
151
+ render_mode: rgb_array
152
+ visualization_width: 384
153
+ visualization_height: 384
154
+ override_dataset_stats:
155
+ observation.environment_state:
156
+ min:
157
+ - 0.0
158
+ - 0.0
159
+ - 0.0
160
+ - 0.0
161
+ - 0.0
162
+ - 0.0
163
+ - 0.0
164
+ - 0.0
165
+ - 0.0
166
+ - 0.0
167
+ - 0.0
168
+ - 0.0
169
+ - 0.0
170
+ - 0.0
171
+ - 0.0
172
+ - 0.0
173
+ max:
174
+ - 512.0
175
+ - 512.0
176
+ - 512.0
177
+ - 512.0
178
+ - 512.0
179
+ - 512.0
180
+ - 512.0
181
+ - 512.0
182
+ - 512.0
183
+ - 512.0
184
+ - 512.0
185
+ - 512.0
186
+ - 512.0
187
+ - 512.0
188
+ - 512.0
189
+ - 512.0
190
+ observation.state:
191
+ min:
192
+ - 0.0
193
+ - 0.0
194
+ max:
195
+ - 512.0
196
+ - 512.0
197
+ action:
198
+ min:
199
+ - 0.0
200
+ - 0.0
201
+ max:
202
+ - 512.0
203
+ - 512.0
204
+ policy:
205
+ name: dot
206
+ n_obs_steps: 3
207
+ train_horizon: 30
208
+ inference_horizon: 30
209
+ lookback_obs_steps: 10
210
+ lookback_aug: 5
211
+ input_shapes:
212
+ observation.environment_state:
213
+ - 16
214
+ observation.state:
215
+ - ${env.state_dim}
216
+ output_shapes:
217
+ action:
218
+ - ${env.action_dim}
219
+ input_normalization_modes:
220
+ observation.environment_state: min_max
221
+ observation.state: min_max
222
+ output_normalization_modes:
223
+ action: min_max
224
+ state_noise: 0.01
225
+ noise_decay: 0.999995
226
+ pre_norm: true
227
+ dim_model: 128
228
+ n_heads: 8
229
+ dim_feedforward: 512
230
+ n_decoder_layers: 8
231
+ dropout: 0.1
232
+ alpha: 0.75
233
+ train_alpha: 0.9
234
+ predict_every_n: 1
235
+ return_every_n: 2
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d5e02a6c29abeaf8b44b1c78dc953b7bb8ce8983ed9491e7fe19eb24cfd6c94
3
+ size 8534312
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e32ed981aca2d1f85511b1f310421dc766c26aad6427ff45c32439e8975187e
3
+ size 135178