轻量投影/MLP BF16 微调
调整了一些参数默认值
This commit is contained in:
@@ -912,6 +912,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
diffusion_autocast_dtype = torch.bfloat16
|
diffusion_autocast_dtype = torch.bfloat16
|
||||||
print(">>> diffusion backbone set to bfloat16")
|
print(">>> diffusion backbone set to bfloat16")
|
||||||
|
|
||||||
|
if hasattr(model, "projector_bf16"):
|
||||||
|
model.projector_bf16 = args.projector_dtype == "bf16"
|
||||||
|
print(f">>> projector dtype set to {args.projector_dtype}")
|
||||||
|
|
||||||
log_inference_precision(model)
|
log_inference_precision(model)
|
||||||
|
|
||||||
profiler.record_memory("after_model_load")
|
profiler.record_memory("after_model_load")
|
||||||
@@ -1255,6 +1259,13 @@ def get_parser():
|
|||||||
default="fp32",
|
default="fp32",
|
||||||
help="Dtype for diffusion backbone weights and sampling autocast."
|
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--projector_dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "bf16"],
|
||||||
|
default="fp32",
|
||||||
|
help="Dtype for image/state/action projectors (autocast in forward)."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_action_steps",
|
"--n_action_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -21,5 +21,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--diffusion_dtype bf16
|
--diffusion_dtype bf16 \
|
||||||
|
--projector_dtype bf16
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -71,3 +71,11 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit
|
|||||||
- 优先 BF16:稳定性好于 FP16
|
- 优先 BF16:稳定性好于 FP16
|
||||||
- 只做半精度,不做 INT 量化:保持 PSNR
|
- 只做半精度,不做 INT 量化:保持 PSNR
|
||||||
- VAE 尽量 FP32:最影响画质的模块
|
- VAE 尽量 FP32:最影响画质的模块
|
||||||
|
|
||||||
|
BF16 projector比FP32 projector更准的可能原因:
|
||||||
|
- 数值路径更一致:主干在 BF16 下做 attention/MLP,projector 若是 FP32 会在进入主干前被 downcast,导致“先高精度非线性→再截断”的分布偏移;直接 BF16 算 projector 反而让输出
|
||||||
|
分布更贴近主干的计算习惯。
|
||||||
|
- 训练分布匹配:训练时你是 precision:16,projector 长期在低精度环境下被优化;推理用 FP32 反而可能偏离训练时的统计特性。
|
||||||
|
- LayerNorm/Softmax 敏感:Resampler/MLP 里 LN/Softmax 对精度很敏感,FP32 计算后再降精度,数值边界更容易“硬截断”;BF16 全程计算可能更平滑。
|
||||||
|
|
||||||
|
这也解释了为什么你看到 BF16 projector 反而更准。
|
||||||
Reference in New Issue
Block a user