在正式测试前添加warm up

This commit is contained in:
qihuanye
2026-05-16 14:53:58 +00:00
parent d86aeb2df0
commit 02080e2564
5 changed files with 74 additions and 12 deletions

View File

@@ -38,6 +38,9 @@ eval:
eval_budget: 50 eval_budget: 50
img_size: 224 img_size: 224
save_video: false save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: ogbench/cube_single_expert dataset_name: ogbench/cube_single_expert
callables: callables:
# -- set state # -- set state

View File

@@ -33,6 +33,9 @@ eval:
eval_budget: 50 eval_budget: 50
img_size: 224 img_size: 224
save_video: false save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: pusht_expert_train dataset_name: pusht_expert_train
callables: callables:
# -- set state # -- set state

View File

@@ -32,6 +32,9 @@ eval:
eval_budget: 50 eval_budget: 50
img_size: 224 img_size: 224
save_video: false save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: dmc/reacher_random dataset_name: dmc/reacher_random
callables: callables:
# -- set state # -- set state

View File

@@ -32,6 +32,9 @@ eval:
eval_budget: 50 eval_budget: 50
img_size: 224 img_size: 224
save_video: false save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: tworoom dataset_name: tworoom
callables: callables:
# -- set state # -- set state

74
eval.py
View File

@@ -7,6 +7,7 @@ import time
import traceback import traceback
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import tempfile
import hydra import hydra
import numpy as np import numpy as np
@@ -84,6 +85,17 @@ def get_compile_cfg(cfg):
return compile_cfg return compile_cfg
def get_compile_warmup_cfg(cfg):
warmup_cfg = {
"enabled": True,
"num_eval": 1,
}
cfg_warmup = cfg.get("compile_warmup")
if cfg_warmup is not None:
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
return warmup_cfg
def maybe_compile_inference_target(model, cfg, device): def maybe_compile_inference_target(model, cfg, device):
compile_cfg = get_compile_cfg(cfg) compile_cfg = get_compile_cfg(cfg)
compile_target = "disabled" compile_target = "disabled"
@@ -363,23 +375,27 @@ def run_eval_subset(
if str(device).startswith("cuda") and torch.cuda.is_available(): if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
return world.evaluate_from_dataset(
dataset,
start_steps=list(start_indices),
goal_offset_steps=eval_cfg.eval.goal_offset_steps,
eval_budget=eval_cfg.eval.eval_budget,
episodes_idx=list(episodes),
callables=OmegaConf.to_container(
eval_cfg.eval.get("callables"), resolve=True
),
save_video=bool(eval_cfg.eval.get("save_video", False)),
video_path=output_dir,
)
start_time = time.time() start_time = time.time()
with get_eval_grad_context(solver): with get_eval_grad_context(solver):
with profiler_ctx as profiler: with profiler_ctx as profiler:
with inference_ctx: with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"): with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset( metrics = evaluate_subset(eval_episodes, eval_start_idx)
dataset,
start_steps=list(eval_start_idx),
goal_offset_steps=local_cfg.eval.goal_offset_steps,
eval_budget=local_cfg.eval.eval_budget,
episodes_idx=list(eval_episodes),
callables=OmegaConf.to_container(
local_cfg.eval.get("callables"), resolve=True
),
save_video=bool(local_cfg.eval.get("save_video", False)),
video_path=output_dir,
)
if str(device).startswith("cuda") and torch.cuda.is_available(): if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
evaluation_time = time.time() - start_time evaluation_time = time.time() - start_time
@@ -396,6 +412,38 @@ def run_eval_subset(
} }
def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
warmup_cfg = get_compile_warmup_cfg(cfg)
if not warmup_cfg["enabled"]:
return
if get_multi_gpu_cfg(cfg)["enabled"]:
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
return
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
if warmup_count < 1:
return
warmup_eval_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
warmup_eval_cfg.eval.num_eval = warmup_count
warmup_eval_cfg.eval.save_video = False
if warmup_eval_cfg.get("profile") is None:
warmup_eval_cfg.profile = OmegaConf.create({"enabled": False})
else:
warmup_eval_cfg.profile.enabled = False
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
run_eval_subset(
warmup_eval_cfg,
eval_episodes[:warmup_count].tolist(),
eval_start_idx[:warmup_count].tolist(),
Path(tmpdir),
enable_profile=False,
)
def _multi_gpu_eval_worker( def _multi_gpu_eval_worker(
cfg_container, cfg_container,
eval_episodes, eval_episodes,
@@ -515,6 +563,8 @@ def run(cfg: DictConfig):
output_dir = Path.cwd().resolve() output_dir = Path.cwd().resolve()
profile_cfg = get_profile_cfg(cfg) profile_cfg = get_profile_cfg(cfg)
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
if get_multi_gpu_cfg(cfg)["enabled"]: if get_multi_gpu_cfg(cfg)["enabled"]:
if profile_cfg["enabled"]: if profile_cfg["enabled"]:
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true") raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")