轻量投影/MLP BF16 微调

调整了一些参数默认值
This commit is contained in:
2026-01-18 18:38:47 +08:00
parent 3c0f409fcf
commit fde3c7445d
3 changed files with 21 additions and 1 deletions

View File

@@ -912,6 +912,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "projector_bf16"):
model.projector_bf16 = args.projector_dtype == "bf16"
print(f">>> projector dtype set to {args.projector_dtype}")
log_inference_precision(model)
profiler.record_memory("after_model_load")
@@ -1255,6 +1259,13 @@ def get_parser():
default="fp32",
help="Dtype for diffusion backbone weights and sampling autocast."
)
parser.add_argument(
"--projector_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for image/state/action projectors (autocast in forward)."
)
parser.add_argument(
"--n_action_steps",
type=int,