把混和精度模型权重导出至本地文件,减少dtype开销
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
--export_only
This commit is contained in:
@@ -441,57 +441,143 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
except:
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
model.load_state_dict(new_pl_sd, strict=True)
|
||||
else:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
def _load_state_dict(model: nn.Module,
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
strict: bool = True,
|
||||
assign: bool = False) -> None:
|
||||
if assign:
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=strict, assign=True)
|
||||
return
|
||||
except TypeError:
|
||||
warnings.warn(
|
||||
"load_state_dict(assign=True) not supported; "
|
||||
"falling back to copy load.")
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module,
|
||||
ckpt: str,
|
||||
assign: bool | None = None) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
assign (bool | None): Whether to preserve checkpoint tensor dtypes
|
||||
via load_state_dict(assign=True). If None, auto-enable when a
|
||||
casted checkpoint metadata is detected.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
ckpt_data = torch.load(ckpt, map_location="cpu")
|
||||
use_assign = False
|
||||
if assign is not None:
|
||||
use_assign = assign
|
||||
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
|
||||
use_assign = True
|
||||
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
|
||||
state_dict = ckpt_data["state_dict"]
|
||||
try:
|
||||
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
|
||||
except Exception:
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
_load_state_dict(model,
|
||||
new_pl_sd,
|
||||
strict=True,
|
||||
assign=use_assign)
|
||||
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in ckpt_data['module'].keys():
|
||||
new_pl_sd[key[16:]] = ckpt_data['module'][key]
|
||||
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
|
||||
else:
|
||||
_load_state_dict(model,
|
||||
ckpt_data,
|
||||
strict=True,
|
||||
assign=use_assign)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def maybe_cast_module(module: nn.Module | None,
|
||||
dtype: torch.dtype,
|
||||
label: str,
|
||||
profiler: Optional[ProfilerManager] = None,
|
||||
profile_name: Optional[str] = None) -> None:
|
||||
if module is None:
|
||||
return
|
||||
try:
|
||||
param = next(module.parameters())
|
||||
except StopIteration:
|
||||
print(f">>> {label} has no parameters; skip cast")
|
||||
return
|
||||
if param.dtype == dtype:
|
||||
print(f">>> {label} already {dtype}; skip cast")
|
||||
return
|
||||
ctx = nullcontext()
|
||||
if profiler is not None and profile_name:
|
||||
ctx = profiler.profile_section(profile_name)
|
||||
with ctx:
|
||||
module.to(dtype=dtype)
|
||||
print(f">>> {label} cast to {dtype}")
|
||||
|
||||
|
||||
def save_casted_checkpoint(model: nn.Module,
|
||||
save_path: str,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not save_path:
|
||||
return
|
||||
save_dir = os.path.dirname(save_path)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
cpu_state = {}
|
||||
for key, value in model.state_dict().items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
cpu_state[key] = value.detach().to("cpu")
|
||||
else:
|
||||
cpu_state[key] = value
|
||||
payload: Dict[str, Any] = {"state_dict": cpu_state}
|
||||
if metadata:
|
||||
payload["precision_metadata"] = metadata
|
||||
torch.save(payload, save_path)
|
||||
print(f">>> Saved casted checkpoint to {save_path}")
|
||||
|
||||
|
||||
def _module_param_dtype(module: nn.Module | None) -> str:
|
||||
if module is None:
|
||||
return "None"
|
||||
dtype_counts: Dict[str, int] = {}
|
||||
for param in module.parameters():
|
||||
return str(param.dtype)
|
||||
return "no_params"
|
||||
dtype_key = str(param.dtype)
|
||||
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
|
||||
if not dtype_counts:
|
||||
return "no_params"
|
||||
if len(dtype_counts) == 1:
|
||||
return next(iter(dtype_counts))
|
||||
total = sum(dtype_counts.values())
|
||||
parts = []
|
||||
for dtype_key in sorted(dtype_counts.keys()):
|
||||
ratio = dtype_counts[dtype_key] / total
|
||||
parts.append(f"{dtype_key}={ratio:.1%}")
|
||||
return f"mixed({', '.join(parts)})"
|
||||
|
||||
|
||||
def log_inference_precision(model: nn.Module) -> None:
|
||||
try:
|
||||
param = next(model.parameters())
|
||||
device = "unknown"
|
||||
for param in model.parameters():
|
||||
device = str(param.device)
|
||||
model_dtype = str(param.dtype)
|
||||
except StopIteration:
|
||||
device = "unknown"
|
||||
model_dtype = "no_params"
|
||||
break
|
||||
model_dtype = _module_param_dtype(model)
|
||||
|
||||
print(f">>> inference precision: model={model_dtype}, device={device}")
|
||||
for attr in [
|
||||
@@ -966,16 +1052,25 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
diffusion_autocast_dtype = None
|
||||
if args.diffusion_dtype == "bf16":
|
||||
with profiler.profile_section("model_loading/diffusion_bf16"):
|
||||
model.model.to(dtype=torch.bfloat16)
|
||||
maybe_cast_module(
|
||||
model.model,
|
||||
torch.bfloat16,
|
||||
"diffusion backbone",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/diffusion_bf16",
|
||||
)
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(dtype=torch.bfloat16)
|
||||
else:
|
||||
model.first_stage_model.to(dtype=torch.float32)
|
||||
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
||||
maybe_cast_module(
|
||||
model.first_stage_model,
|
||||
vae_weight_dtype,
|
||||
"VAE",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/vae_cast",
|
||||
)
|
||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||
|
||||
@@ -983,9 +1078,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
||||
model.cond_stage_model.to(dtype=encoder_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.cond_stage_model,
|
||||
encoder_weight_dtype,
|
||||
"cond_stage_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_cond_cast",
|
||||
)
|
||||
if hasattr(model, "embedder") and model.embedder is not None:
|
||||
model.embedder.to(dtype=encoder_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.embedder,
|
||||
encoder_weight_dtype,
|
||||
"embedder",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_embedder_cast",
|
||||
)
|
||||
model.encoder_bf16 = encoder_bf16
|
||||
model.encoder_mode = encoder_mode
|
||||
print(
|
||||
@@ -996,11 +1103,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
||||
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
||||
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
||||
model.image_proj_model.to(dtype=projector_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.image_proj_model,
|
||||
projector_weight_dtype,
|
||||
"image_proj_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_image_cast",
|
||||
)
|
||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||
model.state_projector.to(dtype=projector_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.state_projector,
|
||||
projector_weight_dtype,
|
||||
"state_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_state_cast",
|
||||
)
|
||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||
model.action_projector.to(dtype=projector_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.action_projector,
|
||||
projector_weight_dtype,
|
||||
"action_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_action_cast",
|
||||
)
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = projector_bf16
|
||||
model.projector_mode = projector_mode
|
||||
@@ -1010,6 +1135,19 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
if args.export_casted_ckpt:
|
||||
metadata = {
|
||||
"diffusion_dtype": args.diffusion_dtype,
|
||||
"vae_dtype": args.vae_dtype,
|
||||
"encoder_mode": args.encoder_mode,
|
||||
"projector_mode": args.projector_mode,
|
||||
"perframe_ae": args.perframe_ae,
|
||||
}
|
||||
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
|
||||
if args.export_only:
|
||||
print(">>> export_only set; skipping inference.")
|
||||
return
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
|
||||
# Run over data
|
||||
@@ -1373,6 +1511,19 @@ def get_parser():
|
||||
default="fp32",
|
||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_casted_ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Exit after exporting the casted checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step_log_every",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user