权重改为fp32, 其他部分保持bf16
目前参数--encoder_mode有三种选择: 1. fp32: 全部使用fp32, 适合显存充足的情况 2. autocast: 使用torch.cuda.amp.autocast自动混合精度, 稍微快一些, psnr下降较多 3. bf16_full: 全部使用bf16, 精度较高
This commit is contained in:
@@ -772,10 +772,34 @@ def image_guided_synthesis_sim_mode(
|
||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
embedder_ctx = nullcontext()
|
||||
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
||||
embedder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||
with embedder_ctx:
|
||||
if getattr(model, "encoder_mode", "autocast") == "autocast":
|
||||
preprocess_ctx = torch.autocast("cuda", enabled=False)
|
||||
with preprocess_ctx:
|
||||
cond_img_fp32 = cond_img.float()
|
||||
if hasattr(model.embedder, "preprocess"):
|
||||
preprocessed = model.embedder.preprocess(cond_img_fp32)
|
||||
else:
|
||||
preprocessed = cond_img_fp32
|
||||
|
||||
if hasattr(model.embedder,
|
||||
"encode_with_vision_transformer") and hasattr(
|
||||
model.embedder, "preprocess"):
|
||||
original_preprocess = model.embedder.preprocess
|
||||
try:
|
||||
model.embedder.preprocess = lambda x: x
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
cond_img_emb = model.embedder.encode_with_vision_transformer(
|
||||
preprocessed)
|
||||
finally:
|
||||
model.embedder.preprocess = original_preprocess
|
||||
else:
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
cond_img_emb = model.embedder(preprocessed)
|
||||
else:
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
else:
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
@@ -788,7 +812,11 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
if not text_input:
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
encoder_ctx = nullcontext()
|
||||
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
||||
encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||
with encoder_ctx:
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
target_dtype = cond_ins_emb.dtype
|
||||
|
||||
cond_img_emb = model._projector_forward(model.image_proj_model,
|
||||
@@ -916,15 +944,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
encoder_dtype = torch.float32
|
||||
if args.encoder_dtype == "bf16":
|
||||
encoder_dtype = torch.bfloat16
|
||||
encoder_mode = args.encoder_mode
|
||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
||||
model.cond_stage_model.to(dtype=encoder_dtype)
|
||||
model.cond_stage_model.to(dtype=encoder_weight_dtype)
|
||||
if hasattr(model, "embedder") and model.embedder is not None:
|
||||
model.embedder.to(dtype=encoder_dtype)
|
||||
model.encoder_bf16 = args.encoder_dtype == "bf16"
|
||||
print(f">>> encoder dtype set to {args.encoder_dtype}")
|
||||
model.embedder.to(dtype=encoder_weight_dtype)
|
||||
model.encoder_bf16 = encoder_bf16
|
||||
model.encoder_mode = encoder_mode
|
||||
print(
|
||||
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
||||
)
|
||||
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = args.projector_dtype == "bf16"
|
||||
@@ -1281,11 +1312,14 @@ def get_parser():
|
||||
help="Dtype for image/state/action projectors (autocast in forward)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_dtype",
|
||||
"--encoder_mode",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="fp32",
|
||||
help="Dtype for text/image encoders (cond_stage_model/embedder)."
|
||||
help=
|
||||
"Encoder precision mode for cond_stage_model/embedder: "
|
||||
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
||||
"bf16_full=bf16 weights + bf16 forward."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_action_steps",
|
||||
|
||||
Binary file not shown.
@@ -23,5 +23,5 @@ dataset="unitree_g1_pack_camera"
|
||||
--perframe_ae \
|
||||
--diffusion_dtype bf16 \
|
||||
--projector_dtype bf16 \
|
||||
--encoder_dtype bf16
|
||||
--encoder_mode autocast #fp32/autocast/bf16_full
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
25
useful.sh
25
useful.sh
@@ -82,3 +82,28 @@ BF16 projector比FP32 projector更准的可能原因:
|
||||
|
||||
embedder:
|
||||
改成 autocast only(权重 FP32,预处理 FP32,仅主干 BF16)
|
||||
|
||||
- 效果差的那次:encoder autocast set to bf16 (weights=fp32)
|
||||
也就是“权重 FP32 + autocast BF16(我现在的 autocast only 方案)”
|
||||
- 效果好的那次:cond_stage_model / embedder param dtype: torch.bfloat16
|
||||
也就是“权重 BF16(全 BF16)”
|
||||
|
||||
为什么 FP32 反而掉分大?主要是“混精度路径不一致”:
|
||||
|
||||
1. autocast only 不是等价于 full BF16
|
||||
autocast 会把部分算子留在 FP32(比如 LayerNorm / Softmax),另一部分用 BF16;
|
||||
这会让 encoder 输出分布和“全 BF16”不一样,尤其当扩散主干也是 BF16 时,条件向量分布一偏就会影响 PSNR。
|
||||
2. 先 FP32 计算再转 BF16
|
||||
encoder 的输出最后还是要拼到 BF16 主干里,等于“先高精度 → 再截断”。
|
||||
训练时如果你用的是 BF16,模型更适应“低精度直接计算”的分布。
|
||||
3. 你的现象和 projector 一样
|
||||
你已经观察到 projector BF16 更准,这说明模型很可能“更偏好一致的 BF16 路径”。
|
||||
|
||||
结论:
|
||||
对你这个模型,full BF16(权重 BF16)更匹配训练分布,反而 PSNR 更高。
|
||||
|
||||
如果你要我改回“bf16=全 BF16(权重 BF16)”,并且保留现在的 autocast only 作为可选模式,我可以这样做:
|
||||
|
||||
1. 新增 --encoder_mode {fp32, autocast, bf16_full}
|
||||
2. bf16_full = 权重 BF16 + 前向 BF16
|
||||
3. autocast = 权重 FP32 + 仅主干 autocast(现在的实现)
|
||||
Reference in New Issue
Block a user