VAE 也做 BF16

这个权重不做修改更好精度
This commit is contained in:
2026-01-18 21:14:55 +08:00
parent e1b029201e
commit a90efc6718
6 changed files with 67 additions and 16 deletions

View File

@@ -664,6 +664,10 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
with profiler.profile_section("get_latent_z/encode"): with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w') x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext()
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
z = model.encode_first_stage(x) z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z return z
@@ -879,6 +883,15 @@ def image_guided_synthesis_sim_mode(
# Reconstruct from latent to pixel space # Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"): with profiler.profile_section("synthesis/decode_first_stage"):
if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16)
vae_ctx = nullcontext()
if model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
batch_images = model.decode_first_stage(samples)
else:
if samples.dtype != torch.float32: if samples.dtype != torch.float32:
samples = samples.float() samples = samples.float()
batch_images = model.decode_first_stage(samples) batch_images = model.decode_first_stage(samples)
@@ -944,6 +957,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = torch.bfloat16 diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16") print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
if args.vae_dtype == "bf16":
model.first_stage_model.to(dtype=torch.bfloat16)
else:
model.first_stage_model.to(dtype=torch.float32)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
encoder_mode = args.encoder_mode encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full") encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
@@ -957,9 +978,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})" f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
) )
projector_mode = args.projector_mode
projector_bf16 = projector_mode in ("autocast", "bf16_full")
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
model.image_proj_model.to(dtype=projector_weight_dtype)
if hasattr(model, "state_projector") and model.state_projector is not None:
model.state_projector.to(dtype=projector_weight_dtype)
if hasattr(model, "action_projector") and model.action_projector is not None:
model.action_projector.to(dtype=projector_weight_dtype)
if hasattr(model, "projector_bf16"): if hasattr(model, "projector_bf16"):
model.projector_bf16 = args.projector_dtype == "bf16" model.projector_bf16 = projector_bf16
print(f">>> projector dtype set to {args.projector_dtype}") model.projector_mode = projector_mode
print(
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
)
log_inference_precision(model) log_inference_precision(model)
@@ -1305,11 +1338,14 @@ def get_parser():
help="Dtype for diffusion backbone weights and sampling autocast." help="Dtype for diffusion backbone weights and sampling autocast."
) )
parser.add_argument( parser.add_argument(
"--projector_dtype", "--projector_mode",
type=str, type=str,
choices=["fp32", "bf16"], choices=["fp32", "autocast", "bf16_full"],
default="fp32", default="fp32",
help="Dtype for image/state/action projectors (autocast in forward)." help=
"Projector precision mode for image/state/action projectors: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
) )
parser.add_argument( parser.add_argument(
"--encoder_mode", "--encoder_mode",
@@ -1321,6 +1357,13 @@ def get_parser():
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward." "bf16_full=bf16 weights + bf16 forward."
) )
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument( parser.add_argument(
"--n_action_steps", "--n_action_steps",
type=int, type=int,

View File

@@ -2032,6 +2032,13 @@ class LatentVisualDiffusion(LatentDiffusion):
target_dtype: torch.dtype | None) -> Tensor: target_dtype: torch.dtype | None) -> Tensor:
use_bf16 = (self.projector_bf16 and x.device.type == "cuda" use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
and torch.cuda.is_bf16_supported()) and torch.cuda.is_bf16_supported())
if not use_bf16:
weight_dtype = None
for param in projector.parameters():
weight_dtype = param.dtype
break
if weight_dtype is not None and x.dtype != weight_dtype:
x = x.to(dtype=weight_dtype)
if use_bf16: if use_bf16:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
out = projector(x) out = projector(x)

View File

@@ -22,6 +22,7 @@ dataset="unitree_g1_pack_camera"
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae \ --perframe_ae \
--diffusion_dtype bf16 \ --diffusion_dtype bf16 \
--projector_dtype bf16 \ --projector_mode autocast \
--encoder_mode autocast #fp32/autocast/bf16_full --encoder_mode bf16_full \
--vae_dtype bf16
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"