VAE 也做 BF16
这个权重不做修改更好精度
This commit is contained in:
@@ -664,7 +664,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
|||||||
with profiler.profile_section("get_latent_z/encode"):
|
with profiler.profile_section("get_latent_z/encode"):
|
||||||
b, c, t, h, w = videos.shape
|
b, c, t, h, w = videos.shape
|
||||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||||
z = model.encode_first_stage(x)
|
vae_ctx = nullcontext()
|
||||||
|
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
|
||||||
|
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||||
|
with vae_ctx:
|
||||||
|
z = model.encode_first_stage(x)
|
||||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
@@ -879,9 +883,18 @@ def image_guided_synthesis_sim_mode(
|
|||||||
|
|
||||||
# Reconstruct from latent to pixel space
|
# Reconstruct from latent to pixel space
|
||||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||||
if samples.dtype != torch.float32:
|
if getattr(model, "vae_bf16", False):
|
||||||
samples = samples.float()
|
if samples.dtype != torch.bfloat16:
|
||||||
batch_images = model.decode_first_stage(samples)
|
samples = samples.to(dtype=torch.bfloat16)
|
||||||
|
vae_ctx = nullcontext()
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||||
|
with vae_ctx:
|
||||||
|
batch_images = model.decode_first_stage(samples)
|
||||||
|
else:
|
||||||
|
if samples.dtype != torch.float32:
|
||||||
|
samples = samples.float()
|
||||||
|
batch_images = model.decode_first_stage(samples)
|
||||||
batch_variants = batch_images
|
batch_variants = batch_images
|
||||||
|
|
||||||
return batch_variants, actions, states
|
return batch_variants, actions, states
|
||||||
@@ -944,6 +957,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
diffusion_autocast_dtype = torch.bfloat16
|
diffusion_autocast_dtype = torch.bfloat16
|
||||||
print(">>> diffusion backbone set to bfloat16")
|
print(">>> diffusion backbone set to bfloat16")
|
||||||
|
|
||||||
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||||
|
if args.vae_dtype == "bf16":
|
||||||
|
model.first_stage_model.to(dtype=torch.bfloat16)
|
||||||
|
else:
|
||||||
|
model.first_stage_model.to(dtype=torch.float32)
|
||||||
|
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||||
|
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||||
|
|
||||||
encoder_mode = args.encoder_mode
|
encoder_mode = args.encoder_mode
|
||||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||||
@@ -957,9 +978,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
projector_mode = args.projector_mode
|
||||||
|
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
||||||
|
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
||||||
|
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
||||||
|
model.image_proj_model.to(dtype=projector_weight_dtype)
|
||||||
|
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||||
|
model.state_projector.to(dtype=projector_weight_dtype)
|
||||||
|
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||||
|
model.action_projector.to(dtype=projector_weight_dtype)
|
||||||
if hasattr(model, "projector_bf16"):
|
if hasattr(model, "projector_bf16"):
|
||||||
model.projector_bf16 = args.projector_dtype == "bf16"
|
model.projector_bf16 = projector_bf16
|
||||||
print(f">>> projector dtype set to {args.projector_dtype}")
|
model.projector_mode = projector_mode
|
||||||
|
print(
|
||||||
|
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
|
||||||
|
)
|
||||||
|
|
||||||
log_inference_precision(model)
|
log_inference_precision(model)
|
||||||
|
|
||||||
@@ -1305,11 +1338,14 @@ def get_parser():
|
|||||||
help="Dtype for diffusion backbone weights and sampling autocast."
|
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--projector_dtype",
|
"--projector_mode",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["fp32", "bf16"],
|
choices=["fp32", "autocast", "bf16_full"],
|
||||||
default="fp32",
|
default="fp32",
|
||||||
help="Dtype for image/state/action projectors (autocast in forward)."
|
help=
|
||||||
|
"Projector precision mode for image/state/action projectors: "
|
||||||
|
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
||||||
|
"bf16_full=bf16 weights + bf16 forward."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder_mode",
|
"--encoder_mode",
|
||||||
@@ -1321,6 +1357,13 @@ def get_parser():
|
|||||||
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
||||||
"bf16_full=bf16 weights + bf16 forward."
|
"bf16_full=bf16 weights + bf16 forward."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "bf16"],
|
||||||
|
default="fp32",
|
||||||
|
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_action_steps",
|
"--n_action_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -2032,6 +2032,13 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
target_dtype: torch.dtype | None) -> Tensor:
|
target_dtype: torch.dtype | None) -> Tensor:
|
||||||
use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
|
use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
|
||||||
and torch.cuda.is_bf16_supported())
|
and torch.cuda.is_bf16_supported())
|
||||||
|
if not use_bf16:
|
||||||
|
weight_dtype = None
|
||||||
|
for param in projector.parameters():
|
||||||
|
weight_dtype = param.dtype
|
||||||
|
break
|
||||||
|
if weight_dtype is not None and x.dtype != weight_dtype:
|
||||||
|
x = x.to(dtype=weight_dtype)
|
||||||
if use_bf16:
|
if use_bf16:
|
||||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
out = projector(x)
|
out = projector(x)
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -22,6 +22,7 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--diffusion_dtype bf16 \
|
--diffusion_dtype bf16 \
|
||||||
--projector_dtype bf16 \
|
--projector_mode autocast \
|
||||||
--encoder_mode autocast #fp32/autocast/bf16_full
|
--encoder_mode bf16_full \
|
||||||
|
--vae_dtype bf16
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
Reference in New Issue
Block a user