Wrap eval inference in torch.inference_mode

This commit is contained in:
qihuanye
2026-04-09 09:18:35 +00:00
parent 0f85e39690
commit 9e2407cdc4
3 changed files with 791 additions and 12 deletions

25
eval.py
View File

@@ -239,18 +239,19 @@ def run(cfg: DictConfig):
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=output_dir,
)
with torch.inference_mode():
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=output_dir,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()