ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
@@ -50,6 +50,20 @@ PEAK_BF16_TFLOPS = 61.0
|
|||||||
PEAK_FP32_TFLOPS = 30.5
|
PEAK_FP32_TFLOPS = 30.5
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
|
||||||
|
|
||||||
def load_model(args):
|
def load_model(args):
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||||
@@ -62,6 +76,7 @@ def load_model(args):
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.model.to(torch.bfloat16)
|
model.model.to(torch.bfloat16)
|
||||||
|
apply_torch_compile(model)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -135,6 +135,21 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||||
"""Save a list of frames to a video file.
|
"""Save a list of frames to a video file.
|
||||||
|
|
||||||
@@ -601,6 +616,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Apply precision settings before moving to GPU
|
# Apply precision settings before moving to GPU
|
||||||
model = apply_precision_settings(model, args)
|
model = apply_precision_settings(model, args)
|
||||||
|
|
||||||
|
# Compile hot ResBlocks for operator fusion
|
||||||
|
apply_torch_compile(model)
|
||||||
|
|
||||||
# Export precision-converted checkpoint if requested
|
# Export precision-converted checkpoint if requested
|
||||||
if args.export_precision_ckpt:
|
if args.export_precision_ckpt:
|
||||||
export_path = args.export_precision_ckpt
|
export_path = args.export_precision_ckpt
|
||||||
|
|||||||
@@ -209,9 +209,9 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
if precision is not None:
|
if precision is not None:
|
||||||
if precision == 16:
|
if precision == 16:
|
||||||
img = img.to(dtype=torch.float16)
|
img = img.to(dtype=torch.bfloat16)
|
||||||
action = action.to(dtype=torch.float16)
|
action = action.to(dtype=torch.bfloat16)
|
||||||
state = state.to(dtype=torch.float16)
|
state = state.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||||
@@ -190,6 +191,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
k_ip) * self.scale
|
k_ip) * self.scale
|
||||||
del k_ip
|
del k_ip
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
sim_ip = sim_ip.softmax(dim=-1)
|
sim_ip = sim_ip.softmax(dim=-1)
|
||||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||||
@@ -201,6 +203,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
k_as) * self.scale
|
k_as) * self.scale
|
||||||
del k_as
|
del k_as
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
sim_as = sim_as.softmax(dim=-1)
|
sim_as = sim_as.softmax(dim=-1)
|
||||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||||
@@ -212,6 +215,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
k_aa) * self.scale
|
k_aa) * self.scale
|
||||||
del k_aa
|
del k_aa
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
sim_aa = sim_aa.softmax(dim=-1)
|
sim_aa = sim_aa.softmax(dim=-1)
|
||||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
2026-02-08 15:47:30.035545: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-08 16:49:41.598605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-08 15:47:30.038628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-08 16:49:41.601687: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-08 15:47:30.069635: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-08 16:49:41.632954: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 15:47:30.069671: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-08 16:49:41.632986: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 15:47:30.071534: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-08 16:49:41.634849: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 15:47:30.080021: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-08 16:49:41.643134: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-08 15:47:30.080300: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-08 16:49:41.643414: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 15:47:30.746161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-08 16:49:42.320864: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
[rank: 0] Global seed set to 123
|
[rank: 0] Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
@@ -23,7 +23,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|||||||
INFO:root:Loaded ViT-H-14 model config.
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:183: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
>>> model checkpoint loaded.
|
>>> model checkpoint loaded.
|
||||||
>>> Load pre-trained model ...
|
>>> Load pre-trained model ...
|
||||||
@@ -38,6 +38,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|||||||
✓ VAE converted to bfloat16
|
✓ VAE converted to bfloat16
|
||||||
⚠ Found 601 fp32 params, converting to bf16
|
⚠ Found 601 fp32 params, converting to bf16
|
||||||
✓ All parameters converted to bfloat16
|
✓ All parameters converted to bfloat16
|
||||||
|
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -113,7 +114,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:15<08:45, 75.10s/it]
|
12%|█▎ | 1/8 [01:15<08:45, 75.10s/it]
|
||||||
25%|██▌ | 2/8 [02:26<07:17, 72.96s/it]
|
25%|██▌ | 2/8 [02:26<07:17, 72.96s/it]
|
||||||
@@ -137,6 +138,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||||
"psnr": 30.24435361473318
|
"psnr": 30.058508734449845
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user