embedder权重改成bf16

似乎因为权重的处理更慢了,整体速度反而变慢了一点点
This commit is contained in:
2026-01-18 19:03:21 +08:00
parent fde3c7445d
commit 44379f3e31
5 changed files with 28 additions and 3 deletions

View File

@@ -772,7 +772,11 @@ def image_guided_synthesis_sim_mode(
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
embedder_ctx = nullcontext()
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
embedder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with embedder_ctx:
cond_img_emb = model.embedder(cond_img)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
@@ -912,6 +916,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
encoder_dtype = torch.float32
if args.encoder_dtype == "bf16":
encoder_dtype = torch.bfloat16
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
model.cond_stage_model.to(dtype=encoder_dtype)
if hasattr(model, "embedder") and model.embedder is not None:
model.embedder.to(dtype=encoder_dtype)
model.encoder_bf16 = args.encoder_dtype == "bf16"
print(f">>> encoder dtype set to {args.encoder_dtype}")
if hasattr(model, "projector_bf16"):
model.projector_bf16 = args.projector_dtype == "bf16"
print(f">>> projector dtype set to {args.projector_dtype}")
@@ -1266,6 +1280,13 @@ def get_parser():
default="fp32",
help="Dtype for image/state/action projectors (autocast in forward)."
)
parser.add_argument(
"--encoder_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for text/image encoders (cond_stage_model/embedder)."
)
parser.add_argument(
"--n_action_steps",
type=int,