add profile frame and bf15/fp16 switch

This commit is contained in:
qihuanye
2026-03-31 11:09:02 +00:00
parent ca231f9f9d
commit 8b84251eb9
4 changed files with 249 additions and 88 deletions

19
.gitignore vendored Normal file
View File

@@ -0,0 +1,19 @@
.venv/
outputs/
__pycache__/
*.py[cod]
.pytest_cache/
.mypy_cache/
.ruff_cache/
torch_profile/
trace.json
key_averages.txt
eval_tmp_*.npy
*.mp4
*.gif
.DS_Store
.idea/
.vscode/

View File

@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
``` ```
## Profiling
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
Example:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
inference_precision=bf16 \
+profile.enabled=true \
+profile.with_stack=true \
+profile.record_shapes=true \
+profile.profile_memory=true
```
Supported inference precision modes:
- `inference_precision=fp32`
- `inference_precision=bf16`
- `inference_precision=fp16`
Outputs are written next to the evaluation results:
- `torch_profile/key_averages.txt` for the aggregated operator table
- `torch_profile/trace.json` for Chrome tracing
- TensorBoard trace files under `torch_profile/`
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
## Pretrained Checkpoints ## Pretrained Checkpoints
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`. Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.

117
eval.py
View File

@@ -3,6 +3,7 @@ import os
os.environ["MUJOCO_GL"] = "egl" os.environ["MUJOCO_GL"] = "egl"
import time import time
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import hydra import hydra
@@ -46,6 +47,99 @@ def get_dataset(cfg, dataset_name):
) )
return dataset return dataset
def get_profile_cfg(cfg):
profile_cfg = {
"enabled": False,
"trace_dirname": "torch_profile",
"record_shapes": True,
"profile_memory": True,
"with_stack": False,
"with_flops": False,
"row_limit": 40,
"worker_name": "eval",
"export_chrome_trace": True,
"export_tensorboard": True,
}
cfg_profile = cfg.get("profile")
if cfg_profile is not None:
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
return profile_cfg
def get_inference_context(cfg, device):
precision = str(cfg.get("inference_precision", "fp32")).lower()
device_type = "cuda" if device.startswith("cuda") else "cpu"
if precision == "fp32":
return nullcontext(), "fp32"
if precision in {"bf16", "bfloat16"}:
return (
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
"bf16",
)
if precision in {"fp16", "float16"}:
if device_type != "cuda":
print("fp16 inference is only supported on CUDA, falling back to fp32.")
return nullcontext(), "fp32"
return (
torch.autocast(device_type=device_type, dtype=torch.float16),
"fp16",
)
raise ValueError(
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
)
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
return nullcontext(), None, profile_cfg
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
profile_dir = results_path / profile_cfg["trace_dirname"]
profile_dir.mkdir(parents=True, exist_ok=True)
profiler = torch.profiler.profile(
activities=activities,
record_shapes=profile_cfg["record_shapes"],
profile_memory=profile_cfg["profile_memory"],
with_stack=profile_cfg["with_stack"],
with_flops=profile_cfg["with_flops"],
)
return profiler, profile_dir, profile_cfg
def dump_profiler_results(profiler, profile_dir, profile_cfg):
if profiler is None or profile_dir is None:
return None
has_cuda = torch.cuda.is_available()
table = profiler.key_averages().table(
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
row_limit=profile_cfg["row_limit"],
)
summary_path = profile_dir / "key_averages.txt"
summary_path.write_text(table)
if profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
if profile_cfg["export_tensorboard"]:
trace_handler = torch.profiler.tensorboard_trace_handler(
str(profile_dir), worker_name=profile_cfg["worker_name"]
)
trace_handler(profiler)
return summary_path
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") @hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig): def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy.""" """Run evaluation of dinowm vs random policy."""
@@ -83,12 +177,15 @@ def run(cfg: DictConfig):
# -- run evaluation # -- run evaluation
policy = cfg.get("policy", "random") policy = cfg.get("policy", "random")
if policy != "random": if policy != "random":
model = swm.policy.AutoCostModel(cfg.policy) model = swm.policy.AutoCostModel(cfg.policy)
model = model.to("cuda") device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model = model.eval() model = model.eval()
model.requires_grad_(False) model.requires_grad_(False)
print(f"model parameter dtype: {next(model.parameters()).dtype}")
inference_ctx, inference_precision = get_inference_context(cfg, device)
print(f"inference execution precision: {inference_precision}")
model.interpolate_pos_encoding = True model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config) config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model) solver = hydra.utils.instantiate(cfg.solver, model=model)
@@ -98,12 +195,15 @@ def run(cfg: DictConfig):
else: else:
policy = swm.policy.RandomPolicy() policy = swm.policy.RandomPolicy()
inference_ctx = nullcontext()
inference_precision = "fp32"
results_path = ( results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random" if cfg.policy != "random"
else Path(__file__).parent else Path(__file__).parent
) )
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
# sample the episodes and the starting indices # sample the episodes and the starting indices
episode_len = get_episodes_length(dataset, ep_indices) episode_len = get_episodes_length(dataset, ep_indices)
@@ -138,7 +238,12 @@ def run(cfg: DictConfig):
world.set_policy(policy) world.set_policy(policy)
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset( metrics = world.evaluate_from_dataset(
dataset, dataset,
start_steps=eval_start_idx.tolist(), start_steps=eval_start_idx.tolist(),
@@ -148,7 +253,10 @@ def run(cfg: DictConfig):
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path, video_path=results_path,
) )
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
print(metrics) print(metrics)
@@ -165,6 +273,11 @@ def run(cfg: DictConfig):
f.write("==== RESULTS ====\n") f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n") f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n") f.write(f"evaluation_time: {end_time - start_time} seconds\n")
f.write(f"inference_precision: {inference_precision}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -30,7 +30,7 @@ class JEPA(nn.Module):
"""Encode observations and actions into embeddings. """Encode observations and actions into embeddings.
info: dict with pixels and action keys info: dict with pixels and action keys
""" """
with torch.profiler.record_function("lewm.encode"):
pixels = info['pixels'].float() pixels = info['pixels'].float()
b = pixels.size(0) b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
@@ -49,6 +49,7 @@ class JEPA(nn.Module):
emb: (B, T, D) emb: (B, T, D)
act_emb: (B, T, A_emb) act_emb: (B, T, A_emb)
""" """
with torch.profiler.record_function("lewm.predict"):
preds = self.predictor(emb, act_emb) preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
@@ -65,7 +66,7 @@ class JEPA(nn.Module):
- S is the number of action plan samples - S is the number of action plan samples
- T is the time horizon - T is the time horizon
""" """
with torch.profiler.record_function("lewm.rollout"):
assert "pixels" in info, "pixels not in info_dict" assert "pixels" in info, "pixels not in info_dict"
H = info["pixels"].size(2) H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3] B, S, T = action_sequence.shape[:3]
@@ -111,6 +112,7 @@ class JEPA(nn.Module):
def criterion(self, info_dict: dict): def criterion(self, info_dict: dict):
"""Compute the cost between predicted embeddings and goal embeddings.""" """Compute the cost between predicted embeddings and goal embeddings."""
with torch.profiler.record_function("lewm.criterion"):
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim) goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
@@ -127,7 +129,7 @@ class JEPA(nn.Module):
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
""" Compute the cost of action candidates given an info dict with goal and initial state.""" """ Compute the cost of action candidates given an info dict with goal and initial state."""
with torch.profiler.record_function("lewm.get_cost"):
assert "goal" in info_dict, "goal not in info_dict" assert "goal" in info_dict, "goal not in info_dict"
device = next(self.parameters()).device device = next(self.parameters()).device