整理代码

This commit is contained in:
qhy
2026-02-10 12:46:12 +08:00
parent f1f92072e6
commit bb274870c2
3 changed files with 1 additions and 19 deletions

View File

@@ -752,13 +752,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
if hasattr(model, "first_stage_model") and model.first_stage_model is not None: if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
vae = model.first_stage_model vae = model.first_stage_model
# Channels-last memory format: cuDNN uses faster NHWC kernels
if args.vae_channels_last:
vae = vae.to(memory_format=torch.channels_last)
vae._channels_last = True
model.first_stage_model = vae
print(">>> VAE converted to channels_last (NHWC) memory format")
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc. # torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
if args.vae_compile: if args.vae_compile:
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead") vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
@@ -1173,12 +1166,6 @@ def get_parser():
default=False, default=False,
help="Apply torch.compile to VAE decoder for kernel fusion." help="Apply torch.compile to VAE decoder for kernel fusion."
) )
parser.add_argument(
"--vae_channels_last",
action='store_true',
default=False,
help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions."
)
parser.add_argument( parser.add_argument(
"--vae_decode_bs", "--vae_decode_bs",
type=int, type=int,

View File

@@ -99,16 +99,12 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}") print(f"Restored from {path}")
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
if getattr(self, '_channels_last', False):
x = x.to(memory_format=torch.channels_last)
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
return posterior return posterior
def decode(self, z, **kwargs): def decode(self, z, **kwargs):
if getattr(self, '_channels_last', False):
z = z.to(memory_format=torch.channels_last)
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec

View File

@@ -24,6 +24,5 @@ dataset="unitree_g1_pack_camera"
--diffusion_dtype bf16 \ --diffusion_dtype bf16 \
--projector_mode bf16_full \ --projector_mode bf16_full \
--encoder_mode bf16_full \ --encoder_mode bf16_full \
--vae_dtype bf16 \ --vae_dtype bf16
--vae_channels_last
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"