修复混合精度vae相关的配置错误,确保在推理阶段正确使用了混合精度模型,并且导出了正确精度的检查点文件。
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import argparse, os, glob
|
||||
from contextlib import nullcontext
|
||||
import pandas as pd
|
||||
import random
|
||||
import torch
|
||||
@@ -38,6 +39,68 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
|
||||
"""Apply precision settings to model components based on command-line arguments.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to apply precision settings to.
|
||||
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with precision settings applied.
|
||||
"""
|
||||
print(f">>> Applying precision settings:")
|
||||
print(f" - Diffusion dtype: {args.diffusion_dtype}")
|
||||
print(f" - Projector mode: {args.projector_mode}")
|
||||
print(f" - Encoder mode: {args.encoder_mode}")
|
||||
print(f" - VAE dtype: {args.vae_dtype}")
|
||||
|
||||
# 1. Set Diffusion backbone precision
|
||||
if args.diffusion_dtype == "bf16":
|
||||
# Convert diffusion model weights to bf16
|
||||
model.model.to(torch.bfloat16)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||
else:
|
||||
model.diffusion_autocast_dtype = None
|
||||
print(" ✓ Diffusion model using fp32")
|
||||
|
||||
# 2. Set Projector precision
|
||||
if args.projector_mode == "bf16_full":
|
||||
model.state_projector.to(torch.bfloat16)
|
||||
model.action_projector.to(torch.bfloat16)
|
||||
model.projector_autocast_dtype = None
|
||||
print(" ✓ Projectors converted to bfloat16")
|
||||
elif args.projector_mode == "autocast":
|
||||
model.projector_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
|
||||
else:
|
||||
model.projector_autocast_dtype = None
|
||||
# fp32 mode: do nothing, keep original precision
|
||||
|
||||
# 3. Set Encoder precision
|
||||
if args.encoder_mode == "bf16_full":
|
||||
model.embedder.to(torch.bfloat16)
|
||||
model.image_proj_model.to(torch.bfloat16)
|
||||
model.encoder_autocast_dtype = None
|
||||
print(" ✓ Encoders converted to bfloat16")
|
||||
elif args.encoder_mode == "autocast":
|
||||
model.encoder_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
|
||||
else:
|
||||
model.encoder_autocast_dtype = None
|
||||
# fp32 mode: do nothing, keep original precision
|
||||
|
||||
# 4. Set VAE precision
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(torch.bfloat16)
|
||||
print(" ✓ VAE converted to bfloat16")
|
||||
else:
|
||||
print(" ✓ VAE kept in fp32 for best quality")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||
"""Save a list of frames to a video file.
|
||||
|
||||
@@ -262,6 +325,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
|
||||
# Auto-detect VAE dtype and convert input
|
||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||
x = x.to(dtype=vae_dtype)
|
||||
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
@@ -363,10 +431,22 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
# Auto-detect model dtype and convert inputs accordingly
|
||||
model_dtype = next(model.embedder.parameters()).dtype
|
||||
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||
|
||||
# Encoder autocast: weights stay fp32, compute in bf16
|
||||
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
|
||||
if enc_ac_dtype is not None and model.device.type == 'cuda':
|
||||
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
|
||||
else:
|
||||
enc_ctx = nullcontext()
|
||||
|
||||
with enc_ctx:
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
@@ -380,11 +460,22 @@ def image_guided_synthesis_sim_mode(
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
|
||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
# Auto-detect projector dtype and convert inputs
|
||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'])
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
# Projector autocast: weights stay fp32, compute in bf16
|
||||
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
|
||||
if proj_ac_dtype is not None and model.device.type == 'cuda':
|
||||
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
|
||||
else:
|
||||
proj_ctx = nullcontext()
|
||||
|
||||
with proj_ctx:
|
||||
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
|
||||
if not sim_mode:
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
@@ -406,8 +497,17 @@ def image_guided_synthesis_sim_mode(
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
|
||||
# Setup autocast context for diffusion sampling
|
||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||
if autocast_dtype is not None and model.device.type == 'cuda':
|
||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
with autocast_ctx:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
@@ -464,6 +564,17 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Apply precision settings before moving to GPU
|
||||
model = apply_precision_settings(model, args)
|
||||
|
||||
# Export precision-converted checkpoint if requested
|
||||
if args.export_precision_ckpt:
|
||||
export_path = args.export_precision_ckpt
|
||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||
return
|
||||
|
||||
# Build unnomalizer
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
@@ -798,6 +909,35 @@ def get_parser():
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
parser.add_argument(
|
||||
"--diffusion_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="bf16",
|
||||
help="Diffusion backbone precision (fp32/bf16)")
|
||||
parser.add_argument(
|
||||
"--projector_mode",
|
||||
type=str,
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="bf16_full",
|
||||
help="Projector precision mode (fp32/autocast/bf16_full)")
|
||||
parser.add_argument(
|
||||
"--encoder_mode",
|
||||
type=str,
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="bf16_full",
|
||||
help="Encoder precision mode (fp32/autocast/bf16_full)")
|
||||
parser.add_argument(
|
||||
"--vae_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="VAE precision (fp32/bf16, most affects image quality)")
|
||||
parser.add_argument(
|
||||
"--export_precision_ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Export precision-converted checkpoint to this path, then exit.")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -1105,6 +1105,10 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
reshape_back = False
|
||||
|
||||
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
|
||||
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
||||
z = z.to(dtype=vae_dtype)
|
||||
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
@@ -2457,7 +2461,6 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
Returns:
|
||||
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
||||
"""
|
||||
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
elif self.conditioning_key == 'concat':
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
2026-02-08 09:20:29.036523: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 09:20:29.301726: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 09:20:29.656318: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 09:20:29.656367: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 09:20:29.662840: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 09:20:29.718736: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 09:20:29.718991: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-08 12:22:55.885867: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 12:22:55.890510: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 12:22:55.938683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 12:22:55.938759: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 12:22:55.941091: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 12:22:55.952450: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 12:22:55.952933: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 09:20:31.661239: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-08 12:22:56.593653: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
@@ -23,10 +23,19 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
>>> Applying precision settings:
|
||||
- Diffusion dtype: bf16
|
||||
- Projector mode: bf16_full
|
||||
- Encoder mode: bf16_full
|
||||
- VAE dtype: bf16
|
||||
✓ Diffusion model weights converted to bfloat16
|
||||
✓ Projectors converted to bfloat16
|
||||
✓ Encoders converted to bfloat16
|
||||
✓ VAE converted to bfloat16
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
@@ -106,7 +115,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:24<09:53, 84.82s/it]
|
||||
25%|██▌ | 2/8 [02:49<08:26, 84.48s/it]
|
||||
@@ -130,6 +139,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||
"psnr": 44.83864567508593
|
||||
"psnr": 30.44844270035179
|
||||
}
|
||||
@@ -4,7 +4,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--savedir "${res_dir}/output" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
--n_iter 8 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
--perframe_ae \
|
||||
--vae_dtype bf16
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user