Optimize inference path: add predictor-only torch.compile with reduce-overhead

This commit is contained in:
qihuanye
2026-04-09 10:00:13 +00:00
parent f2750daace
commit 38be7d3bef

70
eval.py
View File

@@ -67,6 +67,63 @@ def get_profile_cfg(cfg):
return profile_cfg
def get_compile_cfg(cfg):
compile_cfg = {
"enabled": True,
"target": "predictor",
"mode": "reduce-overhead",
"fullgraph": False,
"dynamic": False,
"cuda_only": True,
}
cfg_compile = cfg.get("compile")
if cfg_compile is not None:
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
return compile_cfg
def maybe_compile_inference_target(model, cfg, device):
compile_cfg = get_compile_cfg(cfg)
compile_target = "disabled"
if not compile_cfg["enabled"]:
return model, compile_cfg, compile_target
if not hasattr(torch, "compile"):
print("torch.compile is unavailable, skipping inference compilation.")
return model, compile_cfg, compile_target
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
return model, compile_cfg, compile_target
target = str(compile_cfg["target"]).lower()
compile_kwargs = {
"mode": compile_cfg["mode"],
"fullgraph": compile_cfg["fullgraph"],
"dynamic": compile_cfg["dynamic"],
}
if target == "predictor":
if not hasattr(model, "predictor"):
print("Requested compile target 'predictor' is unavailable on the model.")
return model, compile_cfg, compile_target
model.predictor = torch.compile(model.predictor, **compile_kwargs)
compile_target = "predictor"
elif target == "predict":
if not hasattr(model, "predict"):
print("Requested compile target 'predict' is unavailable on the model.")
return model, compile_cfg, compile_target
model.predict = torch.compile(model.predict, **compile_kwargs)
compile_target = "predict"
else:
print(
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
)
return model, compile_cfg, compile_target
def get_inference_context(cfg, device):
precision = str(cfg.get("inference_precision", "fp32")).lower()
device_type = "cuda" if device.startswith("cuda") else "cpu"
@@ -182,9 +239,17 @@ def run(cfg: DictConfig):
model = model.to(device)
model = model.eval()
model.requires_grad_(False)
model, compile_cfg, compile_target = maybe_compile_inference_target(
model, cfg, device
)
print(f"model parameter dtype: {next(model.parameters()).dtype}")
inference_ctx, inference_precision = get_inference_context(cfg, device)
print(f"inference execution precision: {inference_precision}")
if compile_target != "disabled":
print(
f"inference compile target: {compile_target} "
f"(mode={compile_cfg['mode']})"
)
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model)
@@ -196,6 +261,8 @@ def run(cfg: DictConfig):
policy = swm.policy.RandomPolicy()
inference_ctx = nullcontext()
inference_precision = "fp32"
compile_cfg = get_compile_cfg(cfg)
compile_target = "disabled"
# Hydra switches the working directory to the per-run outputs folder.
# Keep all generated artifacts with that run instead of scattering them
@@ -274,6 +341,9 @@ def run(cfg: DictConfig):
f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
f.write(f"inference_precision: {inference_precision}\n")
f.write(f"inference_compile_target: {compile_target}\n")
if compile_target != "disabled":
f.write(f"inference_compile_mode: {compile_cfg['mode']}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None: