model: target: unifolm_wma.models.ddpms.LatentVisualDiffusion params: rescale_betas_zero_snr: True parameterization: "v" linear_start: 0.00085 linear_end: 0.012 num_timesteps_cond: 1 timesteps: 1000 first_stage_key: video cond_stage_key: instruction cond_stage_trainable: False conditioning_key: hybrid image_size: [40, 64] channels: 4 scale_by_std: False scale_factor: 0.18215 use_ema: False uncond_type: 'empty_seq' use_dynamic_rescale: true base_scale: 0.7 fps_condition_type: 'fps' perframe_ae: True freeze_embedder: True n_obs_steps_imagen: 1 n_obs_steps_acting: 1 agent_state_dim: 16 agent_action_dim: 16 ###################### DP Related input_pertub: 0.1 lr_scheduler: cosine lr_warmup_steps: 2000 num_epochs: 30000 gradient_accumulate_every: 1 use_scheduler: True dp_use_ema: True dp_ema_config: target: unifolm_wma.models.diffusion_head.ema_model.EMAModel params: update_after_step: 0 inv_gamma: 1.0 power: 0.75 min_value: 0.0 max_value: 0.9999 noise_scheduler_config: target: diffusers.DDIMScheduler params: num_train_timesteps: 1000 beta_start: 0.0001 beta_end: 0.02 beta_schedule: squaredcos_cap_v2 clip_sample: True set_alpha_to_one: True steps_offset: 0 prediction_type: epsilon dp_optimizer_config: target: torch.optim.AdamW params: lr: 1.0e-4 betas: [0.95, 0.999] eps: 1.0e-8 weight_decay: 1.0e-6 wma_config: target: unifolm_wma.modules.networks.wma_model.WMAModel params: in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: - 4 - 2 - 1 num_res_blocks: 2 channel_mult: - 1 - 2 - 4 - 4 dropout: 0.1 num_head_channels: 64 transformer_depth: 1 context_dim: 1024 use_linear: true use_checkpoint: True temporal_conv: True temporal_attention: True temporal_selfatt_only: True use_relative_position: False use_causal_attention: False temporal_length: 16 addition_attention: True image_cross_attention: True default_fs: 10 fs_condition: True cross_attention_scale_learnable: False n_obs_steps: ${model.params.n_obs_steps_imagen} num_stem_token: 16 base_model_gen_only: True unet_head_config: target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D params: input_dim: ${model.params.agent_action_dim} n_obs_steps: ${model.params.n_obs_steps_acting} diffusion_step_embed_dim: 128 down_dims: [256, 512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True num_head_channels: ${model.params.wma_config.params.num_head_channels} horizon: ${model.params.wma_config.params.temporal_length} use_linear_attn: ${model.params.wma_config.params.use_linear} use_linear_act_proj: True act_proj_dim: 32 cond_cross_attention: False context_dims: [] image_size: ${model.params.image_size} imagen_cond_gradient: True last_frame_only: False use_imagen_mid_only: False use_z_only: False obs_encoder_config: target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder params: rgb_model_config: target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet params: name: resnet18 weights: null resize_shape: null crop_shape: null random_crop: False use_group_norm: True share_rgb_model: False imagenet_norm: True use_spatial_softmax: True spatial_softmax_kp: 128 ###################### Action Tokenization stem_process_config: target: unifolm_wma.modules.encoders.condition.SATokenProjector params: dim: 1024 depth: 1 dim_head: 64 heads: 16 num_queries: ${model.params.wma_config.params.num_stem_token} output_dim: 1024 ff_mult: 4 chunk_size: ${model.params.wma_config.params.temporal_length} first_stage_config: target: unifolm_wma.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: - 1 - 2 - 4 - 4 num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity cond_stage_config: target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" img_cond_stage_config: target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 params: freeze: true image_proj_stage_config: target: unifolm_wma.modules.encoders.resampler.Resampler params: dim: 1024 depth: 4 dim_head: 64 heads: 12 num_queries: 16 embedding_dim: 1280 output_dim: 1024 ff_mult: 4 video_length: ${model.params.wma_config.params.temporal_length} normalization_config: input_shapes: observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim} input_normalization_modes: observation.state: 'min_max' output_shapes: action: ${model.params.wma_config.params.action_unet_config.params.input_dim} output_normalization_modes: action: 'min_max'