对扩散主干做 BF16
量化对象:model.model(扩散 UNet/WMAModel 主体)
This commit is contained in:
@@ -729,6 +729,7 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
diffusion_autocast_dtype: Optional[torch.dtype] = None,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||
@@ -752,6 +753,7 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
@@ -814,25 +816,31 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
if ddim_sampler is not None:
|
||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
autocast_ctx = nullcontext()
|
||||
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
||||
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
||||
with autocast_ctx:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||
if samples.dtype != torch.float32:
|
||||
samples = samples.float()
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
@@ -889,6 +897,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
diffusion_autocast_dtype = None
|
||||
if args.diffusion_dtype == "bf16":
|
||||
with profiler.profile_section("model_loading/diffusion_bf16"):
|
||||
model.model.to(dtype=torch.bfloat16)
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
@@ -1027,7 +1042,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False)
|
||||
sim_mode=False,
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
with profiler.profile_section("update_action_queues"):
|
||||
@@ -1071,7 +1087,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
with profiler.profile_section("update_state_queues"):
|
||||
for step_idx in range(args.exe_steps):
|
||||
@@ -1223,6 +1240,13 @@ def get_parser():
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diffusion_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_action_steps",
|
||||
type=int,
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
|
||||
dataset="unitree_g1_pack_camera"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
|
||||
--n_iter 11 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
--perframe_ae \
|
||||
--diffusion_dtype bf16
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
41
useful.sh
41
useful.sh
@@ -30,3 +30,44 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit
|
||||
1. torch.compile + AMP + TF32 + cudnn.benchmark
|
||||
2. 排查 .to()/copy/clone 的重复位置并移出循环
|
||||
3. 若需要更大幅度,再换采样器/降步数
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
A100 上我推荐 BF16 优先(稳定性更好、PSNR 更稳),FP16 作为速度优先方案。
|
||||
|
||||
下面是“分模块”的 消融方案(从稳到激进):
|
||||
|
||||
0)基线
|
||||
|
||||
- 全 FP32(你现在就是这个)
|
||||
|
||||
1)只对扩散主干做 BF16(最推荐)
|
||||
|
||||
- 量化对象:model.model(扩散 UNet/WMAModel 主体)
|
||||
- 保持 FP32:first_stage_model(VAE 编/解码)、cond_stage_model(文本)、embedder(图像)、image_proj_model
|
||||
- 预期:PSNR 基本不掉 or 极小波动
|
||||
|
||||
2)+ 轻量投影/MLP 做 BF16
|
||||
|
||||
- 增加:image_proj_model、state_projector、action_projector
|
||||
- 预期:几乎不影响 PSNR
|
||||
|
||||
3)+ 文本/图像编码做 BF16
|
||||
|
||||
- 增加:cond_stage_model、embedder
|
||||
- 预期:可能有轻微波动,通常仍可接受
|
||||
|
||||
4)VAE 也做 BF16(最容易伤 PSNR)
|
||||
|
||||
- 增加:first_stage_model
|
||||
- 预期:画质/PSNR 最敏感,建议最后做消融
|
||||
|
||||
———
|
||||
|
||||
具体建议(A100)
|
||||
|
||||
- 优先 BF16:稳定性好于 FP16
|
||||
- 只做半精度,不做 INT 量化:保持 PSNR
|
||||
- VAE 尽量 FP32:最影响画质的模块
|
||||
|
||||
Reference in New Issue
Block a user