Optimize JEPA eval outputs and inference hot path

This commit is contained in:
qihuanye
2026-04-08 12:41:21 +00:00
parent 8b84251eb9
commit fa1c15c896
3 changed files with 164 additions and 60 deletions

20
eval.py
View File

@@ -129,14 +129,13 @@ def dump_profiler_results(profiler, profile_dir, profile_cfg):
summary_path = profile_dir / "key_averages.txt"
summary_path.write_text(table)
if profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
if profile_cfg["export_tensorboard"]:
trace_handler = torch.profiler.tensorboard_trace_handler(
str(profile_dir), worker_name=profile_cfg["worker_name"]
)
trace_handler(profiler)
elif profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
return summary_path
@@ -198,12 +197,11 @@ def run(cfg: DictConfig):
inference_ctx = nullcontext()
inference_precision = "fp32"
results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random"
else Path(__file__).parent
)
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
# Hydra switches the working directory to the per-run outputs folder.
# Keep all generated artifacts with that run instead of scattering them
# next to the cache or source tree.
output_dir = Path.cwd().resolve()
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, output_dir)
# sample the episodes and the starting indices
episode_len = get_episodes_length(dataset, ep_indices)
@@ -251,7 +249,7 @@ def run(cfg: DictConfig):
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path,
video_path=output_dir,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
@@ -260,7 +258,7 @@ def run(cfg: DictConfig):
print(metrics)
results_path = results_path / cfg.output.filename
results_path = output_dir / cfg.output.filename
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("a") as f: