权重改为fp32, 其他部分保持bf16

目前参数--encoder_mode有三种选择:
1. fp32: 全部使用fp32, 适合显存充足的情况
2. autocast: 使用torch.cuda.amp.autocast自动混合精度, 稍微快一些, psnr下降较多
3. bf16_full: 全部使用bf16, 精度较高
This commit is contained in:
2026-01-18 20:24:37 +08:00
parent 44379f3e31
commit e1b029201e
4 changed files with 75 additions and 16 deletions

View File

@@ -772,10 +772,34 @@ def image_guided_synthesis_sim_mode(
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
embedder_ctx = nullcontext()
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
embedder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with embedder_ctx:
if getattr(model, "encoder_mode", "autocast") == "autocast":
preprocess_ctx = torch.autocast("cuda", enabled=False)
with preprocess_ctx:
cond_img_fp32 = cond_img.float()
if hasattr(model.embedder, "preprocess"):
preprocessed = model.embedder.preprocess(cond_img_fp32)
else:
preprocessed = cond_img_fp32
if hasattr(model.embedder,
"encode_with_vision_transformer") and hasattr(
model.embedder, "preprocess"):
original_preprocess = model.embedder.preprocess
try:
model.embedder.preprocess = lambda x: x
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder.encode_with_vision_transformer(
preprocessed)
finally:
model.embedder.preprocess = original_preprocess
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder(preprocessed)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
else:
cond_img_emb = model.embedder(cond_img)
if model.model.conditioning_key == 'hybrid':
@@ -788,7 +812,11 @@ def image_guided_synthesis_sim_mode(
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
encoder_ctx = nullcontext()
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with encoder_ctx:
cond_ins_emb = model.get_learned_conditioning(prompts)
target_dtype = cond_ins_emb.dtype
cond_img_emb = model._projector_forward(model.image_proj_model,
@@ -916,15 +944,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
encoder_dtype = torch.float32
if args.encoder_dtype == "bf16":
encoder_dtype = torch.bfloat16
encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
model.cond_stage_model.to(dtype=encoder_dtype)
model.cond_stage_model.to(dtype=encoder_weight_dtype)
if hasattr(model, "embedder") and model.embedder is not None:
model.embedder.to(dtype=encoder_dtype)
model.encoder_bf16 = args.encoder_dtype == "bf16"
print(f">>> encoder dtype set to {args.encoder_dtype}")
model.embedder.to(dtype=encoder_weight_dtype)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
print(
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = args.projector_dtype == "bf16"
@@ -1281,11 +1312,14 @@ def get_parser():
help="Dtype for image/state/action projectors (autocast in forward)."
)
parser.add_argument(
"--encoder_dtype",
"--encoder_mode",
type=str,
choices=["fp32", "bf16"],
choices=["fp32", "autocast", "bf16_full"],
default="fp32",
help="Dtype for text/image encoders (cond_stage_model/embedder)."
help=
"Encoder precision mode for cond_stage_model/embedder: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--n_action_steps",