mbreuss commited on
Commit
5e21358
1 Parent(s): 5a29d7b

Create config.yaml

Browse files
Files changed (1) hide show
  1. config.yaml +274 -0
config.yaml ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datamodule:
2
+ transforms:
3
+ combine_goal_obs: false
4
+ move_axis: false
5
+ bytes_to_string: true
6
+ adjust_type: null
7
+ add_robot_information: false
8
+ language_encoders:
9
+ _target_: medit.agents.input_encoders.goal_encoders.language_encoders.clip_tokens.TokenLangClip
10
+ _recursive_: false
11
+ model_name: ${clip_lang_model_name}
12
+ _target_: oxe_torch_dataloader.uha.uha_datamodule.UhaDataModule
13
+ _recursive_: false
14
+ num_workers: ${num_workers}
15
+ batch_size: ${batch_size}
16
+ pin_memory: ${pin_memory}
17
+ drop_last: ${drop_last}
18
+ datasets:
19
+ DATA_NAME: ${DATA_NAME}
20
+ DATA_PATH: gs://gresearch/robotics
21
+ load_camera_views: ${load_camera_views}
22
+ dataset_size_limit: ${dataset_size_limit}
23
+ action_proprio_normalization_type: bounds
24
+ interleaved_dataset_cfg:
25
+ shuffle_buffer_size: ${shuffle_buffer_size}
26
+ balance_weights: true
27
+ traj_transform_kwargs:
28
+ goal_relabeling_strategy: ${goal_relabeling_strategy}
29
+ goal_relabeling_kwargs: ${goal_relabeling_kwargs}
30
+ window_size: ${window_size}
31
+ action_horizon: ${act_seq_len}
32
+ subsample_length: ${subsample_length}
33
+ skip_unlabeled: ${skip_unlabeled}
34
+ frame_transform_kwargs:
35
+ image_augment_kwargs:
36
+ primary:
37
+ random_resized_crop:
38
+ scale:
39
+ - 0.8
40
+ - 1.0
41
+ ratio:
42
+ - 0.9
43
+ - 1.1
44
+ random_brightness:
45
+ - 0.1
46
+ random_contrast:
47
+ - 0.9
48
+ - 1.1
49
+ random_saturation:
50
+ - 0.9
51
+ - 1.1
52
+ random_hue:
53
+ - 0.05
54
+ augment_order:
55
+ - random_resized_crop
56
+ - random_brightness
57
+ - random_contrast
58
+ - random_saturation
59
+ - random_hue
60
+ secondary:
61
+ random_resized_crop:
62
+ scale:
63
+ - 0.8
64
+ - 1.0
65
+ ratio:
66
+ - 0.9
67
+ - 1.1
68
+ random_brightness:
69
+ - 0.1
70
+ random_contrast:
71
+ - 0.9
72
+ - 1.1
73
+ random_saturation:
74
+ - 0.9
75
+ - 1.1
76
+ random_hue:
77
+ - 0.05
78
+ augment_order:
79
+ - random_resized_crop
80
+ - random_brightness
81
+ - random_contrast
82
+ - random_saturation
83
+ - random_hue
84
+ wrist:
85
+ random_brightness:
86
+ - 0.1
87
+ random_contrast:
88
+ - 0.9
89
+ - 1.1
90
+ random_saturation:
91
+ - 0.9
92
+ - 1.1
93
+ random_hue:
94
+ - 0.05
95
+ augment_order:
96
+ - random_brightness
97
+ - random_contrast
98
+ - random_saturation
99
+ - random_hue
100
+ resize_size:
101
+ primary:
102
+ - 224
103
+ - 224
104
+ secondary:
105
+ - 224
106
+ - 224
107
+ wrist:
108
+ - 224
109
+ - 224
110
+ resize_size_future_obs:
111
+ primary:
112
+ - 112
113
+ - 112
114
+ secondary:
115
+ - 112
116
+ - 112
117
+ wrist:
118
+ - 112
119
+ - 112
120
+ num_parallel_calls: 128
121
+ traj_transform_threads: 64
122
+ traj_read_threads: 32
123
+ trainer:
124
+ agent:
125
+ agent:
126
+ language_goal:
127
+ _target_: medit.agents.input_encoders.goal_encoders.language_encoders.clip_tokens.LangClip
128
+ _recursive_: false
129
+ freeze_backbone: true
130
+ model_name: ${clip_lang_model_name}
131
+ model:
132
+ _target_: medit.agents.inner_models.edm_diffusion_policy.score_wrappers.GCDenoiser
133
+ _recursive_: true
134
+ sigma_data: 0.5
135
+ inner_model:
136
+ _target_: medit.agents.inner_models.modedit.MoDeDiT
137
+ action_dim: ${act_dim}
138
+ goal_dim: ${goal_dim}
139
+ obs_dim: 2048
140
+ goal_conditioned: true
141
+ causal: true
142
+ use_custom_attn_mask: false
143
+ use_proprio: false
144
+ state_dim: 8
145
+ embed_dim: 1024
146
+ n_layers: 12
147
+ goal_seq_len: 1
148
+ obs_seq_len: ${obs_seq_len}
149
+ action_seq_len: ${act_seq_len}
150
+ embed_pdrob: 0
151
+ goal_drop: 0.1
152
+ attn_pdrop: 0.3
153
+ mlp_pdrop: 0.1
154
+ n_heads: 8
155
+ linear_output: true
156
+ cond_router: true
157
+ num_experts: 4
158
+ top_k: 2
159
+ router_normalize: true
160
+ use_goal_in_routing: false
161
+ use_argmax: false
162
+ use_shared_expert: false
163
+ use_noise_token_as_input: true
164
+ init_style: olmoe
165
+ _target_: medit.agents.mode_agent.MoDEAgent
166
+ _recursive_: false
167
+ latent_dim: 1024
168
+ multistep: 5
169
+ sampler_type: ddim
170
+ num_sampling_steps: 5
171
+ sigma_data: 0.5
172
+ sigma_min: 0.001
173
+ sigma_max: 80
174
+ noise_scheduler: exponential
175
+ sigma_sample_density_type: loglogistic
176
+ act_window_size: ${act_seq_len}
177
+ act_dim: ${act_dim}
178
+ seed: ${seed}
179
+ obs_modalities: ${obs_modalities}
180
+ goal_modalities: ${goal_modalities}
181
+ img_modalities: ${img_modalities}
182
+ lang_modalities: ${lang_modalities}
183
+ target_modality: ${target_modality}
184
+ entropy_gamma: 0.01
185
+ router_z_delta: 0.0
186
+ resnet_type: '50'
187
+ _target_: medit.agents.ddp_wrapper.DDPAgentWrapper
188
+ _recursive_: false
189
+ obs_modalities: ${obs_modalities}
190
+ goal_modalities: ${goal_modalities}
191
+ img_modalities: ${img_modalities}
192
+ lang_modalities: ${lang_modalities}
193
+ target_modality: ${target_modality}
194
+ _target_: medit.trainers.accelerate_trainer.AccelerateTrainer
195
+ _recursive_: false
196
+ weight_decay:
197
+ transformer_weight_decay: 0.1
198
+ obs_encoder_weight_decay: 0.1
199
+ perceptual_encoder_lr: 0.0001
200
+ lr_scheduler: ${lr_scheduler}
201
+ eval_every_n_steps: ${eval_every_n_steps}
202
+ save_every_n_steps: ${save_every_n_steps}
203
+ max_train_steps: ${max_train_steps}
204
+ max_eval_steps: ${max_eval_steps}
205
+ use_ema: true
206
+ decay: ${decay}
207
+ rampup_ratio: ${rampup_ratio}
208
+ update_ema_every_n_steps: ${update_ema_every_n_steps}
209
+ batch_size: ${batch_size}
210
+ obs_modalities: ${obs_modalities}
211
+ goal_modalities: ${goal_modalities}
212
+ img_modalities: ${img_modalities}
213
+ lang_modalities: ${lang_modalities}
214
+ target_modality: ${target_modality}
215
+ vis_clip_model_name: ViT-B/16
216
+ clip_lang_model_name: ViT-B/32
217
+ DATA_NAME: MO
218
+ wandb:
219
+ name: uha_${now:%H-%M-%S}
220
+ group: ${now:%Y-%m-%d}
221
+ project: simulation_eval
222
+ entity: irl-masterthesis
223
+ mode: null
224
+ lr_scheduler:
225
+ _target_: medit.agents.utils.lr_schedulers.InverseSquareRootLRSchedule
226
+ num_warmup_steps: 1000
227
+ timescale: ${max_train_steps}
228
+ log_dir: logs/
229
+ window_size: 1
230
+ obs_seq_len: 1
231
+ goal_window_size: 1
232
+ seed: 42
233
+ obs_dim: 512
234
+ goal_dim: 512
235
+ act_seq_len: 10
236
+ update_ema_every_n_steps: 1
237
+ decay: 0.999
238
+ rampup_ratio: 0.001
239
+ gen_img_res: 112
240
+ num_tokens_voltron: 10
241
+ img_gen_frame_diff: 3
242
+ use_modality_encoder: false
243
+ goal_relabeling_strategy: null
244
+ goal_relabeling_kwargs:
245
+ min_bound: 20
246
+ max_bound: 50
247
+ frame_diff: ${img_gen_frame_diff}
248
+ subsample_length: null
249
+ skip_unlabeled: true
250
+ load_camera_views:
251
+ - primary
252
+ - secondary
253
+ - wrist
254
+ obs_modalities: observation
255
+ goal_modalities: task
256
+ img_modalities:
257
+ - image_primary
258
+ - image_secondary
259
+ - image_wrist
260
+ lang_modalities:
261
+ - language_instruction
262
+ target_modality: action
263
+ drop_last: true
264
+ pin_memory: true
265
+ num_workers: 0
266
+ gradient_accumulation_steps: 1
267
+ act_dim: 7
268
+ max_train_steps: 300000
269
+ max_eval_steps: 200
270
+ eval_every_n_steps: 5000
271
+ save_every_n_steps: 5000
272
+ shuffle_buffer_size: 400000
273
+ batch_size: 512
274
+ dataset_size_limit: null