轻量投影/MLP BF16 微调
调整了一些参数默认值
This commit is contained in:
@@ -912,6 +912,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = args.projector_dtype == "bf16"
|
||||
print(f">>> projector dtype set to {args.projector_dtype}")
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
@@ -1255,6 +1259,13 @@ def get_parser():
|
||||
default="fp32",
|
||||
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--projector_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="Dtype for image/state/action projectors (autocast in forward)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_action_steps",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user