VAE 也做 BF16

这个权重不做修改更好精度
This commit is contained in:
2026-01-18 21:14:55 +08:00
parent e1b029201e
commit a90efc6718
6 changed files with 67 additions and 16 deletions

View File

@@ -649,7 +649,7 @@ def prepare_init_input(start_idx: int,
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
@@ -661,11 +661,15 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
profiler = get_profiler()
with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext()
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
@@ -879,9 +883,18 @@ def image_guided_synthesis_sim_mode(
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if samples.dtype != torch.float32:
samples = samples.float()
batch_images = model.decode_first_stage(samples)
if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16)
vae_ctx = nullcontext()
if model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
batch_images = model.decode_first_stage(samples)
else:
if samples.dtype != torch.float32:
samples = samples.float()
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
@@ -944,6 +957,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
if args.vae_dtype == "bf16":
model.first_stage_model.to(dtype=torch.bfloat16)
else:
model.first_stage_model.to(dtype=torch.float32)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
@@ -957,9 +978,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
)
projector_mode = args.projector_mode
projector_bf16 = projector_mode in ("autocast", "bf16_full")
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
model.image_proj_model.to(dtype=projector_weight_dtype)
if hasattr(model, "state_projector") and model.state_projector is not None:
model.state_projector.to(dtype=projector_weight_dtype)
if hasattr(model, "action_projector") and model.action_projector is not None:
model.action_projector.to(dtype=projector_weight_dtype)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = args.projector_dtype == "bf16"
print(f">>> projector dtype set to {args.projector_dtype}")
model.projector_bf16 = projector_bf16
model.projector_mode = projector_mode
print(
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
)
log_inference_precision(model)
@@ -1305,11 +1338,14 @@ def get_parser():
help="Dtype for diffusion backbone weights and sampling autocast."
)
parser.add_argument(
"--projector_dtype",
"--projector_mode",
type=str,
choices=["fp32", "bf16"],
choices=["fp32", "autocast", "bf16_full"],
default="fp32",
help="Dtype for image/state/action projectors (autocast in forward)."
help=
"Projector precision mode for image/state/action projectors: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--encoder_mode",
@@ -1321,6 +1357,13 @@ def get_parser():
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--n_action_steps",
type=int,