MoDE_Pretrained / config.yaml
mbreuss's picture
Create config.yaml
5e21358 verified
raw
history blame
7.48 kB
datamodule:
transforms:
combine_goal_obs: false
move_axis: false
bytes_to_string: true
adjust_type: null
add_robot_information: false
language_encoders:
_target_: medit.agents.input_encoders.goal_encoders.language_encoders.clip_tokens.TokenLangClip
_recursive_: false
model_name: ${clip_lang_model_name}
_target_: oxe_torch_dataloader.uha.uha_datamodule.UhaDataModule
_recursive_: false
num_workers: ${num_workers}
batch_size: ${batch_size}
pin_memory: ${pin_memory}
drop_last: ${drop_last}
datasets:
DATA_NAME: ${DATA_NAME}
DATA_PATH: gs://gresearch/robotics
load_camera_views: ${load_camera_views}
dataset_size_limit: ${dataset_size_limit}
action_proprio_normalization_type: bounds
interleaved_dataset_cfg:
shuffle_buffer_size: ${shuffle_buffer_size}
balance_weights: true
traj_transform_kwargs:
goal_relabeling_strategy: ${goal_relabeling_strategy}
goal_relabeling_kwargs: ${goal_relabeling_kwargs}
window_size: ${window_size}
action_horizon: ${act_seq_len}
subsample_length: ${subsample_length}
skip_unlabeled: ${skip_unlabeled}
frame_transform_kwargs:
image_augment_kwargs:
primary:
random_resized_crop:
scale:
- 0.8
- 1.0
ratio:
- 0.9
- 1.1
random_brightness:
- 0.1
random_contrast:
- 0.9
- 1.1
random_saturation:
- 0.9
- 1.1
random_hue:
- 0.05
augment_order:
- random_resized_crop
- random_brightness
- random_contrast
- random_saturation
- random_hue
secondary:
random_resized_crop:
scale:
- 0.8
- 1.0
ratio:
- 0.9
- 1.1
random_brightness:
- 0.1
random_contrast:
- 0.9
- 1.1
random_saturation:
- 0.9
- 1.1
random_hue:
- 0.05
augment_order:
- random_resized_crop
- random_brightness
- random_contrast
- random_saturation
- random_hue
wrist:
random_brightness:
- 0.1
random_contrast:
- 0.9
- 1.1
random_saturation:
- 0.9
- 1.1
random_hue:
- 0.05
augment_order:
- random_brightness
- random_contrast
- random_saturation
- random_hue
resize_size:
primary:
- 224
- 224
secondary:
- 224
- 224
wrist:
- 224
- 224
resize_size_future_obs:
primary:
- 112
- 112
secondary:
- 112
- 112
wrist:
- 112
- 112
num_parallel_calls: 128
traj_transform_threads: 64
traj_read_threads: 32
trainer:
agent:
agent:
language_goal:
_target_: medit.agents.input_encoders.goal_encoders.language_encoders.clip_tokens.LangClip
_recursive_: false
freeze_backbone: true
model_name: ${clip_lang_model_name}
model:
_target_: medit.agents.inner_models.edm_diffusion_policy.score_wrappers.GCDenoiser
_recursive_: true
sigma_data: 0.5
inner_model:
_target_: medit.agents.inner_models.modedit.MoDeDiT
action_dim: ${act_dim}
goal_dim: ${goal_dim}
obs_dim: 2048
goal_conditioned: true
causal: true
use_custom_attn_mask: false
use_proprio: false
state_dim: 8
embed_dim: 1024
n_layers: 12
goal_seq_len: 1
obs_seq_len: ${obs_seq_len}
action_seq_len: ${act_seq_len}
embed_pdrob: 0
goal_drop: 0.1
attn_pdrop: 0.3
mlp_pdrop: 0.1
n_heads: 8
linear_output: true
cond_router: true
num_experts: 4
top_k: 2
router_normalize: true
use_goal_in_routing: false
use_argmax: false
use_shared_expert: false
use_noise_token_as_input: true
init_style: olmoe
_target_: medit.agents.mode_agent.MoDEAgent
_recursive_: false
latent_dim: 1024
multistep: 5
sampler_type: ddim
num_sampling_steps: 5
sigma_data: 0.5
sigma_min: 0.001
sigma_max: 80
noise_scheduler: exponential
sigma_sample_density_type: loglogistic
act_window_size: ${act_seq_len}
act_dim: ${act_dim}
seed: ${seed}
obs_modalities: ${obs_modalities}
goal_modalities: ${goal_modalities}
img_modalities: ${img_modalities}
lang_modalities: ${lang_modalities}
target_modality: ${target_modality}
entropy_gamma: 0.01
router_z_delta: 0.0
resnet_type: '50'
_target_: medit.agents.ddp_wrapper.DDPAgentWrapper
_recursive_: false
obs_modalities: ${obs_modalities}
goal_modalities: ${goal_modalities}
img_modalities: ${img_modalities}
lang_modalities: ${lang_modalities}
target_modality: ${target_modality}
_target_: medit.trainers.accelerate_trainer.AccelerateTrainer
_recursive_: false
weight_decay:
transformer_weight_decay: 0.1
obs_encoder_weight_decay: 0.1
perceptual_encoder_lr: 0.0001
lr_scheduler: ${lr_scheduler}
eval_every_n_steps: ${eval_every_n_steps}
save_every_n_steps: ${save_every_n_steps}
max_train_steps: ${max_train_steps}
max_eval_steps: ${max_eval_steps}
use_ema: true
decay: ${decay}
rampup_ratio: ${rampup_ratio}
update_ema_every_n_steps: ${update_ema_every_n_steps}
batch_size: ${batch_size}
obs_modalities: ${obs_modalities}
goal_modalities: ${goal_modalities}
img_modalities: ${img_modalities}
lang_modalities: ${lang_modalities}
target_modality: ${target_modality}
vis_clip_model_name: ViT-B/16
clip_lang_model_name: ViT-B/32
DATA_NAME: MO
wandb:
name: uha_${now:%H-%M-%S}
group: ${now:%Y-%m-%d}
project: simulation_eval
entity: irl-masterthesis
mode: null
lr_scheduler:
_target_: medit.agents.utils.lr_schedulers.InverseSquareRootLRSchedule
num_warmup_steps: 1000
timescale: ${max_train_steps}
log_dir: logs/
window_size: 1
obs_seq_len: 1
goal_window_size: 1
seed: 42
obs_dim: 512
goal_dim: 512
act_seq_len: 10
update_ema_every_n_steps: 1
decay: 0.999
rampup_ratio: 0.001
gen_img_res: 112
num_tokens_voltron: 10
img_gen_frame_diff: 3
use_modality_encoder: false
goal_relabeling_strategy: null
goal_relabeling_kwargs:
min_bound: 20
max_bound: 50
frame_diff: ${img_gen_frame_diff}
subsample_length: null
skip_unlabeled: true
load_camera_views:
- primary
- secondary
- wrist
obs_modalities: observation
goal_modalities: task
img_modalities:
- image_primary
- image_secondary
- image_wrist
lang_modalities:
- language_instruction
target_modality: action
drop_last: true
pin_memory: true
num_workers: 0
gradient_accumulation_steps: 1
act_dim: 7
max_train_steps: 300000
max_eval_steps: 200
eval_every_n_steps: 5000
save_every_n_steps: 5000
shuffle_buffer_size: 400000
batch_size: 512
dataset_size_limit: null