Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +38 -0
- config.json +48 -0
- config.yaml +235 -0
- model.safetensors +3 -0
- replay.mp4 +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
replay.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: lerobot
|
3 |
+
tags:
|
4 |
+
- model_hub_mixin
|
5 |
+
- pytorch_model_hub_mixin
|
6 |
+
- robotics
|
7 |
+
- dot
|
8 |
+
license: apache-2.0
|
9 |
+
datasets:
|
10 |
+
- lerobot/pusht_keypoints
|
11 |
+
pipeline_tag: robotics
|
12 |
+
---
|
13 |
+
|
14 |
+
# Model Card for "Decoder Only Transformer (DOT) Policy" for PushT keypoints dataset
|
15 |
+
|
16 |
+
Read more about the model and implementation details in the [DOT Policy repository](https://github.com/IliaLarchenko/dot_policy).
|
17 |
+
|
18 |
+
This model is trained using the [LeRobot library](https://huggingface.co/lerobot) and achieves state-of-the-art results on behavior cloning on the PushT keypoints dataset. It achieves 94% success rate (and 0.985 average max reward) vs. ~78% for the previous state-of-the-art model or 69% that I managed to reproduce using VQ-BET implementation in LeRobot.
|
19 |
+
|
20 |
+
This is the best checkpoint for the model. These results are achievable assuming we have reliable validation and can select the best checkpoint based on the validation results (not always the case in robotics). If you are interested in more stable and reproducible results achievable without checkpoint selection, please refer to https://huggingface.co/IliaLarchenko/dot_pusht_keypoints
|
21 |
+
|
22 |
+
You can use this model by installing LeRobot from this branch: https://github.com/IliaLarchenko/lerobot/tree/dot
|
23 |
+
|
24 |
+
To train the model:
|
25 |
+
|
26 |
+
```
|
27 |
+
python lerobot/scripts/train.py policy=dot_pusht_keypoints env=pusht env.obs_type=environment_state_agent_pos
|
28 |
+
```
|
29 |
+
|
30 |
+
To evaluate the model:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
python lerobot/scripts/eval.py -p IliaLarchenko/dot_pusht_keypoints_best eval.n_episodes=1000 eval.batch_size=100 seed=1000000
|
34 |
+
```
|
35 |
+
|
36 |
+
Model size:
|
37 |
+
- Total parameters: 2.1m
|
38 |
+
- Trainable parameters: 2.1m
|
config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.75,
|
3 |
+
"crop_scale": 1.0,
|
4 |
+
"dim_feedforward": 512,
|
5 |
+
"dim_model": 128,
|
6 |
+
"dropout": 0.1,
|
7 |
+
"inference_horizon": 30,
|
8 |
+
"input_normalization_modes": {
|
9 |
+
"observation.environment_state": "min_max",
|
10 |
+
"observation.state": "min_max"
|
11 |
+
},
|
12 |
+
"input_shapes": {
|
13 |
+
"observation.environment_state": [
|
14 |
+
16
|
15 |
+
],
|
16 |
+
"observation.state": [
|
17 |
+
2
|
18 |
+
]
|
19 |
+
},
|
20 |
+
"lookback_aug": 5,
|
21 |
+
"lookback_obs_steps": 10,
|
22 |
+
"lora_rank": 20,
|
23 |
+
"merge_lora": true,
|
24 |
+
"n_decoder_layers": 8,
|
25 |
+
"n_heads": 8,
|
26 |
+
"n_obs_steps": 3,
|
27 |
+
"noise_decay": 0.999995,
|
28 |
+
"output_normalization_modes": {
|
29 |
+
"action": "min_max"
|
30 |
+
},
|
31 |
+
"output_shapes": {
|
32 |
+
"action": [
|
33 |
+
2
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"pre_norm": true,
|
37 |
+
"predict_every_n": 1,
|
38 |
+
"pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
|
39 |
+
"rescale_shape": [
|
40 |
+
96,
|
41 |
+
96
|
42 |
+
],
|
43 |
+
"return_every_n": 2,
|
44 |
+
"state_noise": 0.01,
|
45 |
+
"train_alpha": 0.9,
|
46 |
+
"train_horizon": 30,
|
47 |
+
"vision_backbone": "resnet18"
|
48 |
+
}
|
config.yaml
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
resume: false
|
2 |
+
device: cuda
|
3 |
+
use_amp: true
|
4 |
+
seed: 100000
|
5 |
+
dataset_repo_id: lerobot/pusht_keypoints
|
6 |
+
video_backend: pyav
|
7 |
+
training:
|
8 |
+
offline_steps: 1000000
|
9 |
+
num_workers: 24
|
10 |
+
batch_size: 24
|
11 |
+
eval_freq: 10000
|
12 |
+
log_freq: 1000
|
13 |
+
save_checkpoint: true
|
14 |
+
save_freq: 50000
|
15 |
+
online_steps: 0
|
16 |
+
online_rollout_n_episodes: 1
|
17 |
+
online_rollout_batch_size: 1
|
18 |
+
online_steps_between_rollouts: 1
|
19 |
+
online_sampling_ratio: 0.5
|
20 |
+
online_env_seed: null
|
21 |
+
online_buffer_capacity: null
|
22 |
+
online_buffer_seed_size: 0
|
23 |
+
do_online_rollout_async: false
|
24 |
+
image_transforms:
|
25 |
+
enable: false
|
26 |
+
max_num_transforms: 3
|
27 |
+
random_order: false
|
28 |
+
brightness:
|
29 |
+
weight: 1
|
30 |
+
min_max:
|
31 |
+
- 0.8
|
32 |
+
- 1.2
|
33 |
+
contrast:
|
34 |
+
weight: 1
|
35 |
+
min_max:
|
36 |
+
- 0.8
|
37 |
+
- 1.2
|
38 |
+
saturation:
|
39 |
+
weight: 1
|
40 |
+
min_max:
|
41 |
+
- 0.5
|
42 |
+
- 1.5
|
43 |
+
hue:
|
44 |
+
weight: 1
|
45 |
+
min_max:
|
46 |
+
- -0.05
|
47 |
+
- 0.05
|
48 |
+
sharpness:
|
49 |
+
weight: 1
|
50 |
+
min_max:
|
51 |
+
- 0.8
|
52 |
+
- 1.2
|
53 |
+
save_model: true
|
54 |
+
grad_clip_norm: 50
|
55 |
+
lr: 0.0001
|
56 |
+
min_lr: 0.0001
|
57 |
+
lr_cycle_steps: 300000
|
58 |
+
weight_decay: 1.0e-05
|
59 |
+
delta_timestamps:
|
60 |
+
observation.environment_state:
|
61 |
+
- -1.5
|
62 |
+
- -1.4
|
63 |
+
- -1.3
|
64 |
+
- -1.2
|
65 |
+
- -1.1
|
66 |
+
- -1.0
|
67 |
+
- -0.9
|
68 |
+
- -0.8
|
69 |
+
- -0.7
|
70 |
+
- -0.6
|
71 |
+
- -0.5
|
72 |
+
- -0.1
|
73 |
+
- 0.0
|
74 |
+
observation.state:
|
75 |
+
- -1.5
|
76 |
+
- -1.4
|
77 |
+
- -1.3
|
78 |
+
- -1.2
|
79 |
+
- -1.1
|
80 |
+
- -1.0
|
81 |
+
- -0.9
|
82 |
+
- -0.8
|
83 |
+
- -0.7
|
84 |
+
- -0.6
|
85 |
+
- -0.5
|
86 |
+
- -0.1
|
87 |
+
- 0.0
|
88 |
+
action:
|
89 |
+
- -1.5
|
90 |
+
- -1.4
|
91 |
+
- -1.3
|
92 |
+
- -1.2
|
93 |
+
- -1.1
|
94 |
+
- -1.0
|
95 |
+
- -0.9
|
96 |
+
- -0.8
|
97 |
+
- -0.7
|
98 |
+
- -0.6
|
99 |
+
- -0.5
|
100 |
+
- -0.1
|
101 |
+
- 0.0
|
102 |
+
- 0.1
|
103 |
+
- 0.2
|
104 |
+
- 0.3
|
105 |
+
- 0.4
|
106 |
+
- 0.5
|
107 |
+
- 0.6
|
108 |
+
- 0.7
|
109 |
+
- 0.8
|
110 |
+
- 0.9
|
111 |
+
- 1.0
|
112 |
+
- 1.1
|
113 |
+
- 1.2
|
114 |
+
- 1.3
|
115 |
+
- 1.4
|
116 |
+
- 1.5
|
117 |
+
- 1.6
|
118 |
+
- 1.7
|
119 |
+
- 1.8
|
120 |
+
- 1.9
|
121 |
+
- 2.0
|
122 |
+
- 2.1
|
123 |
+
- 2.2
|
124 |
+
- 2.3
|
125 |
+
- 2.4
|
126 |
+
- 2.5
|
127 |
+
- 2.6
|
128 |
+
- 2.7
|
129 |
+
- 2.8
|
130 |
+
- 2.9
|
131 |
+
eval:
|
132 |
+
n_episodes: 100
|
133 |
+
batch_size: 100
|
134 |
+
use_async_envs: false
|
135 |
+
wandb:
|
136 |
+
enable: true
|
137 |
+
disable_artifact: false
|
138 |
+
project: pusht
|
139 |
+
notes: ''
|
140 |
+
fps: 10
|
141 |
+
env:
|
142 |
+
name: pusht
|
143 |
+
task: PushT-v0
|
144 |
+
image_size: 96
|
145 |
+
state_dim: 2
|
146 |
+
action_dim: 2
|
147 |
+
fps: ${fps}
|
148 |
+
episode_length: 300
|
149 |
+
gym:
|
150 |
+
obs_type: environment_state_agent_pos
|
151 |
+
render_mode: rgb_array
|
152 |
+
visualization_width: 384
|
153 |
+
visualization_height: 384
|
154 |
+
override_dataset_stats:
|
155 |
+
observation.environment_state:
|
156 |
+
min:
|
157 |
+
- 0.0
|
158 |
+
- 0.0
|
159 |
+
- 0.0
|
160 |
+
- 0.0
|
161 |
+
- 0.0
|
162 |
+
- 0.0
|
163 |
+
- 0.0
|
164 |
+
- 0.0
|
165 |
+
- 0.0
|
166 |
+
- 0.0
|
167 |
+
- 0.0
|
168 |
+
- 0.0
|
169 |
+
- 0.0
|
170 |
+
- 0.0
|
171 |
+
- 0.0
|
172 |
+
- 0.0
|
173 |
+
max:
|
174 |
+
- 512.0
|
175 |
+
- 512.0
|
176 |
+
- 512.0
|
177 |
+
- 512.0
|
178 |
+
- 512.0
|
179 |
+
- 512.0
|
180 |
+
- 512.0
|
181 |
+
- 512.0
|
182 |
+
- 512.0
|
183 |
+
- 512.0
|
184 |
+
- 512.0
|
185 |
+
- 512.0
|
186 |
+
- 512.0
|
187 |
+
- 512.0
|
188 |
+
- 512.0
|
189 |
+
- 512.0
|
190 |
+
observation.state:
|
191 |
+
min:
|
192 |
+
- 0.0
|
193 |
+
- 0.0
|
194 |
+
max:
|
195 |
+
- 512.0
|
196 |
+
- 512.0
|
197 |
+
action:
|
198 |
+
min:
|
199 |
+
- 0.0
|
200 |
+
- 0.0
|
201 |
+
max:
|
202 |
+
- 512.0
|
203 |
+
- 512.0
|
204 |
+
policy:
|
205 |
+
name: dot
|
206 |
+
n_obs_steps: 3
|
207 |
+
train_horizon: 30
|
208 |
+
inference_horizon: 30
|
209 |
+
lookback_obs_steps: 10
|
210 |
+
lookback_aug: 5
|
211 |
+
input_shapes:
|
212 |
+
observation.environment_state:
|
213 |
+
- 16
|
214 |
+
observation.state:
|
215 |
+
- ${env.state_dim}
|
216 |
+
output_shapes:
|
217 |
+
action:
|
218 |
+
- ${env.action_dim}
|
219 |
+
input_normalization_modes:
|
220 |
+
observation.environment_state: min_max
|
221 |
+
observation.state: min_max
|
222 |
+
output_normalization_modes:
|
223 |
+
action: min_max
|
224 |
+
state_noise: 0.01
|
225 |
+
noise_decay: 0.999995
|
226 |
+
pre_norm: true
|
227 |
+
dim_model: 128
|
228 |
+
n_heads: 8
|
229 |
+
dim_feedforward: 512
|
230 |
+
n_decoder_layers: 8
|
231 |
+
dropout: 0.1
|
232 |
+
alpha: 0.75
|
233 |
+
train_alpha: 0.9
|
234 |
+
predict_every_n: 1
|
235 |
+
return_every_n: 2
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d5e02a6c29abeaf8b44b1c78dc953b7bb8ce8983ed9491e7fe19eb24cfd6c94
|
3 |
+
size 8534312
|
replay.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e32ed981aca2d1f85511b1f310421dc766c26aad6427ff45c32439e8975187e
|
3 |
+
size 135178
|