把混和精度模型权重导出至本地文件,减少dtype开销

--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
        --export_only
This commit is contained in:
2026-01-19 15:14:01 +08:00
parent cb334f308b
commit 7e501b17fd
20 changed files with 245 additions and 55 deletions

View File

@@ -441,57 +441,143 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
def _load_state_dict(model: nn.Module,
state_dict: Mapping[str, torch.Tensor],
strict: bool = True,
assign: bool = False) -> None:
if assign:
try:
model.load_state_dict(state_dict, strict=strict, assign=True)
return
except TypeError:
warnings.warn(
"load_state_dict(assign=True) not supported; "
"falling back to copy load.")
model.load_state_dict(state_dict, strict=strict)
def load_model_checkpoint(model: nn.Module,
ckpt: str,
assign: bool | None = None) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
assign (bool | None): Whether to preserve checkpoint tensor dtypes
via load_state_dict(assign=True). If None, auto-enable when a
casted checkpoint metadata is detected.
Returns:
nn.Module: Model with loaded weights.
"""
ckpt_data = torch.load(ckpt, map_location="cpu")
use_assign = False
if assign is not None:
use_assign = assign
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
use_assign = True
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
state_dict = ckpt_data["state_dict"]
try:
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
except Exception:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
_load_state_dict(model,
new_pl_sd,
strict=True,
assign=use_assign)
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
new_pl_sd = OrderedDict()
for key in ckpt_data['module'].keys():
new_pl_sd[key[16:]] = ckpt_data['module'][key]
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
else:
_load_state_dict(model,
ckpt_data,
strict=True,
assign=use_assign)
print('>>> model checkpoint loaded.')
return model
def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype,
label: str,
profiler: Optional[ProfilerManager] = None,
profile_name: Optional[str] = None) -> None:
if module is None:
return
try:
param = next(module.parameters())
except StopIteration:
print(f">>> {label} has no parameters; skip cast")
return
if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast")
return
ctx = nullcontext()
if profiler is not None and profile_name:
ctx = profiler.profile_section(profile_name)
with ctx:
module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}")
def save_casted_checkpoint(model: nn.Module,
save_path: str,
metadata: Optional[Dict[str, Any]] = None) -> None:
if not save_path:
return
save_dir = os.path.dirname(save_path)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
cpu_state = {}
for key, value in model.state_dict().items():
if isinstance(value, torch.Tensor):
cpu_state[key] = value.detach().to("cpu")
else:
cpu_state[key] = value
payload: Dict[str, Any] = {"state_dict": cpu_state}
if metadata:
payload["precision_metadata"] = metadata
torch.save(payload, save_path)
print(f">>> Saved casted checkpoint to {save_path}")
def _module_param_dtype(module: nn.Module | None) -> str:
if module is None:
return "None"
dtype_counts: Dict[str, int] = {}
for param in module.parameters():
return str(param.dtype)
return "no_params"
dtype_key = str(param.dtype)
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
if not dtype_counts:
return "no_params"
if len(dtype_counts) == 1:
return next(iter(dtype_counts))
total = sum(dtype_counts.values())
parts = []
for dtype_key in sorted(dtype_counts.keys()):
ratio = dtype_counts[dtype_key] / total
parts.append(f"{dtype_key}={ratio:.1%}")
return f"mixed({', '.join(parts)})"
def log_inference_precision(model: nn.Module) -> None:
try:
param = next(model.parameters())
device = "unknown"
for param in model.parameters():
device = str(param.device)
model_dtype = str(param.dtype)
except StopIteration:
device = "unknown"
model_dtype = "no_params"
break
model_dtype = _module_param_dtype(model)
print(f">>> inference precision: model={model_dtype}, device={device}")
for attr in [
@@ -966,16 +1052,25 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
with profiler.profile_section("model_loading/diffusion_bf16"):
model.model.to(dtype=torch.bfloat16)
maybe_cast_module(
model.model,
torch.bfloat16,
"diffusion backbone",
profiler=profiler,
profile_name="model_loading/diffusion_bf16",
)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
if args.vae_dtype == "bf16":
model.first_stage_model.to(dtype=torch.bfloat16)
else:
model.first_stage_model.to(dtype=torch.float32)
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
maybe_cast_module(
model.first_stage_model,
vae_weight_dtype,
"VAE",
profiler=profiler,
profile_name="model_loading/vae_cast",
)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
@@ -983,9 +1078,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
model.cond_stage_model.to(dtype=encoder_weight_dtype)
maybe_cast_module(
model.cond_stage_model,
encoder_weight_dtype,
"cond_stage_model",
profiler=profiler,
profile_name="model_loading/encoder_cond_cast",
)
if hasattr(model, "embedder") and model.embedder is not None:
model.embedder.to(dtype=encoder_weight_dtype)
maybe_cast_module(
model.embedder,
encoder_weight_dtype,
"embedder",
profiler=profiler,
profile_name="model_loading/encoder_embedder_cast",
)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
print(
@@ -996,11 +1103,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
projector_bf16 = projector_mode in ("autocast", "bf16_full")
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
model.image_proj_model.to(dtype=projector_weight_dtype)
maybe_cast_module(
model.image_proj_model,
projector_weight_dtype,
"image_proj_model",
profiler=profiler,
profile_name="model_loading/projector_image_cast",
)
if hasattr(model, "state_projector") and model.state_projector is not None:
model.state_projector.to(dtype=projector_weight_dtype)
maybe_cast_module(
model.state_projector,
projector_weight_dtype,
"state_projector",
profiler=profiler,
profile_name="model_loading/projector_state_cast",
)
if hasattr(model, "action_projector") and model.action_projector is not None:
model.action_projector.to(dtype=projector_weight_dtype)
maybe_cast_module(
model.action_projector,
projector_weight_dtype,
"action_projector",
profiler=profiler,
profile_name="model_loading/projector_action_cast",
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16
model.projector_mode = projector_mode
@@ -1010,6 +1135,19 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
log_inference_precision(model)
if args.export_casted_ckpt:
metadata = {
"diffusion_dtype": args.diffusion_dtype,
"vae_dtype": args.vae_dtype,
"encoder_mode": args.encoder_mode,
"projector_mode": args.projector_mode,
"perframe_ae": args.perframe_ae,
}
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
if args.export_only:
print(">>> export_only set; skipping inference.")
return
profiler.record_memory("after_model_load")
# Run over data
@@ -1373,6 +1511,19 @@ def get_parser():
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--export_casted_ckpt",
type=str,
default=None,
help=
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
)
parser.add_argument(
"--export_only",
action='store_true',
default=False,
help="Exit after exporting the casted checkpoint."
)
parser.add_argument(
"--step_log_every",
type=int,