Optimize inference path: add predictor-only torch.compile with reduce-overhead
This commit is contained in:
70
eval.py
70
eval.py
@@ -67,6 +67,63 @@ def get_profile_cfg(cfg):
|
|||||||
return profile_cfg
|
return profile_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_compile_cfg(cfg):
|
||||||
|
compile_cfg = {
|
||||||
|
"enabled": True,
|
||||||
|
"target": "predictor",
|
||||||
|
"mode": "reduce-overhead",
|
||||||
|
"fullgraph": False,
|
||||||
|
"dynamic": False,
|
||||||
|
"cuda_only": True,
|
||||||
|
}
|
||||||
|
cfg_compile = cfg.get("compile")
|
||||||
|
if cfg_compile is not None:
|
||||||
|
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
|
||||||
|
return compile_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_compile_inference_target(model, cfg, device):
|
||||||
|
compile_cfg = get_compile_cfg(cfg)
|
||||||
|
compile_target = "disabled"
|
||||||
|
|
||||||
|
if not compile_cfg["enabled"]:
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
if not hasattr(torch, "compile"):
|
||||||
|
print("torch.compile is unavailable, skipping inference compilation.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
|
||||||
|
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
target = str(compile_cfg["target"]).lower()
|
||||||
|
compile_kwargs = {
|
||||||
|
"mode": compile_cfg["mode"],
|
||||||
|
"fullgraph": compile_cfg["fullgraph"],
|
||||||
|
"dynamic": compile_cfg["dynamic"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == "predictor":
|
||||||
|
if not hasattr(model, "predictor"):
|
||||||
|
print("Requested compile target 'predictor' is unavailable on the model.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
model.predictor = torch.compile(model.predictor, **compile_kwargs)
|
||||||
|
compile_target = "predictor"
|
||||||
|
elif target == "predict":
|
||||||
|
if not hasattr(model, "predict"):
|
||||||
|
print("Requested compile target 'predict' is unavailable on the model.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
model.predict = torch.compile(model.predict, **compile_kwargs)
|
||||||
|
compile_target = "predict"
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
|
||||||
def get_inference_context(cfg, device):
|
def get_inference_context(cfg, device):
|
||||||
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||||||
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||||||
@@ -182,9 +239,17 @@ def run(cfg: DictConfig):
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
|
model, compile_cfg, compile_target = maybe_compile_inference_target(
|
||||||
|
model, cfg, device
|
||||||
|
)
|
||||||
print(f"model parameter dtype: {next(model.parameters()).dtype}")
|
print(f"model parameter dtype: {next(model.parameters()).dtype}")
|
||||||
inference_ctx, inference_precision = get_inference_context(cfg, device)
|
inference_ctx, inference_precision = get_inference_context(cfg, device)
|
||||||
print(f"inference execution precision: {inference_precision}")
|
print(f"inference execution precision: {inference_precision}")
|
||||||
|
if compile_target != "disabled":
|
||||||
|
print(
|
||||||
|
f"inference compile target: {compile_target} "
|
||||||
|
f"(mode={compile_cfg['mode']})"
|
||||||
|
)
|
||||||
model.interpolate_pos_encoding = True
|
model.interpolate_pos_encoding = True
|
||||||
config = swm.PlanConfig(**cfg.plan_config)
|
config = swm.PlanConfig(**cfg.plan_config)
|
||||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||||||
@@ -196,6 +261,8 @@ def run(cfg: DictConfig):
|
|||||||
policy = swm.policy.RandomPolicy()
|
policy = swm.policy.RandomPolicy()
|
||||||
inference_ctx = nullcontext()
|
inference_ctx = nullcontext()
|
||||||
inference_precision = "fp32"
|
inference_precision = "fp32"
|
||||||
|
compile_cfg = get_compile_cfg(cfg)
|
||||||
|
compile_target = "disabled"
|
||||||
|
|
||||||
# Hydra switches the working directory to the per-run outputs folder.
|
# Hydra switches the working directory to the per-run outputs folder.
|
||||||
# Keep all generated artifacts with that run instead of scattering them
|
# Keep all generated artifacts with that run instead of scattering them
|
||||||
@@ -274,6 +341,9 @@ def run(cfg: DictConfig):
|
|||||||
f.write(f"metrics: {metrics}\n")
|
f.write(f"metrics: {metrics}\n")
|
||||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||||||
f.write(f"inference_precision: {inference_precision}\n")
|
f.write(f"inference_precision: {inference_precision}\n")
|
||||||
|
f.write(f"inference_compile_target: {compile_target}\n")
|
||||||
|
if compile_target != "disabled":
|
||||||
|
f.write(f"inference_compile_mode: {compile_cfg['mode']}\n")
|
||||||
if profile_cfg["enabled"]:
|
if profile_cfg["enabled"]:
|
||||||
f.write(f"profile_dir: {profile_dir}\n")
|
f.write(f"profile_dir: {profile_dir}\n")
|
||||||
if profile_summary_path is not None:
|
if profile_summary_path is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user