diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 92025b8..27fcd2d 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -441,57 +441,143 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: return file_list -def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: - """Load model weights from checkpoint file. - - Args: - model (nn.Module): Model instance. - ckpt (str): Path to the checkpoint file. - - Returns: - nn.Module: Model with loaded weights. - """ - state_dict = torch.load(ckpt, map_location="cpu") - if "state_dict" in list(state_dict.keys()): - state_dict = state_dict["state_dict"] - try: - model.load_state_dict(state_dict, strict=True) - except: - new_pl_sd = OrderedDict() - for k, v in state_dict.items(): - new_pl_sd[k] = v - - for k in list(new_pl_sd.keys()): - if "framestride_embed" in k: - new_key = k.replace("framestride_embed", "fps_embedding") - new_pl_sd[new_key] = new_pl_sd[k] - del new_pl_sd[k] - model.load_state_dict(new_pl_sd, strict=True) - else: - new_pl_sd = OrderedDict() - for key in state_dict['module'].keys(): - new_pl_sd[key[16:]] = state_dict['module'][key] - model.load_state_dict(new_pl_sd) +def _load_state_dict(model: nn.Module, + state_dict: Mapping[str, torch.Tensor], + strict: bool = True, + assign: bool = False) -> None: + if assign: + try: + model.load_state_dict(state_dict, strict=strict, assign=True) + return + except TypeError: + warnings.warn( + "load_state_dict(assign=True) not supported; " + "falling back to copy load.") + model.load_state_dict(state_dict, strict=strict) + + +def load_model_checkpoint(model: nn.Module, + ckpt: str, + assign: bool | None = None) -> nn.Module: + """Load model weights from checkpoint file. + + Args: + model (nn.Module): Model instance. + ckpt (str): Path to the checkpoint file. + assign (bool | None): Whether to preserve checkpoint tensor dtypes + via load_state_dict(assign=True). If None, auto-enable when a + casted checkpoint metadata is detected. + + Returns: + nn.Module: Model with loaded weights. + """ + ckpt_data = torch.load(ckpt, map_location="cpu") + use_assign = False + if assign is not None: + use_assign = assign + elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data: + use_assign = True + if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data: + state_dict = ckpt_data["state_dict"] + try: + _load_state_dict(model, state_dict, strict=True, assign=use_assign) + except Exception: + new_pl_sd = OrderedDict() + for k, v in state_dict.items(): + new_pl_sd[k] = v + + for k in list(new_pl_sd.keys()): + if "framestride_embed" in k: + new_key = k.replace("framestride_embed", "fps_embedding") + new_pl_sd[new_key] = new_pl_sd[k] + del new_pl_sd[k] + _load_state_dict(model, + new_pl_sd, + strict=True, + assign=use_assign) + elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data: + new_pl_sd = OrderedDict() + for key in ckpt_data['module'].keys(): + new_pl_sd[key[16:]] = ckpt_data['module'][key] + _load_state_dict(model, new_pl_sd, strict=True, assign=use_assign) + else: + _load_state_dict(model, + ckpt_data, + strict=True, + assign=use_assign) print('>>> model checkpoint loaded.') return model +def maybe_cast_module(module: nn.Module | None, + dtype: torch.dtype, + label: str, + profiler: Optional[ProfilerManager] = None, + profile_name: Optional[str] = None) -> None: + if module is None: + return + try: + param = next(module.parameters()) + except StopIteration: + print(f">>> {label} has no parameters; skip cast") + return + if param.dtype == dtype: + print(f">>> {label} already {dtype}; skip cast") + return + ctx = nullcontext() + if profiler is not None and profile_name: + ctx = profiler.profile_section(profile_name) + with ctx: + module.to(dtype=dtype) + print(f">>> {label} cast to {dtype}") + + +def save_casted_checkpoint(model: nn.Module, + save_path: str, + metadata: Optional[Dict[str, Any]] = None) -> None: + if not save_path: + return + save_dir = os.path.dirname(save_path) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + cpu_state = {} + for key, value in model.state_dict().items(): + if isinstance(value, torch.Tensor): + cpu_state[key] = value.detach().to("cpu") + else: + cpu_state[key] = value + payload: Dict[str, Any] = {"state_dict": cpu_state} + if metadata: + payload["precision_metadata"] = metadata + torch.save(payload, save_path) + print(f">>> Saved casted checkpoint to {save_path}") + + def _module_param_dtype(module: nn.Module | None) -> str: if module is None: return "None" + dtype_counts: Dict[str, int] = {} for param in module.parameters(): - return str(param.dtype) - return "no_params" + dtype_key = str(param.dtype) + dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel() + if not dtype_counts: + return "no_params" + if len(dtype_counts) == 1: + return next(iter(dtype_counts)) + total = sum(dtype_counts.values()) + parts = [] + for dtype_key in sorted(dtype_counts.keys()): + ratio = dtype_counts[dtype_key] / total + parts.append(f"{dtype_key}={ratio:.1%}") + return f"mixed({', '.join(parts)})" def log_inference_precision(model: nn.Module) -> None: - try: - param = next(model.parameters()) + device = "unknown" + for param in model.parameters(): device = str(param.device) - model_dtype = str(param.dtype) - except StopIteration: - device = "unknown" - model_dtype = "no_params" + break + model_dtype = _module_param_dtype(model) print(f">>> inference precision: model={model_dtype}, device={device}") for attr in [ @@ -966,16 +1052,25 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: diffusion_autocast_dtype = None if args.diffusion_dtype == "bf16": - with profiler.profile_section("model_loading/diffusion_bf16"): - model.model.to(dtype=torch.bfloat16) + maybe_cast_module( + model.model, + torch.bfloat16, + "diffusion backbone", + profiler=profiler, + profile_name="model_loading/diffusion_bf16", + ) diffusion_autocast_dtype = torch.bfloat16 print(">>> diffusion backbone set to bfloat16") if hasattr(model, "first_stage_model") and model.first_stage_model is not None: - if args.vae_dtype == "bf16": - model.first_stage_model.to(dtype=torch.bfloat16) - else: - model.first_stage_model.to(dtype=torch.float32) + vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32 + maybe_cast_module( + model.first_stage_model, + vae_weight_dtype, + "VAE", + profiler=profiler, + profile_name="model_loading/vae_cast", + ) model.vae_bf16 = args.vae_dtype == "bf16" print(f">>> VAE dtype set to {args.vae_dtype}") @@ -983,9 +1078,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: encoder_bf16 = encoder_mode in ("autocast", "bf16_full") encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None: - model.cond_stage_model.to(dtype=encoder_weight_dtype) + maybe_cast_module( + model.cond_stage_model, + encoder_weight_dtype, + "cond_stage_model", + profiler=profiler, + profile_name="model_loading/encoder_cond_cast", + ) if hasattr(model, "embedder") and model.embedder is not None: - model.embedder.to(dtype=encoder_weight_dtype) + maybe_cast_module( + model.embedder, + encoder_weight_dtype, + "embedder", + profiler=profiler, + profile_name="model_loading/encoder_embedder_cast", + ) model.encoder_bf16 = encoder_bf16 model.encoder_mode = encoder_mode print( @@ -996,11 +1103,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: projector_bf16 = projector_mode in ("autocast", "bf16_full") projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32 if hasattr(model, "image_proj_model") and model.image_proj_model is not None: - model.image_proj_model.to(dtype=projector_weight_dtype) + maybe_cast_module( + model.image_proj_model, + projector_weight_dtype, + "image_proj_model", + profiler=profiler, + profile_name="model_loading/projector_image_cast", + ) if hasattr(model, "state_projector") and model.state_projector is not None: - model.state_projector.to(dtype=projector_weight_dtype) + maybe_cast_module( + model.state_projector, + projector_weight_dtype, + "state_projector", + profiler=profiler, + profile_name="model_loading/projector_state_cast", + ) if hasattr(model, "action_projector") and model.action_projector is not None: - model.action_projector.to(dtype=projector_weight_dtype) + maybe_cast_module( + model.action_projector, + projector_weight_dtype, + "action_projector", + profiler=profiler, + profile_name="model_loading/projector_action_cast", + ) if hasattr(model, "projector_bf16"): model.projector_bf16 = projector_bf16 model.projector_mode = projector_mode @@ -1010,6 +1135,19 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: log_inference_precision(model) + if args.export_casted_ckpt: + metadata = { + "diffusion_dtype": args.diffusion_dtype, + "vae_dtype": args.vae_dtype, + "encoder_mode": args.encoder_mode, + "projector_mode": args.projector_mode, + "perframe_ae": args.perframe_ae, + } + save_casted_checkpoint(model, args.export_casted_ckpt, metadata) + if args.export_only: + print(">>> export_only set; skipping inference.") + return + profiler.record_memory("after_model_load") # Run over data @@ -1373,6 +1511,19 @@ def get_parser(): default="fp32", help="Dtype for VAE/first_stage_model weights and forward autocast." ) + parser.add_argument( + "--export_casted_ckpt", + type=str, + default=None, + help= + "Save a checkpoint after applying precision settings (mixed dtypes preserved)." + ) + parser.add_argument( + "--export_only", + action='store_true', + default=False, + help="Exit after exporting the casted checkpoint." + ) parser.add_argument( "--step_log_every", type=int, diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768747270.node-0.516840.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768747270.node-0.516840.0 new file mode 100644 index 0000000..c746d9e Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768747270.node-0.516840.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768751173.node-0.526794.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768751173.node-0.526794.0 new file mode 100644 index 0000000..597db7e Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768751173.node-0.526794.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768786150.node-0.561740.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768786150.node-0.561740.0 new file mode 100644 index 0000000..a210245 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768786150.node-0.561740.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768786449.node-0.562354.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768786449.node-0.562354.0 new file mode 100644 index 0000000..49e3dc6 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768786449.node-0.562354.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768789007.node-0.571347.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768789007.node-0.571347.0 new file mode 100644 index 0000000..44c6e77 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768789007.node-0.571347.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768789039.node-0.571801.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768789039.node-0.571801.0 new file mode 100644 index 0000000..4623fd6 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768789039.node-0.571801.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768790111.node-0.581917.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768790111.node-0.581917.0 new file mode 100644 index 0000000..8f81ac0 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768790111.node-0.581917.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768790598.node-0.586145.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768790598.node-0.586145.0 new file mode 100644 index 0000000..f2ee32a Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768790598.node-0.586145.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791130.node-0.590214.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791130.node-0.590214.0 new file mode 100644 index 0000000..86e6dd1 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791130.node-0.590214.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791284.node-0.591591.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791284.node-0.591591.0 new file mode 100644 index 0000000..07d4f27 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791284.node-0.591591.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791727.node-0.595872.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791727.node-0.595872.0 new file mode 100644 index 0000000..14f5000 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791727.node-0.595872.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791849.node-0.596965.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791849.node-0.596965.0 new file mode 100644 index 0000000..d0c0fb0 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768791849.node-0.596965.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768792905.node-0.606250.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768792905.node-0.606250.0 new file mode 100644 index 0000000..cfd343f Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768792905.node-0.606250.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768793755.node-0.612877.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768793755.node-0.612877.0 new file mode 100644 index 0000000..55f0672 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768793755.node-0.612877.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768805201.node-0.634483.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768805201.node-0.634483.0 new file mode 100644 index 0000000..46bddc6 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768805201.node-0.634483.0 differ diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh index 62784f3..899fae9 100644 --- a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh @@ -4,7 +4,7 @@ dataset="unitree_g1_pack_camera" { time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \ --seed 123 \ - --ckpt_path ckpts/unifolm_wma_dual.ckpt \ + --ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \ --config configs/inference/world_model_interaction.yaml \ --savedir "${res_dir}/output" \ --bs 1 --height 320 --width 512 \ @@ -22,7 +22,6 @@ dataset="unitree_g1_pack_camera" --guidance_rescale 0.7 \ --perframe_ae \ --diffusion_dtype bf16 \ - --projector_mode autocast \ - --encoder_mode bf16_full \ - --vae_dtype bf16 + --projector_mode bf16_full \ + --encoder_mode bf16_full } 2>&1 | tee "${res_dir}/output.log" diff --git a/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768751353.node-0.527432.0 b/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768751353.node-0.527432.0 new file mode 100644 index 0000000..bdf72dc Binary files /dev/null and b/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768751353.node-0.527432.0 differ diff --git a/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768751560.node-0.528608.0 b/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768751560.node-0.528608.0 new file mode 100644 index 0000000..0f9288d Binary files /dev/null and b/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768751560.node-0.528608.0 differ diff --git a/useful.sh b/useful.sh index 86a04b5..e74ec63 100644 --- a/useful.sh +++ b/useful.sh @@ -118,4 +118,44 @@ embedder: 2. attention mask 缓存到 GPU(已完成) - _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝 - - 位置:src/unifolm_wma/modules/attention.py \ No newline at end of file + - 位置:src/unifolm_wma/modules/attention.py + + + + + + --export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \ + --export_only + + + checkpoint or previous run; diffusion backbone and condition embedder are bfloat16 while + projector is fp32, so casts are skipped as expected. The "inference precision" line reflects the + top-level param dtype (possibly fp32 embeddings), explaining the apparent inconsistency with + submodule param dtypes. + + Explaining skip cast due to preserved dtypes + + The skip cast happens because checkpoint loading preserves stored dtypes, so weights already + match the target (bf16 or fp32), making casting unnecessary. The log shows first param dtype for + checks, which can differ from submodules but is expected. For forcing recast, exporting a new + casted checkpoint is needed since repeated casts to the same dtype no-op. + +• “skip cast” 的意思是该模块的权重已经是目标 dtype,所以 maybe_cast_module 直接跳过转换。你这段日志 + 里: + + - diffusion backbone / cond_stage_model / embedder 已经是 torch.bfloat16 + - projector 仍是 torch.float32(因为你设的是 projector_mode=autocast,权重保持 fp32) + - VAE 你设置为 fp32,也已经是 fp32 + + 所以全部都匹配目标精度了,自然会 “skip cast”。这通常说明你加载的 ckpt 本身就是混合精度权重(比如 + 之前用 --export_casted_ckpt 导出的)。 + + 另外这行: + + >>> inference precision: model=torch.float32 + + 只是取了模型第一个参数的 dtype(log_inference_precision 的实现如此),模型里可能还有一些主模块外 + 的 fp32 参数(比如位置编码等),所以看到整体是 fp32 很正常,不代表 diffusion 不在 bf16。后面的子 + 模块打印才更准确。 + + 如果你想确认是否从“原始 fp32 ckpt”重新 cast,只要把 --ckpt_path 指回原始 ckpt 跑一次即可。 \ No newline at end of file