add profile frame and bf15/fp16 switch
This commit is contained in:
135
eval.py
135
eval.py
@@ -3,6 +3,7 @@ import os
|
||||
os.environ["MUJOCO_GL"] = "egl"
|
||||
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
@@ -46,6 +47,99 @@ def get_dataset(cfg, dataset_name):
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_profile_cfg(cfg):
|
||||
profile_cfg = {
|
||||
"enabled": False,
|
||||
"trace_dirname": "torch_profile",
|
||||
"record_shapes": True,
|
||||
"profile_memory": True,
|
||||
"with_stack": False,
|
||||
"with_flops": False,
|
||||
"row_limit": 40,
|
||||
"worker_name": "eval",
|
||||
"export_chrome_trace": True,
|
||||
"export_tensorboard": True,
|
||||
}
|
||||
cfg_profile = cfg.get("profile")
|
||||
if cfg_profile is not None:
|
||||
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
|
||||
return profile_cfg
|
||||
|
||||
|
||||
def get_inference_context(cfg, device):
|
||||
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||||
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||||
|
||||
if precision == "fp32":
|
||||
return nullcontext(), "fp32"
|
||||
|
||||
if precision in {"bf16", "bfloat16"}:
|
||||
return (
|
||||
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
|
||||
"bf16",
|
||||
)
|
||||
|
||||
if precision in {"fp16", "float16"}:
|
||||
if device_type != "cuda":
|
||||
print("fp16 inference is only supported on CUDA, falling back to fp32.")
|
||||
return nullcontext(), "fp32"
|
||||
return (
|
||||
torch.autocast(device_type=device_type, dtype=torch.float16),
|
||||
"fp16",
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
|
||||
)
|
||||
|
||||
|
||||
def make_profiler(cfg, results_path):
|
||||
profile_cfg = get_profile_cfg(cfg)
|
||||
if not profile_cfg["enabled"]:
|
||||
return nullcontext(), None, profile_cfg
|
||||
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
if torch.cuda.is_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
|
||||
profile_dir = results_path / profile_cfg["trace_dirname"]
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
profiler = torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=profile_cfg["record_shapes"],
|
||||
profile_memory=profile_cfg["profile_memory"],
|
||||
with_stack=profile_cfg["with_stack"],
|
||||
with_flops=profile_cfg["with_flops"],
|
||||
)
|
||||
return profiler, profile_dir, profile_cfg
|
||||
|
||||
|
||||
def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
||||
if profiler is None or profile_dir is None:
|
||||
return None
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
table = profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
|
||||
row_limit=profile_cfg["row_limit"],
|
||||
)
|
||||
|
||||
summary_path = profile_dir / "key_averages.txt"
|
||||
summary_path.write_text(table)
|
||||
|
||||
if profile_cfg["export_chrome_trace"]:
|
||||
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
||||
|
||||
if profile_cfg["export_tensorboard"]:
|
||||
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
||||
)
|
||||
trace_handler(profiler)
|
||||
|
||||
return summary_path
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||
def run(cfg: DictConfig):
|
||||
"""Run evaluation of dinowm vs random policy."""
|
||||
@@ -83,12 +177,15 @@ def run(cfg: DictConfig):
|
||||
|
||||
# -- run evaluation
|
||||
policy = cfg.get("policy", "random")
|
||||
|
||||
if policy != "random":
|
||||
model = swm.policy.AutoCostModel(cfg.policy)
|
||||
model = model.to("cuda")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
model.requires_grad_(False)
|
||||
print(f"model parameter dtype: {next(model.parameters()).dtype}")
|
||||
inference_ctx, inference_precision = get_inference_context(cfg, device)
|
||||
print(f"inference execution precision: {inference_precision}")
|
||||
model.interpolate_pos_encoding = True
|
||||
config = swm.PlanConfig(**cfg.plan_config)
|
||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||||
@@ -98,12 +195,15 @@ def run(cfg: DictConfig):
|
||||
|
||||
else:
|
||||
policy = swm.policy.RandomPolicy()
|
||||
inference_ctx = nullcontext()
|
||||
inference_precision = "fp32"
|
||||
|
||||
results_path = (
|
||||
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
||||
if cfg.policy != "random"
|
||||
else Path(__file__).parent
|
||||
)
|
||||
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
|
||||
|
||||
# sample the episodes and the starting indices
|
||||
episode_len = get_episodes_length(dataset, ep_indices)
|
||||
@@ -138,17 +238,25 @@ def run(cfg: DictConfig):
|
||||
|
||||
world.set_policy(policy)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
metrics = world.evaluate_from_dataset(
|
||||
dataset,
|
||||
start_steps=eval_start_idx.tolist(),
|
||||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
||||
eval_budget=cfg.eval.eval_budget,
|
||||
episodes_idx=eval_episodes.tolist(),
|
||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||||
video_path=results_path,
|
||||
)
|
||||
with profiler_ctx as profiler:
|
||||
with inference_ctx:
|
||||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||
metrics = world.evaluate_from_dataset(
|
||||
dataset,
|
||||
start_steps=eval_start_idx.tolist(),
|
||||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
||||
eval_budget=cfg.eval.eval_budget,
|
||||
episodes_idx=eval_episodes.tolist(),
|
||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||||
video_path=results_path,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
|
||||
|
||||
print(metrics)
|
||||
|
||||
@@ -165,6 +273,11 @@ def run(cfg: DictConfig):
|
||||
f.write("==== RESULTS ====\n")
|
||||
f.write(f"metrics: {metrics}\n")
|
||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||||
f.write(f"inference_precision: {inference_precision}\n")
|
||||
if profile_cfg["enabled"]:
|
||||
f.write(f"profile_dir: {profile_dir}\n")
|
||||
if profile_summary_path is not None:
|
||||
f.write(f"profile_summary: {profile_summary_path}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user