第一次完整测例跑完
This commit is contained in:
213
configs/inference/base_model_inference.yaml
Normal file
213
configs/inference/base_model_inference.yaml
Normal file
@@ -0,0 +1,213 @@
|
||||
model:
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_type: 'empty_seq'
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 1
|
||||
n_obs_steps_acting: 1
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 30000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: True
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
240
configs/inference/world_model_decision_making.yaml
Normal file
240
configs/inference/world_model_decision_making.yaml
Normal file
@@ -0,0 +1,240 @@
|
||||
model:
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_type: 'empty_seq'
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 2
|
||||
n_obs_steps_acting: 2
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
decision_making_only: True
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 30000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: False
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
|
||||
data:
|
||||
target: unifolm_wma.utils.data.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 6
|
||||
num_workers: 12
|
||||
wrap: False
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/path/to/the/dataset/directory/that/contains/the/meta/folder/of/the/testing/case/under/a/transitions/folder' # e.g., /path/to/unifolm-world-model-action/examples/world_model_interaction_prompts
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
resolution: [320, 512]
|
||||
spatial_transform: resize_center_crop
|
||||
crop_resolution: [320, 512]
|
||||
random_fs: False
|
||||
cond_robot_label_prob: 0.0
|
||||
normalization_mode: 'min_max'
|
||||
individual_normalization: True
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
max_action_dim: ${model.params.agent_action_dim}
|
||||
max_state_dim: ${model.params.agent_state_dim}
|
||||
dataset_and_weights:
|
||||
unitree_g1_pack_camera: 1.0
|
||||
244
configs/inference/world_model_interaction.yaml
Normal file
244
configs/inference/world_model_interaction.yaml
Normal file
@@ -0,0 +1,244 @@
|
||||
model:
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_type: 'empty_seq'
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 2
|
||||
n_obs_steps_acting: 2
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
decision_making_only: False
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 30000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: False
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
|
||||
data:
|
||||
target: unifolm_wma.utils.data.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 6
|
||||
num_workers: 12
|
||||
wrap: False
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
resolution: [320, 512]
|
||||
spatial_transform: resize_center_crop
|
||||
crop_resolution: [320, 512]
|
||||
random_fs: False
|
||||
cond_robot_label_prob: 0.0
|
||||
normalization_mode: 'min_max'
|
||||
individual_normalization: True
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
max_action_dim: ${model.params.agent_action_dim}
|
||||
max_state_dim: ${model.params.agent_state_dim}
|
||||
dataset_and_weights:
|
||||
unitree_z1_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox_v2: 0.2
|
||||
unitree_z1_dual_arm_cleanup_pencils: 0.2
|
||||
unitree_g1_pack_camera: 0.2
|
||||
287
configs/train/config.yaml
Normal file
287
configs/train/config.yaml
Normal file
@@ -0,0 +1,287 @@
|
||||
model:
|
||||
pretrained_checkpoint: /path/to/pretrained/checkpoint
|
||||
base_learning_rate: 1.0e-05
|
||||
scale_lr: False
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
image_proj_model_trainable: True
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_prob: 0.05
|
||||
uncond_type: 'empty_seq'
|
||||
rand_cond_frame: false
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 2
|
||||
n_obs_steps_acting: 2
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
decision_making_only: True
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 60000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: False
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.unet_head_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.unet_head_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
|
||||
data:
|
||||
target: unifolm_wma.utils.data.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 8
|
||||
num_workers: 12
|
||||
wrap: False
|
||||
train:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/path/to/training/dataset/directory'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
resolution: [320, 512]
|
||||
spatial_transform: resize_center_crop
|
||||
crop_resolution: [320, 512]
|
||||
random_fs: False
|
||||
cond_robot_label_prob: 0.0
|
||||
normalization_mode: 'min_max'
|
||||
individual_normalization: True
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
max_action_dim: ${model.params.agent_action_dim}
|
||||
max_state_dim: ${model.params.agent_state_dim}
|
||||
dataset_and_weights:
|
||||
unitree_z1_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox_v2: 0.2
|
||||
unitree_z1_dual_arm_cleanup_pencils: 0.2
|
||||
unitree_g1_pack_camera: 0.2
|
||||
|
||||
lightning:
|
||||
precision: 16
|
||||
trainer:
|
||||
benchmark: True
|
||||
accumulate_grad_batches: 2
|
||||
max_steps: 300000
|
||||
log_every_n_steps: 50
|
||||
val_check_interval: 1.0
|
||||
gradient_clip_algorithm: 'norm'
|
||||
gradient_clip_val: 0.5
|
||||
enable_model_summary: False
|
||||
callbacks:
|
||||
model_checkpoint:
|
||||
target: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
params:
|
||||
every_n_train_steps: 1000
|
||||
filename: "{epoch}-{step}"
|
||||
save_weights_only: True
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
target: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
params:
|
||||
filename: '{epoch}-{step}'
|
||||
save_weights_only: True
|
||||
every_n_train_steps: 10000
|
||||
batch_logger:
|
||||
target: unifolm_wma.utils.callbacks.ImageLogger
|
||||
params:
|
||||
batch_frequency: 20000
|
||||
to_local: False
|
||||
max_images: 8
|
||||
log_images_kwargs:
|
||||
ddim_steps: 16
|
||||
unconditional_guidance_scale: 1.0
|
||||
timestep_spacing: uniform_trailing
|
||||
guidance_rescale: 0.7
|
||||
Reference in New Issue
Block a user