在正式测试前添加warm up

This commit is contained in:
qihuanye
2026-05-16 14:53:58 +00:00
parent d86aeb2df0
commit 02080e2564
5 changed files with 74 additions and 12 deletions

View File

@@ -38,6 +38,9 @@ eval:
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: ogbench/cube_single_expert
callables:
# -- set state

View File

@@ -33,6 +33,9 @@ eval:
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: pusht_expert_train
callables:
# -- set state

View File

@@ -32,6 +32,9 @@ eval:
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: dmc/reacher_random
callables:
# -- set state

View File

@@ -32,6 +32,9 @@ eval:
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: tworoom
callables:
# -- set state

74
eval.py
View File

@@ -7,6 +7,7 @@ import time
import traceback
from contextlib import nullcontext
from pathlib import Path
import tempfile
import hydra
import numpy as np
@@ -84,6 +85,17 @@ def get_compile_cfg(cfg):
return compile_cfg
def get_compile_warmup_cfg(cfg):
warmup_cfg = {
"enabled": True,
"num_eval": 1,
}
cfg_warmup = cfg.get("compile_warmup")
if cfg_warmup is not None:
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
return warmup_cfg
def maybe_compile_inference_target(model, cfg, device):
compile_cfg = get_compile_cfg(cfg)
compile_target = "disabled"
@@ -363,23 +375,27 @@ def run_eval_subset(
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
return world.evaluate_from_dataset(
dataset,
start_steps=list(start_indices),
goal_offset_steps=eval_cfg.eval.goal_offset_steps,
eval_budget=eval_cfg.eval.eval_budget,
episodes_idx=list(episodes),
callables=OmegaConf.to_container(
eval_cfg.eval.get("callables"), resolve=True
),
save_video=bool(eval_cfg.eval.get("save_video", False)),
video_path=output_dir,
)
start_time = time.time()
with get_eval_grad_context(solver):
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=list(eval_start_idx),
goal_offset_steps=local_cfg.eval.goal_offset_steps,
eval_budget=local_cfg.eval.eval_budget,
episodes_idx=list(eval_episodes),
callables=OmegaConf.to_container(
local_cfg.eval.get("callables"), resolve=True
),
save_video=bool(local_cfg.eval.get("save_video", False)),
video_path=output_dir,
)
metrics = evaluate_subset(eval_episodes, eval_start_idx)
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
evaluation_time = time.time() - start_time
@@ -396,6 +412,38 @@ def run_eval_subset(
}
def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
warmup_cfg = get_compile_warmup_cfg(cfg)
if not warmup_cfg["enabled"]:
return
if get_multi_gpu_cfg(cfg)["enabled"]:
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
return
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
if warmup_count < 1:
return
warmup_eval_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
warmup_eval_cfg.eval.num_eval = warmup_count
warmup_eval_cfg.eval.save_video = False
if warmup_eval_cfg.get("profile") is None:
warmup_eval_cfg.profile = OmegaConf.create({"enabled": False})
else:
warmup_eval_cfg.profile.enabled = False
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
run_eval_subset(
warmup_eval_cfg,
eval_episodes[:warmup_count].tolist(),
eval_start_idx[:warmup_count].tolist(),
Path(tmpdir),
enable_profile=False,
)
def _multi_gpu_eval_worker(
cfg_container,
eval_episodes,
@@ -515,6 +563,8 @@ def run(cfg: DictConfig):
output_dir = Path.cwd().resolve()
profile_cfg = get_profile_cfg(cfg)
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
if get_multi_gpu_cfg(cfg)["enabled"]:
if profile_cfg["enabled"]:
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")