轻量投影/MLP BF16 微调

调整了一些参数默认值
This commit is contained in:
2026-01-18 18:38:47 +08:00
parent 3c0f409fcf
commit fde3c7445d
3 changed files with 21 additions and 1 deletions

View File

@@ -912,6 +912,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = torch.bfloat16 diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16") print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "projector_bf16"):
model.projector_bf16 = args.projector_dtype == "bf16"
print(f">>> projector dtype set to {args.projector_dtype}")
log_inference_precision(model) log_inference_precision(model)
profiler.record_memory("after_model_load") profiler.record_memory("after_model_load")
@@ -1255,6 +1259,13 @@ def get_parser():
default="fp32", default="fp32",
help="Dtype for diffusion backbone weights and sampling autocast." help="Dtype for diffusion backbone weights and sampling autocast."
) )
parser.add_argument(
"--projector_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for image/state/action projectors (autocast in forward)."
)
parser.add_argument( parser.add_argument(
"--n_action_steps", "--n_action_steps",
type=int, type=int,

View File

@@ -21,5 +21,6 @@ dataset="unitree_g1_pack_camera"
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae \ --perframe_ae \
--diffusion_dtype bf16 --diffusion_dtype bf16 \
--projector_dtype bf16
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"

View File

@@ -71,3 +71,11 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit
- 优先 BF16稳定性好于 FP16 - 优先 BF16稳定性好于 FP16
- 只做半精度,不做 INT 量化:保持 PSNR - 只做半精度,不做 INT 量化:保持 PSNR
- VAE 尽量 FP32最影响画质的模块 - VAE 尽量 FP32最影响画质的模块
BF16 projector比FP32 projector更准的可能原因
- 数值路径更一致:主干在 BF16 下做 attention/MLPprojector 若是 FP32 会在进入主干前被 downcast导致“先高精度非线性→再截断”的分布偏移直接 BF16 算 projector 反而让输出
分布更贴近主干的计算习惯。
- 训练分布匹配:训练时你是 precision:16projector 长期在低精度环境下被优化;推理用 FP32 反而可能偏离训练时的统计特性。
- LayerNorm/Softmax 敏感Resampler/MLP 里 LN/Softmax 对精度很敏感FP32 计算后再降精度数值边界更容易“硬截断”BF16 全程计算可能更平滑。
这也解释了为什么你看到 BF16 projector 反而更准。