Optimize inference path: add predictor-only torch.compile with reduce-overhead
This commit is contained in:
70
eval.py
70
eval.py
@@ -67,6 +67,63 @@ def get_profile_cfg(cfg):
|
||||
return profile_cfg
|
||||
|
||||
|
||||
def get_compile_cfg(cfg):
|
||||
compile_cfg = {
|
||||
"enabled": True,
|
||||
"target": "predictor",
|
||||
"mode": "reduce-overhead",
|
||||
"fullgraph": False,
|
||||
"dynamic": False,
|
||||
"cuda_only": True,
|
||||
}
|
||||
cfg_compile = cfg.get("compile")
|
||||
if cfg_compile is not None:
|
||||
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
|
||||
return compile_cfg
|
||||
|
||||
|
||||
def maybe_compile_inference_target(model, cfg, device):
|
||||
compile_cfg = get_compile_cfg(cfg)
|
||||
compile_target = "disabled"
|
||||
|
||||
if not compile_cfg["enabled"]:
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
if not hasattr(torch, "compile"):
|
||||
print("torch.compile is unavailable, skipping inference compilation.")
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
|
||||
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
target = str(compile_cfg["target"]).lower()
|
||||
compile_kwargs = {
|
||||
"mode": compile_cfg["mode"],
|
||||
"fullgraph": compile_cfg["fullgraph"],
|
||||
"dynamic": compile_cfg["dynamic"],
|
||||
}
|
||||
|
||||
if target == "predictor":
|
||||
if not hasattr(model, "predictor"):
|
||||
print("Requested compile target 'predictor' is unavailable on the model.")
|
||||
return model, compile_cfg, compile_target
|
||||
model.predictor = torch.compile(model.predictor, **compile_kwargs)
|
||||
compile_target = "predictor"
|
||||
elif target == "predict":
|
||||
if not hasattr(model, "predict"):
|
||||
print("Requested compile target 'predict' is unavailable on the model.")
|
||||
return model, compile_cfg, compile_target
|
||||
model.predict = torch.compile(model.predict, **compile_kwargs)
|
||||
compile_target = "predict"
|
||||
else:
|
||||
print(
|
||||
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
|
||||
)
|
||||
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
|
||||
def get_inference_context(cfg, device):
|
||||
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||||
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||||
@@ -182,9 +239,17 @@ def run(cfg: DictConfig):
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
model.requires_grad_(False)
|
||||
model, compile_cfg, compile_target = maybe_compile_inference_target(
|
||||
model, cfg, device
|
||||
)
|
||||
print(f"model parameter dtype: {next(model.parameters()).dtype}")
|
||||
inference_ctx, inference_precision = get_inference_context(cfg, device)
|
||||
print(f"inference execution precision: {inference_precision}")
|
||||
if compile_target != "disabled":
|
||||
print(
|
||||
f"inference compile target: {compile_target} "
|
||||
f"(mode={compile_cfg['mode']})"
|
||||
)
|
||||
model.interpolate_pos_encoding = True
|
||||
config = swm.PlanConfig(**cfg.plan_config)
|
||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||||
@@ -196,6 +261,8 @@ def run(cfg: DictConfig):
|
||||
policy = swm.policy.RandomPolicy()
|
||||
inference_ctx = nullcontext()
|
||||
inference_precision = "fp32"
|
||||
compile_cfg = get_compile_cfg(cfg)
|
||||
compile_target = "disabled"
|
||||
|
||||
# Hydra switches the working directory to the per-run outputs folder.
|
||||
# Keep all generated artifacts with that run instead of scattering them
|
||||
@@ -274,6 +341,9 @@ def run(cfg: DictConfig):
|
||||
f.write(f"metrics: {metrics}\n")
|
||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||||
f.write(f"inference_precision: {inference_precision}\n")
|
||||
f.write(f"inference_compile_target: {compile_target}\n")
|
||||
if compile_target != "disabled":
|
||||
f.write(f"inference_compile_mode: {compile_cfg['mode']}\n")
|
||||
if profile_cfg["enabled"]:
|
||||
f.write(f"profile_dir: {profile_dir}\n")
|
||||
if profile_summary_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user