整理代码
This commit is contained in:
@@ -752,13 +752,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
|
|
||||||
# Channels-last memory format: cuDNN uses faster NHWC kernels
|
|
||||||
if args.vae_channels_last:
|
|
||||||
vae = vae.to(memory_format=torch.channels_last)
|
|
||||||
vae._channels_last = True
|
|
||||||
model.first_stage_model = vae
|
|
||||||
print(">>> VAE converted to channels_last (NHWC) memory format")
|
|
||||||
|
|
||||||
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
||||||
if args.vae_compile:
|
if args.vae_compile:
|
||||||
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
||||||
@@ -1173,12 +1166,6 @@ def get_parser():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Apply torch.compile to VAE decoder for kernel fusion."
|
help="Apply torch.compile to VAE decoder for kernel fusion."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--vae_channels_last",
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vae_decode_bs",
|
"--vae_decode_bs",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -99,16 +99,12 @@ class AutoencoderKL(pl.LightningModule):
|
|||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
if getattr(self, '_channels_last', False):
|
|
||||||
x = x.to(memory_format=torch.channels_last)
|
|
||||||
h = self.encoder(x)
|
h = self.encoder(x)
|
||||||
moments = self.quant_conv(h)
|
moments = self.quant_conv(h)
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
return posterior
|
return posterior
|
||||||
|
|
||||||
def decode(self, z, **kwargs):
|
def decode(self, z, **kwargs):
|
||||||
if getattr(self, '_channels_last', False):
|
|
||||||
z = z.to(memory_format=torch.channels_last)
|
|
||||||
z = self.post_quant_conv(z)
|
z = self.post_quant_conv(z)
|
||||||
dec = self.decoder(z)
|
dec = self.decoder(z)
|
||||||
return dec
|
return dec
|
||||||
|
|||||||
@@ -24,6 +24,5 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--diffusion_dtype bf16 \
|
--diffusion_dtype bf16 \
|
||||||
--projector_mode bf16_full \
|
--projector_mode bf16_full \
|
||||||
--encoder_mode bf16_full \
|
--encoder_mode bf16_full \
|
||||||
--vae_dtype bf16 \
|
--vae_dtype bf16
|
||||||
--vae_channels_last
|
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
Reference in New Issue
Block a user