整理代码
This commit is contained in:
@@ -752,13 +752,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# Channels-last memory format: cuDNN uses faster NHWC kernels
|
||||
if args.vae_channels_last:
|
||||
vae = vae.to(memory_format=torch.channels_last)
|
||||
vae._channels_last = True
|
||||
model.first_stage_model = vae
|
||||
print(">>> VAE converted to channels_last (NHWC) memory format")
|
||||
|
||||
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
||||
if args.vae_compile:
|
||||
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
||||
@@ -1173,12 +1166,6 @@ def get_parser():
|
||||
default=False,
|
||||
help="Apply torch.compile to VAE decoder for kernel fusion."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_channels_last",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_decode_bs",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user