Files
unifolm-world-model-action/configs/inference/world_model_interaction.yaml
2026-01-18 00:30:10 +08:00

245 lines
7.4 KiB
YAML

model:
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
timesteps: 1000
first_stage_key: video
cond_stage_key: instruction
cond_stage_trainable: False
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_type: 'empty_seq'
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
freeze_embedder: True
n_obs_steps_imagen: 2
n_obs_steps_acting: 2
agent_state_dim: 16
agent_action_dim: 16
decision_making_only: False
###################### DP Related
input_pertub: 0.1
lr_scheduler: cosine
lr_warmup_steps: 2000
num_epochs: 30000
gradient_accumulate_every: 1
use_scheduler: True
dp_use_ema: True
dp_ema_config:
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
params:
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
noise_scheduler_config:
target: diffusers.DDIMScheduler
params:
num_train_timesteps: 1000
beta_start: 0.0001
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
clip_sample: True
set_alpha_to_one: True
steps_offset: 0
prediction_type: epsilon
dp_optimizer_config:
target: torch.optim.AdamW
params:
lr: 1.0e-4
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
wma_config:
target: unifolm_wma.modules.networks.wma_model.WMAModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: True
use_relative_position: False
use_causal_attention: False
temporal_length: 16
addition_attention: True
image_cross_attention: True
default_fs: 10
fs_condition: True
cross_attention_scale_learnable: False
n_obs_steps: ${model.params.n_obs_steps_imagen}
num_stem_token: 16
base_model_gen_only: False
unet_head_config:
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
params:
input_dim: ${model.params.agent_action_dim}
n_obs_steps: ${model.params.n_obs_steps_acting}
diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
num_head_channels: ${model.params.wma_config.params.num_head_channels}
horizon: ${model.params.wma_config.params.temporal_length}
use_linear_attn: ${model.params.wma_config.params.use_linear}
use_linear_act_proj: True
act_proj_dim: 32
cond_cross_attention: False
context_dims: []
image_size: ${model.params.image_size}
imagen_cond_gradient: True
last_frame_only: False
use_imagen_mid_only: False
use_z_only: False
obs_encoder_config:
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
params:
rgb_model_config:
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
params:
name: resnet18
weights: null
resize_shape: null
crop_shape: null
random_crop: False
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
use_spatial_softmax: True
spatial_softmax_kp: 128
###################### Action Tokenization
stem_process_config:
target: unifolm_wma.modules.encoders.condition.SATokenProjector
params:
dim: 1024
depth: 1
dim_head: 64
heads: 16
num_queries: ${model.params.wma_config.params.num_stem_token}
output_dim: 1024
ff_mult: 4
chunk_size: ${model.params.wma_config.params.temporal_length}
first_stage_config:
target: unifolm_wma.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
img_cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: unifolm_wma.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: ${model.params.wma_config.params.temporal_length}
normalization_config:
input_shapes:
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
input_normalization_modes:
observation.state: 'min_max'
output_shapes:
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
output_normalization_modes:
action: 'min_max'
data:
target: unifolm_wma.utils.data.DataModuleFromConfig
params:
batch_size: 6
num_workers: 12
wrap: False
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True
resolution: [320, 512]
spatial_transform: resize_center_crop
crop_resolution: [320, 512]
random_fs: False
cond_robot_label_prob: 0.0
normalization_mode: 'min_max'
individual_normalization: True
n_obs_steps: ${model.params.n_obs_steps_imagen}
max_action_dim: ${model.params.agent_action_dim}
max_state_dim: ${model.params.agent_state_dim}
dataset_and_weights:
unitree_z1_stackbox: 0.2
unitree_z1_dual_arm_stackbox: 0.2
unitree_z1_dual_arm_stackbox_v2: 0.2
unitree_z1_dual_arm_cleanup_pencils: 0.2
unitree_g1_pack_camera: 0.2