复制模型对象,跳过加载模型
This commit is contained in:
@@ -1026,24 +1026,171 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
df = pd.read_csv(csv_path)
|
df = pd.read_csv(csv_path)
|
||||||
|
|
||||||
# Load config
|
# Load config (always needed for data setup)
|
||||||
with profiler.profile_section("model_loading/config"):
|
config = OmegaConf.load(args.config)
|
||||||
config = OmegaConf.load(args.config)
|
|
||||||
config['model']['params']['wma_config']['params'][
|
|
||||||
'use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.perframe_ae = args.perframe_ae
|
|
||||||
|
|
||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||||
|
if os.path.exists(prepared_path):
|
||||||
|
# ---- Fast path: load the fully-prepared model ----
|
||||||
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||||
|
with profiler.profile_section("model_loading/prepared"):
|
||||||
|
model = torch.load(prepared_path,
|
||||||
|
map_location=f"cuda:{gpu_no}",
|
||||||
|
weights_only=False,
|
||||||
|
mmap=True)
|
||||||
|
model.eval()
|
||||||
|
diffusion_autocast_dtype = (torch.bfloat16
|
||||||
|
if args.diffusion_dtype == "bf16"
|
||||||
|
else None)
|
||||||
|
print(f">>> Prepared model loaded.")
|
||||||
|
else:
|
||||||
|
# ---- Normal path: construct + checkpoint + casting ----
|
||||||
|
with profiler.profile_section("model_loading/config"):
|
||||||
|
config['model']['params']['wma_config']['params'][
|
||||||
|
'use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
|
||||||
with profiler.profile_section("model_loading/checkpoint"):
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
model = load_model_checkpoint(model, args.ckpt_path,
|
|
||||||
device=f"cuda:{gpu_no}")
|
|
||||||
model.eval()
|
|
||||||
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
|
||||||
print(f'>>> Load pre-trained model ...')
|
|
||||||
|
|
||||||
# Build unnomalizer
|
with profiler.profile_section("model_loading/checkpoint"):
|
||||||
|
model = load_model_checkpoint(model, args.ckpt_path,
|
||||||
|
device=f"cuda:{gpu_no}")
|
||||||
|
model.eval()
|
||||||
|
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
||||||
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
|
diffusion_autocast_dtype = None
|
||||||
|
if args.diffusion_dtype == "bf16":
|
||||||
|
maybe_cast_module(
|
||||||
|
model.model,
|
||||||
|
torch.bfloat16,
|
||||||
|
"diffusion backbone",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/diffusion_bf16",
|
||||||
|
)
|
||||||
|
diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
print(">>> diffusion backbone set to bfloat16")
|
||||||
|
|
||||||
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||||
|
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
||||||
|
maybe_cast_module(
|
||||||
|
model.first_stage_model,
|
||||||
|
vae_weight_dtype,
|
||||||
|
"VAE",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/vae_cast",
|
||||||
|
)
|
||||||
|
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||||
|
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||||
|
|
||||||
|
# --- VAE performance optimizations ---
|
||||||
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||||
|
vae = model.first_stage_model
|
||||||
|
|
||||||
|
# Channels-last memory format: cuDNN uses faster NHWC kernels
|
||||||
|
if args.vae_channels_last:
|
||||||
|
vae = vae.to(memory_format=torch.channels_last)
|
||||||
|
vae._channels_last = True
|
||||||
|
model.first_stage_model = vae
|
||||||
|
print(">>> VAE converted to channels_last (NHWC) memory format")
|
||||||
|
|
||||||
|
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
||||||
|
if args.vae_compile:
|
||||||
|
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
||||||
|
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
|
||||||
|
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
|
||||||
|
|
||||||
|
# Batch decode size
|
||||||
|
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
|
||||||
|
model.vae_decode_bs = vae_decode_bs
|
||||||
|
model.vae_encode_bs = vae_decode_bs
|
||||||
|
if args.vae_decode_bs > 0:
|
||||||
|
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
|
||||||
|
else:
|
||||||
|
print(">>> VAE encode/decode batch size: all frames at once")
|
||||||
|
|
||||||
|
encoder_mode = args.encoder_mode
|
||||||
|
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||||
|
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||||
|
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
||||||
|
maybe_cast_module(
|
||||||
|
model.cond_stage_model,
|
||||||
|
encoder_weight_dtype,
|
||||||
|
"cond_stage_model",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/encoder_cond_cast",
|
||||||
|
)
|
||||||
|
if hasattr(model, "embedder") and model.embedder is not None:
|
||||||
|
maybe_cast_module(
|
||||||
|
model.embedder,
|
||||||
|
encoder_weight_dtype,
|
||||||
|
"embedder",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/encoder_embedder_cast",
|
||||||
|
)
|
||||||
|
model.encoder_bf16 = encoder_bf16
|
||||||
|
model.encoder_mode = encoder_mode
|
||||||
|
print(
|
||||||
|
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
||||||
|
)
|
||||||
|
|
||||||
|
projector_mode = args.projector_mode
|
||||||
|
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
||||||
|
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
||||||
|
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
||||||
|
maybe_cast_module(
|
||||||
|
model.image_proj_model,
|
||||||
|
projector_weight_dtype,
|
||||||
|
"image_proj_model",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/projector_image_cast",
|
||||||
|
)
|
||||||
|
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||||
|
maybe_cast_module(
|
||||||
|
model.state_projector,
|
||||||
|
projector_weight_dtype,
|
||||||
|
"state_projector",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/projector_state_cast",
|
||||||
|
)
|
||||||
|
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||||
|
maybe_cast_module(
|
||||||
|
model.action_projector,
|
||||||
|
projector_weight_dtype,
|
||||||
|
"action_projector",
|
||||||
|
profiler=profiler,
|
||||||
|
profile_name="model_loading/projector_action_cast",
|
||||||
|
)
|
||||||
|
if hasattr(model, "projector_bf16"):
|
||||||
|
model.projector_bf16 = projector_bf16
|
||||||
|
model.projector_mode = projector_mode
|
||||||
|
print(
|
||||||
|
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
|
||||||
|
)
|
||||||
|
|
||||||
|
log_inference_precision(model)
|
||||||
|
|
||||||
|
if args.export_casted_ckpt:
|
||||||
|
metadata = {
|
||||||
|
"diffusion_dtype": args.diffusion_dtype,
|
||||||
|
"vae_dtype": args.vae_dtype,
|
||||||
|
"encoder_mode": args.encoder_mode,
|
||||||
|
"projector_mode": args.projector_mode,
|
||||||
|
"perframe_ae": args.perframe_ae,
|
||||||
|
}
|
||||||
|
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
|
||||||
|
if args.export_only:
|
||||||
|
print(">>> export_only set; skipping inference.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save prepared model for fast loading next time
|
||||||
|
if prepared_path:
|
||||||
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||||
|
torch.save(model, prepared_path)
|
||||||
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||||
|
|
||||||
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
logging.info("***** Configing Data *****")
|
logging.info("***** Configing Data *****")
|
||||||
with profiler.profile_section("data_loading"):
|
with profiler.profile_section("data_loading"):
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
@@ -1051,130 +1198,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
print(">>> Dataset is successfully loaded ...")
|
print(">>> Dataset is successfully loaded ...")
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
diffusion_autocast_dtype = None
|
|
||||||
if args.diffusion_dtype == "bf16":
|
|
||||||
maybe_cast_module(
|
|
||||||
model.model,
|
|
||||||
torch.bfloat16,
|
|
||||||
"diffusion backbone",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/diffusion_bf16",
|
|
||||||
)
|
|
||||||
diffusion_autocast_dtype = torch.bfloat16
|
|
||||||
print(">>> diffusion backbone set to bfloat16")
|
|
||||||
|
|
||||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
|
||||||
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
|
||||||
maybe_cast_module(
|
|
||||||
model.first_stage_model,
|
|
||||||
vae_weight_dtype,
|
|
||||||
"VAE",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/vae_cast",
|
|
||||||
)
|
|
||||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
|
||||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
|
||||||
|
|
||||||
# --- VAE performance optimizations ---
|
|
||||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
|
||||||
vae = model.first_stage_model
|
|
||||||
|
|
||||||
# Channels-last memory format: cuDNN uses faster NHWC kernels
|
|
||||||
if args.vae_channels_last:
|
|
||||||
vae = vae.to(memory_format=torch.channels_last)
|
|
||||||
vae._channels_last = True
|
|
||||||
model.first_stage_model = vae
|
|
||||||
print(">>> VAE converted to channels_last (NHWC) memory format")
|
|
||||||
|
|
||||||
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
|
||||||
if args.vae_compile:
|
|
||||||
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
|
||||||
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
|
|
||||||
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
|
|
||||||
|
|
||||||
# Batch decode size
|
|
||||||
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
|
|
||||||
model.vae_decode_bs = vae_decode_bs
|
|
||||||
model.vae_encode_bs = vae_decode_bs
|
|
||||||
if args.vae_decode_bs > 0:
|
|
||||||
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
|
|
||||||
else:
|
|
||||||
print(">>> VAE encode/decode batch size: all frames at once")
|
|
||||||
|
|
||||||
encoder_mode = args.encoder_mode
|
|
||||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
|
||||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
|
||||||
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
|
||||||
maybe_cast_module(
|
|
||||||
model.cond_stage_model,
|
|
||||||
encoder_weight_dtype,
|
|
||||||
"cond_stage_model",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/encoder_cond_cast",
|
|
||||||
)
|
|
||||||
if hasattr(model, "embedder") and model.embedder is not None:
|
|
||||||
maybe_cast_module(
|
|
||||||
model.embedder,
|
|
||||||
encoder_weight_dtype,
|
|
||||||
"embedder",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/encoder_embedder_cast",
|
|
||||||
)
|
|
||||||
model.encoder_bf16 = encoder_bf16
|
|
||||||
model.encoder_mode = encoder_mode
|
|
||||||
print(
|
|
||||||
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
|
||||||
)
|
|
||||||
|
|
||||||
projector_mode = args.projector_mode
|
|
||||||
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
|
||||||
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
|
||||||
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
|
||||||
maybe_cast_module(
|
|
||||||
model.image_proj_model,
|
|
||||||
projector_weight_dtype,
|
|
||||||
"image_proj_model",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/projector_image_cast",
|
|
||||||
)
|
|
||||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
|
||||||
maybe_cast_module(
|
|
||||||
model.state_projector,
|
|
||||||
projector_weight_dtype,
|
|
||||||
"state_projector",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/projector_state_cast",
|
|
||||||
)
|
|
||||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
|
||||||
maybe_cast_module(
|
|
||||||
model.action_projector,
|
|
||||||
projector_weight_dtype,
|
|
||||||
"action_projector",
|
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/projector_action_cast",
|
|
||||||
)
|
|
||||||
if hasattr(model, "projector_bf16"):
|
|
||||||
model.projector_bf16 = projector_bf16
|
|
||||||
model.projector_mode = projector_mode
|
|
||||||
print(
|
|
||||||
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
|
|
||||||
)
|
|
||||||
|
|
||||||
log_inference_precision(model)
|
|
||||||
|
|
||||||
if args.export_casted_ckpt:
|
|
||||||
metadata = {
|
|
||||||
"diffusion_dtype": args.diffusion_dtype,
|
|
||||||
"vae_dtype": args.vae_dtype,
|
|
||||||
"encoder_mode": args.encoder_mode,
|
|
||||||
"projector_mode": args.projector_mode,
|
|
||||||
"perframe_ae": args.perframe_ae,
|
|
||||||
}
|
|
||||||
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
|
|
||||||
if args.export_only:
|
|
||||||
print(">>> export_only set; skipping inference.")
|
|
||||||
return
|
|
||||||
|
|
||||||
profiler.record_memory("after_model_load")
|
profiler.record_memory("after_model_load")
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
|
|||||||
Reference in New Issue
Block a user