对扩散主干做 BF16

量化对象:model.model(扩散 UNet/WMAModel 主体)
This commit is contained in:
2026-01-18 17:14:16 +08:00
parent 7b499284bf
commit 2b634cde90
5 changed files with 141 additions and 75 deletions

View File

@@ -714,22 +714,23 @@ def preprocess_observation(
return return_observations return return_observations
def image_guided_synthesis_sim_mode( def image_guided_synthesis_sim_mode(
model: torch.nn.Module, model: torch.nn.Module,
prompts: list[str], prompts: list[str],
observation: dict, observation: dict,
noise_shape: tuple[int, int, int, int, int], noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16, action_cond_step: int = 16,
n_samples: int = 1, n_samples: int = 1,
ddim_steps: int = 50, ddim_steps: int = 50,
ddim_eta: float = 1.0, ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0, unconditional_guidance_scale: float = 1.0,
fs: int | None = None, fs: int | None = None,
text_input: bool = True, text_input: bool = True,
timestep_spacing: str = 'uniform', timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
sim_mode: bool = True, sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diffusion_autocast_dtype: Optional[torch.dtype] = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -750,9 +751,10 @@ def image_guided_synthesis_sim_mode(
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
**kwargs: Additional arguments passed to the DDIM sampler. diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
**kwargs: Additional arguments passed to the DDIM sampler.
Returns: Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
@@ -810,31 +812,37 @@ def image_guided_synthesis_sim_mode(
uc = None uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None}) kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None cond_mask = None
cond_z0 = None cond_z0 = None
if ddim_sampler is not None: if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"): with profiler.profile_section("synthesis/ddim_sampling"):
samples, actions, states, intermedia = ddim_sampler.sample( autocast_ctx = nullcontext()
S=ddim_steps, if diffusion_autocast_dtype is not None and model.device.type == "cuda":
conditioning=cond, autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
batch_size=batch_size, with autocast_ctx:
shape=noise_shape[1:], samples, actions, states, intermedia = ddim_sampler.sample(
verbose=False, S=ddim_steps,
unconditional_guidance_scale=unconditional_guidance_scale, conditioning=cond,
unconditional_conditioning=uc, batch_size=batch_size,
eta=ddim_eta, shape=noise_shape[1:],
cfg_img=None, verbose=False,
mask=cond_mask, unconditional_guidance_scale=unconditional_guidance_scale,
x0=cond_z0, unconditional_conditioning=uc,
fs=fs, eta=ddim_eta,
timestep_spacing=timestep_spacing, cfg_img=None,
guidance_rescale=guidance_rescale, mask=cond_mask,
**kwargs) x0=cond_z0,
fs=fs,
# Reconstruct from latent to pixel space timestep_spacing=timestep_spacing,
with profiler.profile_section("synthesis/decode_first_stage"): guidance_rescale=guidance_rescale,
batch_images = model.decode_first_stage(samples) **kwargs)
batch_variants = batch_images
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if samples.dtype != torch.float32:
samples = samples.float()
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states return batch_variants, actions, states
@@ -889,6 +897,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model = model.cuda(gpu_no) model = model.cuda(gpu_no)
device = get_device_from_parameters(model) device = get_device_from_parameters(model)
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
with profiler.profile_section("model_loading/diffusion_bf16"):
model.model.to(dtype=torch.bfloat16)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
log_inference_precision(model) log_inference_precision(model)
profiler.record_memory("after_model_load") profiler.record_memory("after_model_load")
@@ -1014,20 +1029,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action # Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...') print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"): with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model, model,
sample['instruction'], sample['instruction'],
observation, observation,
noise_shape, noise_shape,
action_cond_step=args.exe_steps, action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps, ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args. unconditional_guidance_scale=args.
unconditional_guidance_scale, unconditional_guidance_scale,
fs=model_input_fs, fs=model_input_fs,
timestep_spacing=args.timestep_spacing, timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale, guidance_rescale=args.guidance_rescale,
sim_mode=False) sim_mode=False,
diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues # Update future actions in the observation queues
with profiler.profile_section("update_action_queues"): with profiler.profile_section("update_action_queues"):
@@ -1058,20 +1074,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model # Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...') print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"): with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model, model,
"", "",
observation, observation,
noise_shape, noise_shape,
action_cond_step=args.exe_steps, action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps, ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args. unconditional_guidance_scale=args.
unconditional_guidance_scale, unconditional_guidance_scale,
fs=model_input_fs, fs=model_input_fs,
text_input=False, text_input=False,
timestep_spacing=args.timestep_spacing, timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale) guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype)
with profiler.profile_section("update_state_queues"): with profiler.profile_section("update_state_queues"):
for step_idx in range(args.exe_steps): for step_idx in range(args.exe_steps):
@@ -1216,13 +1233,20 @@ def get_parser():
help= help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
) )
parser.add_argument( parser.add_argument(
"--perframe_ae", "--perframe_ae",
action='store_true', action='store_true',
default=False, default=False,
help= help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
) )
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for diffusion backbone weights and sampling autocast."
)
parser.add_argument( parser.add_argument(
"--n_action_steps", "--n_action_steps",
type=int, type=int,

View File

@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
dataset="unitree_g1_pack_camera" dataset="unitree_g1_pack_camera"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
--n_iter 11 \ --n_iter 11 \
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae --perframe_ae \
--diffusion_dtype bf16
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"

View File

@@ -29,4 +29,45 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit
1. torch.compile + AMP + TF32 + cudnn.benchmark 1. torch.compile + AMP + TF32 + cudnn.benchmark
2. 排查 .to()/copy/clone 的重复位置并移出循环 2. 排查 .to()/copy/clone 的重复位置并移出循环
3. 若需要更大幅度,再换采样器/降步数 3. 若需要更大幅度,再换采样器/降步数
A100 上我推荐 BF16 优先稳定性更好、PSNR 更稳FP16 作为速度优先方案。
下面是“分模块”的 消融方案(从稳到激进):
0基线
- 全 FP32你现在就是这个
1只对扩散主干做 BF16最推荐
- 量化对象model.model扩散 UNet/WMAModel 主体)
- 保持 FP32first_stage_modelVAE 编/解码、cond_stage_model文本、embedder图像、image_proj_model
- 预期PSNR 基本不掉 or 极小波动
2+ 轻量投影/MLP 做 BF16
- 增加image_proj_model、state_projector、action_projector
- 预期:几乎不影响 PSNR
3+ 文本/图像编码做 BF16
- 增加cond_stage_model、embedder
- 预期:可能有轻微波动,通常仍可接受
4VAE 也做 BF16最容易伤 PSNR
- 增加first_stage_model
- 预期:画质/PSNR 最敏感,建议最后做消融
———
具体建议A100
- 优先 BF16稳定性好于 FP16
- 只做半精度,不做 INT 量化:保持 PSNR
- VAE 尽量 FP32最影响画质的模块