diff --git a/eval.py b/eval.py index c5fae03..8e66fb3 100644 --- a/eval.py +++ b/eval.py @@ -67,6 +67,63 @@ def get_profile_cfg(cfg): return profile_cfg +def get_compile_cfg(cfg): + compile_cfg = { + "enabled": True, + "target": "predictor", + "mode": "reduce-overhead", + "fullgraph": False, + "dynamic": False, + "cuda_only": True, + } + cfg_compile = cfg.get("compile") + if cfg_compile is not None: + compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True)) + return compile_cfg + + +def maybe_compile_inference_target(model, cfg, device): + compile_cfg = get_compile_cfg(cfg) + compile_target = "disabled" + + if not compile_cfg["enabled"]: + return model, compile_cfg, compile_target + + if not hasattr(torch, "compile"): + print("torch.compile is unavailable, skipping inference compilation.") + return model, compile_cfg, compile_target + + if compile_cfg["cuda_only"] and not str(device).startswith("cuda"): + print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.") + return model, compile_cfg, compile_target + + target = str(compile_cfg["target"]).lower() + compile_kwargs = { + "mode": compile_cfg["mode"], + "fullgraph": compile_cfg["fullgraph"], + "dynamic": compile_cfg["dynamic"], + } + + if target == "predictor": + if not hasattr(model, "predictor"): + print("Requested compile target 'predictor' is unavailable on the model.") + return model, compile_cfg, compile_target + model.predictor = torch.compile(model.predictor, **compile_kwargs) + compile_target = "predictor" + elif target == "predict": + if not hasattr(model, "predict"): + print("Requested compile target 'predict' is unavailable on the model.") + return model, compile_cfg, compile_target + model.predict = torch.compile(model.predict, **compile_kwargs) + compile_target = "predict" + else: + print( + f"Unsupported compile.target={target}. Expected one of: predictor, predict." + ) + + return model, compile_cfg, compile_target + + def get_inference_context(cfg, device): precision = str(cfg.get("inference_precision", "fp32")).lower() device_type = "cuda" if device.startswith("cuda") else "cpu" @@ -182,9 +239,17 @@ def run(cfg: DictConfig): model = model.to(device) model = model.eval() model.requires_grad_(False) + model, compile_cfg, compile_target = maybe_compile_inference_target( + model, cfg, device + ) print(f"model parameter dtype: {next(model.parameters()).dtype}") inference_ctx, inference_precision = get_inference_context(cfg, device) print(f"inference execution precision: {inference_precision}") + if compile_target != "disabled": + print( + f"inference compile target: {compile_target} " + f"(mode={compile_cfg['mode']})" + ) model.interpolate_pos_encoding = True config = swm.PlanConfig(**cfg.plan_config) solver = hydra.utils.instantiate(cfg.solver, model=model) @@ -196,6 +261,8 @@ def run(cfg: DictConfig): policy = swm.policy.RandomPolicy() inference_ctx = nullcontext() inference_precision = "fp32" + compile_cfg = get_compile_cfg(cfg) + compile_target = "disabled" # Hydra switches the working directory to the per-run outputs folder. # Keep all generated artifacts with that run instead of scattering them @@ -274,6 +341,9 @@ def run(cfg: DictConfig): f.write(f"metrics: {metrics}\n") f.write(f"evaluation_time: {end_time - start_time} seconds\n") f.write(f"inference_precision: {inference_precision}\n") + f.write(f"inference_compile_target: {compile_target}\n") + if compile_target != "disabled": + f.write(f"inference_compile_mode: {compile_cfg['mode']}\n") if profile_cfg["enabled"]: f.write(f"profile_dir: {profile_dir}\n") if profile_summary_path is not None: