add profile frame and bf15/fp16 switch

This commit is contained in:
qihuanye
2026-03-31 11:09:02 +00:00
parent ca231f9f9d
commit 8b84251eb9
4 changed files with 249 additions and 88 deletions

19
.gitignore vendored Normal file
View File

@@ -0,0 +1,19 @@
.venv/
outputs/
__pycache__/
*.py[cod]
.pytest_cache/
.mypy_cache/
.ruff_cache/
torch_profile/
trace.json
key_averages.txt
eval_tmp_*.npy
*.mp4
*.gif
.DS_Store
.idea/
.vscode/

View File

@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
```
## Profiling
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
Example:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
inference_precision=bf16 \
+profile.enabled=true \
+profile.with_stack=true \
+profile.record_shapes=true \
+profile.profile_memory=true
```
Supported inference precision modes:
- `inference_precision=fp32`
- `inference_precision=bf16`
- `inference_precision=fp16`
Outputs are written next to the evaluation results:
- `torch_profile/key_averages.txt` for the aggregated operator table
- `torch_profile/trace.json` for Chrome tracing
- TensorBoard trace files under `torch_profile/`
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
## Pretrained Checkpoints
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.

117
eval.py
View File

@@ -3,6 +3,7 @@ import os
os.environ["MUJOCO_GL"] = "egl"
import time
from contextlib import nullcontext
from pathlib import Path
import hydra
@@ -46,6 +47,99 @@ def get_dataset(cfg, dataset_name):
)
return dataset
def get_profile_cfg(cfg):
profile_cfg = {
"enabled": False,
"trace_dirname": "torch_profile",
"record_shapes": True,
"profile_memory": True,
"with_stack": False,
"with_flops": False,
"row_limit": 40,
"worker_name": "eval",
"export_chrome_trace": True,
"export_tensorboard": True,
}
cfg_profile = cfg.get("profile")
if cfg_profile is not None:
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
return profile_cfg
def get_inference_context(cfg, device):
precision = str(cfg.get("inference_precision", "fp32")).lower()
device_type = "cuda" if device.startswith("cuda") else "cpu"
if precision == "fp32":
return nullcontext(), "fp32"
if precision in {"bf16", "bfloat16"}:
return (
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
"bf16",
)
if precision in {"fp16", "float16"}:
if device_type != "cuda":
print("fp16 inference is only supported on CUDA, falling back to fp32.")
return nullcontext(), "fp32"
return (
torch.autocast(device_type=device_type, dtype=torch.float16),
"fp16",
)
raise ValueError(
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
)
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
return nullcontext(), None, profile_cfg
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
profile_dir = results_path / profile_cfg["trace_dirname"]
profile_dir.mkdir(parents=True, exist_ok=True)
profiler = torch.profiler.profile(
activities=activities,
record_shapes=profile_cfg["record_shapes"],
profile_memory=profile_cfg["profile_memory"],
with_stack=profile_cfg["with_stack"],
with_flops=profile_cfg["with_flops"],
)
return profiler, profile_dir, profile_cfg
def dump_profiler_results(profiler, profile_dir, profile_cfg):
if profiler is None or profile_dir is None:
return None
has_cuda = torch.cuda.is_available()
table = profiler.key_averages().table(
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
row_limit=profile_cfg["row_limit"],
)
summary_path = profile_dir / "key_averages.txt"
summary_path.write_text(table)
if profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
if profile_cfg["export_tensorboard"]:
trace_handler = torch.profiler.tensorboard_trace_handler(
str(profile_dir), worker_name=profile_cfg["worker_name"]
)
trace_handler(profiler)
return summary_path
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy."""
@@ -83,12 +177,15 @@ def run(cfg: DictConfig):
# -- run evaluation
policy = cfg.get("policy", "random")
if policy != "random":
model = swm.policy.AutoCostModel(cfg.policy)
model = model.to("cuda")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model = model.eval()
model.requires_grad_(False)
print(f"model parameter dtype: {next(model.parameters()).dtype}")
inference_ctx, inference_precision = get_inference_context(cfg, device)
print(f"inference execution precision: {inference_precision}")
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model)
@@ -98,12 +195,15 @@ def run(cfg: DictConfig):
else:
policy = swm.policy.RandomPolicy()
inference_ctx = nullcontext()
inference_precision = "fp32"
results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random"
else Path(__file__).parent
)
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
# sample the episodes and the starting indices
episode_len = get_episodes_length(dataset, ep_indices)
@@ -138,7 +238,12 @@ def run(cfg: DictConfig):
world.set_policy(policy)
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
@@ -148,7 +253,10 @@ def run(cfg: DictConfig):
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
print(metrics)
@@ -165,6 +273,11 @@ def run(cfg: DictConfig):
f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
f.write(f"inference_precision: {inference_precision}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
if __name__ == "__main__":

View File

@@ -30,7 +30,7 @@ class JEPA(nn.Module):
"""Encode observations and actions into embeddings.
info: dict with pixels and action keys
"""
with torch.profiler.record_function("lewm.encode"):
pixels = info['pixels'].float()
b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
@@ -49,6 +49,7 @@ class JEPA(nn.Module):
emb: (B, T, D)
act_emb: (B, T, A_emb)
"""
with torch.profiler.record_function("lewm.predict"):
preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
@@ -65,7 +66,7 @@ class JEPA(nn.Module):
- S is the number of action plan samples
- T is the time horizon
"""
with torch.profiler.record_function("lewm.rollout"):
assert "pixels" in info, "pixels not in info_dict"
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
@@ -111,6 +112,7 @@ class JEPA(nn.Module):
def criterion(self, info_dict: dict):
"""Compute the cost between predicted embeddings and goal embeddings."""
with torch.profiler.record_function("lewm.criterion"):
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
@@ -127,7 +129,7 @@ class JEPA(nn.Module):
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
""" Compute the cost of action candidates given an info dict with goal and initial state."""
with torch.profiler.record_function("lewm.get_cost"):
assert "goal" in info_dict, "goal not in info_dict"
device = next(self.parameters()).device