把混和精度模型权重导出至本地文件,减少dtype开销

--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
        --export_only
This commit is contained in:
2026-01-19 15:14:01 +08:00
parent cb334f308b
commit 7e501b17fd
20 changed files with 245 additions and 55 deletions

View File

@@ -118,4 +118,44 @@ embedder
2. attention mask 缓存到 GPU已完成
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
- 位置src/unifolm_wma/modules/attention.py
- 位置src/unifolm_wma/modules/attention.py
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
--export_only
checkpoint or previous run; diffusion backbone and condition embedder are bfloat16 while
projector is fp32, so casts are skipped as expected. The "inference precision" line reflects the
top-level param dtype (possibly fp32 embeddings), explaining the apparent inconsistency with
submodule param dtypes.
Explaining skip cast due to preserved dtypes
The skip cast happens because checkpoint loading preserves stored dtypes, so weights already
match the target (bf16 or fp32), making casting unnecessary. The log shows first param dtype for
checks, which can differ from submodules but is expected. For forcing recast, exporting a new
casted checkpoint is needed since repeated casts to the same dtype no-op.
• “skip cast” 的意思是该模块的权重已经是目标 dtype所以 maybe_cast_module 直接跳过转换。你这段日志
里:
- diffusion backbone / cond_stage_model / embedder 已经是 torch.bfloat16
- projector 仍是 torch.float32因为你设的是 projector_mode=autocast权重保持 fp32
- VAE 你设置为 fp32也已经是 fp32
所以全部都匹配目标精度了,自然会 “skip cast”。这通常说明你加载的 ckpt 本身就是混合精度权重(比如
之前用 --export_casted_ckpt 导出的)。
另外这行:
>>> inference precision: model=torch.float32
只是取了模型第一个参数的 dtypelog_inference_precision 的实现如此),模型里可能还有一些主模块外
的 fp32 参数(比如位置编码等),所以看到整体是 fp32 很正常,不代表 diffusion 不在 bf16。后面的子
模块打印才更准确。
如果你想确认是否从“原始 fp32 ckpt”重新 cast只要把 --ckpt_path 指回原始 ckpt 跑一次即可。