对扩散主干做 BF16
量化对象:model.model(扩散 UNet/WMAModel 主体)
This commit is contained in:
@@ -729,6 +729,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
timestep_spacing: str = 'uniform',
|
timestep_spacing: str = 'uniform',
|
||||||
guidance_rescale: float = 0.0,
|
guidance_rescale: float = 0.0,
|
||||||
sim_mode: bool = True,
|
sim_mode: bool = True,
|
||||||
|
diffusion_autocast_dtype: Optional[torch.dtype] = None,
|
||||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||||
@@ -752,6 +753,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||||
|
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
|
||||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -814,25 +816,31 @@ def image_guided_synthesis_sim_mode(
|
|||||||
|
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
with profiler.profile_section("synthesis/ddim_sampling"):
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
autocast_ctx = nullcontext()
|
||||||
S=ddim_steps,
|
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
||||||
conditioning=cond,
|
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
||||||
batch_size=batch_size,
|
with autocast_ctx:
|
||||||
shape=noise_shape[1:],
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||||
verbose=False,
|
S=ddim_steps,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
conditioning=cond,
|
||||||
unconditional_conditioning=uc,
|
batch_size=batch_size,
|
||||||
eta=ddim_eta,
|
shape=noise_shape[1:],
|
||||||
cfg_img=None,
|
verbose=False,
|
||||||
mask=cond_mask,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
x0=cond_z0,
|
unconditional_conditioning=uc,
|
||||||
fs=fs,
|
eta=ddim_eta,
|
||||||
timestep_spacing=timestep_spacing,
|
cfg_img=None,
|
||||||
guidance_rescale=guidance_rescale,
|
mask=cond_mask,
|
||||||
**kwargs)
|
x0=cond_z0,
|
||||||
|
fs=fs,
|
||||||
|
timestep_spacing=timestep_spacing,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# Reconstruct from latent to pixel space
|
# Reconstruct from latent to pixel space
|
||||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||||
|
if samples.dtype != torch.float32:
|
||||||
|
samples = samples.float()
|
||||||
batch_images = model.decode_first_stage(samples)
|
batch_images = model.decode_first_stage(samples)
|
||||||
batch_variants = batch_images
|
batch_variants = batch_images
|
||||||
|
|
||||||
@@ -889,6 +897,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model = model.cuda(gpu_no)
|
model = model.cuda(gpu_no)
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
|
diffusion_autocast_dtype = None
|
||||||
|
if args.diffusion_dtype == "bf16":
|
||||||
|
with profiler.profile_section("model_loading/diffusion_bf16"):
|
||||||
|
model.model.to(dtype=torch.bfloat16)
|
||||||
|
diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
print(">>> diffusion backbone set to bfloat16")
|
||||||
|
|
||||||
log_inference_precision(model)
|
log_inference_precision(model)
|
||||||
|
|
||||||
profiler.record_memory("after_model_load")
|
profiler.record_memory("after_model_load")
|
||||||
@@ -1027,7 +1042,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fs=model_input_fs,
|
fs=model_input_fs,
|
||||||
timestep_spacing=args.timestep_spacing,
|
timestep_spacing=args.timestep_spacing,
|
||||||
guidance_rescale=args.guidance_rescale,
|
guidance_rescale=args.guidance_rescale,
|
||||||
sim_mode=False)
|
sim_mode=False,
|
||||||
|
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||||
|
|
||||||
# Update future actions in the observation queues
|
# Update future actions in the observation queues
|
||||||
with profiler.profile_section("update_action_queues"):
|
with profiler.profile_section("update_action_queues"):
|
||||||
@@ -1071,7 +1087,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fs=model_input_fs,
|
fs=model_input_fs,
|
||||||
text_input=False,
|
text_input=False,
|
||||||
timestep_spacing=args.timestep_spacing,
|
timestep_spacing=args.timestep_spacing,
|
||||||
guidance_rescale=args.guidance_rescale)
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||||
|
|
||||||
with profiler.profile_section("update_state_queues"):
|
with profiler.profile_section("update_state_queues"):
|
||||||
for step_idx in range(args.exe_steps):
|
for step_idx in range(args.exe_steps):
|
||||||
@@ -1223,6 +1240,13 @@ def get_parser():
|
|||||||
help=
|
help=
|
||||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diffusion_dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "bf16"],
|
||||||
|
default="fp32",
|
||||||
|
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_action_steps",
|
"--n_action_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
|
|||||||
dataset="unitree_g1_pack_camera"
|
dataset="unitree_g1_pack_camera"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--diffusion_dtype bf16
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
41
useful.sh
41
useful.sh
@@ -30,3 +30,44 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit
|
|||||||
1. torch.compile + AMP + TF32 + cudnn.benchmark
|
1. torch.compile + AMP + TF32 + cudnn.benchmark
|
||||||
2. 排查 .to()/copy/clone 的重复位置并移出循环
|
2. 排查 .to()/copy/clone 的重复位置并移出循环
|
||||||
3. 若需要更大幅度,再换采样器/降步数
|
3. 若需要更大幅度,再换采样器/降步数
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
A100 上我推荐 BF16 优先(稳定性更好、PSNR 更稳),FP16 作为速度优先方案。
|
||||||
|
|
||||||
|
下面是“分模块”的 消融方案(从稳到激进):
|
||||||
|
|
||||||
|
0)基线
|
||||||
|
|
||||||
|
- 全 FP32(你现在就是这个)
|
||||||
|
|
||||||
|
1)只对扩散主干做 BF16(最推荐)
|
||||||
|
|
||||||
|
- 量化对象:model.model(扩散 UNet/WMAModel 主体)
|
||||||
|
- 保持 FP32:first_stage_model(VAE 编/解码)、cond_stage_model(文本)、embedder(图像)、image_proj_model
|
||||||
|
- 预期:PSNR 基本不掉 or 极小波动
|
||||||
|
|
||||||
|
2)+ 轻量投影/MLP 做 BF16
|
||||||
|
|
||||||
|
- 增加:image_proj_model、state_projector、action_projector
|
||||||
|
- 预期:几乎不影响 PSNR
|
||||||
|
|
||||||
|
3)+ 文本/图像编码做 BF16
|
||||||
|
|
||||||
|
- 增加:cond_stage_model、embedder
|
||||||
|
- 预期:可能有轻微波动,通常仍可接受
|
||||||
|
|
||||||
|
4)VAE 也做 BF16(最容易伤 PSNR)
|
||||||
|
|
||||||
|
- 增加:first_stage_model
|
||||||
|
- 预期:画质/PSNR 最敏感,建议最后做消融
|
||||||
|
|
||||||
|
———
|
||||||
|
|
||||||
|
具体建议(A100)
|
||||||
|
|
||||||
|
- 优先 BF16:稳定性好于 FP16
|
||||||
|
- 只做半精度,不做 INT 量化:保持 PSNR
|
||||||
|
- VAE 尽量 FP32:最影响画质的模块
|
||||||
|
|||||||
Reference in New Issue
Block a user