VAE优化,模型直接加载至GPU
This commit is contained in:
@@ -222,7 +222,7 @@ data:
|
|||||||
test:
|
test:
|
||||||
target: unifolm_wma.data.wma_data.WMAData
|
target: unifolm_wma.data.wma_data.WMAData
|
||||||
params:
|
params:
|
||||||
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||||
video_length: ${model.params.wma_config.params.temporal_length}
|
video_length: ${model.params.wma_config.params.temporal_length}
|
||||||
frame_stride: 2
|
frame_stride: 2
|
||||||
load_raw_resolution: True
|
load_raw_resolution: True
|
||||||
|
|||||||
@@ -458,7 +458,8 @@ def _load_state_dict(model: nn.Module,
|
|||||||
|
|
||||||
def load_model_checkpoint(model: nn.Module,
|
def load_model_checkpoint(model: nn.Module,
|
||||||
ckpt: str,
|
ckpt: str,
|
||||||
assign: bool | None = None) -> nn.Module:
|
assign: bool | None = None,
|
||||||
|
device: str | torch.device = "cpu") -> nn.Module:
|
||||||
"""Load model weights from checkpoint file.
|
"""Load model weights from checkpoint file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -467,11 +468,12 @@ def load_model_checkpoint(model: nn.Module,
|
|||||||
assign (bool | None): Whether to preserve checkpoint tensor dtypes
|
assign (bool | None): Whether to preserve checkpoint tensor dtypes
|
||||||
via load_state_dict(assign=True). If None, auto-enable when a
|
via load_state_dict(assign=True). If None, auto-enable when a
|
||||||
casted checkpoint metadata is detected.
|
casted checkpoint metadata is detected.
|
||||||
|
device (str | torch.device): Target device for loaded tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: Model with loaded weights.
|
nn.Module: Model with loaded weights.
|
||||||
"""
|
"""
|
||||||
ckpt_data = torch.load(ckpt, map_location="cpu")
|
ckpt_data = torch.load(ckpt, map_location=device, mmap=True)
|
||||||
use_assign = False
|
use_assign = False
|
||||||
if assign is not None:
|
if assign is not None:
|
||||||
use_assign = assign
|
use_assign = assign
|
||||||
@@ -1035,8 +1037,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
|
|
||||||
with profiler.profile_section("model_loading/checkpoint"):
|
with profiler.profile_section("model_loading/checkpoint"):
|
||||||
model = load_model_checkpoint(model, args.ckpt_path)
|
model = load_model_checkpoint(model, args.ckpt_path,
|
||||||
|
device=f"cuda:{gpu_no}")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
||||||
print(f'>>> Load pre-trained model ...')
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
# Build unnomalizer
|
# Build unnomalizer
|
||||||
@@ -1045,9 +1049,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
data.setup()
|
data.setup()
|
||||||
print(">>> Dataset is successfully loaded ...")
|
print(">>> Dataset is successfully loaded ...")
|
||||||
|
|
||||||
with profiler.profile_section("model_to_cuda"):
|
|
||||||
model = model.cuda(gpu_no)
|
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
diffusion_autocast_dtype = None
|
diffusion_autocast_dtype = None
|
||||||
@@ -1074,6 +1075,32 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||||
|
|
||||||
|
# --- VAE performance optimizations ---
|
||||||
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||||
|
vae = model.first_stage_model
|
||||||
|
|
||||||
|
# Channels-last memory format: cuDNN uses faster NHWC kernels
|
||||||
|
if args.vae_channels_last:
|
||||||
|
vae = vae.to(memory_format=torch.channels_last)
|
||||||
|
vae._channels_last = True
|
||||||
|
model.first_stage_model = vae
|
||||||
|
print(">>> VAE converted to channels_last (NHWC) memory format")
|
||||||
|
|
||||||
|
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
||||||
|
if args.vae_compile:
|
||||||
|
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
||||||
|
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
|
||||||
|
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
|
||||||
|
|
||||||
|
# Batch decode size
|
||||||
|
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
|
||||||
|
model.vae_decode_bs = vae_decode_bs
|
||||||
|
model.vae_encode_bs = vae_decode_bs
|
||||||
|
if args.vae_decode_bs > 0:
|
||||||
|
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
|
||||||
|
else:
|
||||||
|
print(">>> VAE encode/decode batch size: all frames at once")
|
||||||
|
|
||||||
encoder_mode = args.encoder_mode
|
encoder_mode = args.encoder_mode
|
||||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||||
@@ -1511,6 +1538,24 @@ def get_parser():
|
|||||||
default="fp32",
|
default="fp32",
|
||||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_compile",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Apply torch.compile to VAE decoder for kernel fusion."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_channels_last",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_decode_bs",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--export_casted_ckpt",
|
"--export_casted_ckpt",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -99,13 +99,16 @@ class AutoencoderKL(pl.LightningModule):
|
|||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
|
if getattr(self, '_channels_last', False):
|
||||||
|
x = x.to(memory_format=torch.channels_last)
|
||||||
h = self.encoder(x)
|
h = self.encoder(x)
|
||||||
moments = self.quant_conv(h)
|
moments = self.quant_conv(h)
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
return posterior
|
return posterior
|
||||||
|
|
||||||
def decode(self, z, **kwargs):
|
def decode(self, z, **kwargs):
|
||||||
|
if getattr(self, '_channels_last', False):
|
||||||
|
z = z.to(memory_format=torch.channels_last)
|
||||||
z = self.post_quant_conv(z)
|
z = self.post_quant_conv(z)
|
||||||
dec = self.decoder(z)
|
dec = self.decoder(z)
|
||||||
return dec
|
return dec
|
||||||
|
|||||||
@@ -1073,15 +1073,19 @@ class LatentDiffusion(DDPM):
|
|||||||
if not self.perframe_ae:
|
if not self.perframe_ae:
|
||||||
encoder_posterior = self.first_stage_model.encode(x)
|
encoder_posterior = self.first_stage_model.encode(x)
|
||||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||||
else: ## Consume less GPU memory but slower
|
else: ## Batch encode with configurable batch size
|
||||||
results = []
|
bs = getattr(self, 'vae_encode_bs', 1)
|
||||||
for index in range(x.shape[0]):
|
if bs >= x.shape[0]:
|
||||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
encoder_posterior = self.first_stage_model.encode(x)
|
||||||
1, :, :, :])
|
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||||
frame_result = self.get_first_stage_encoding(
|
else:
|
||||||
frame_batch).detach()
|
results = []
|
||||||
results.append(frame_result)
|
for i in range(0, x.shape[0], bs):
|
||||||
results = torch.cat(results, dim=0)
|
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||||
|
frame_result = self.get_first_stage_encoding(
|
||||||
|
frame_batch).detach()
|
||||||
|
results.append(frame_result)
|
||||||
|
results = torch.cat(results, dim=0)
|
||||||
|
|
||||||
if reshape_back:
|
if reshape_back:
|
||||||
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
@@ -1105,16 +1109,21 @@ class LatentDiffusion(DDPM):
|
|||||||
else:
|
else:
|
||||||
reshape_back = False
|
reshape_back = False
|
||||||
|
|
||||||
|
z = 1. / self.scale_factor * z
|
||||||
|
|
||||||
if not self.perframe_ae:
|
if not self.perframe_ae:
|
||||||
z = 1. / self.scale_factor * z
|
|
||||||
results = self.first_stage_model.decode(z, **kwargs)
|
results = self.first_stage_model.decode(z, **kwargs)
|
||||||
else:
|
else:
|
||||||
results = []
|
bs = getattr(self, 'vae_decode_bs', 1)
|
||||||
for index in range(z.shape[0]):
|
if bs >= z.shape[0]:
|
||||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
# all frames in one batch
|
||||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
results = self.first_stage_model.decode(z, **kwargs)
|
||||||
results.append(frame_result)
|
else:
|
||||||
results = torch.cat(results, dim=0)
|
results = []
|
||||||
|
for i in range(0, z.shape[0], bs):
|
||||||
|
results.append(
|
||||||
|
self.first_stage_model.decode(z[i:i + bs], **kwargs))
|
||||||
|
results = torch.cat(results, dim=0)
|
||||||
|
|
||||||
if reshape_back:
|
if reshape_back:
|
||||||
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
|||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
# swish
|
# swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
|
||||||
return x * torch.sigmoid(x)
|
return torch.nn.functional.silu(x)
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
|||||||
Reference in New Issue
Block a user