add profile frame and bf15/fp16 switch

This commit is contained in:
qihuanye
2026-03-31 11:09:02 +00:00
parent ca231f9f9d
commit 8b84251eb9
4 changed files with 249 additions and 88 deletions

19
.gitignore vendored Normal file
View File

@@ -0,0 +1,19 @@
.venv/
outputs/
__pycache__/
*.py[cod]
.pytest_cache/
.mypy_cache/
.ruff_cache/
torch_profile/
trace.json
key_averages.txt
eval_tmp_*.npy
*.mp4
*.gif
.DS_Store
.idea/
.vscode/

View File

@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
``` ```
## Profiling
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
Example:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
inference_precision=bf16 \
+profile.enabled=true \
+profile.with_stack=true \
+profile.record_shapes=true \
+profile.profile_memory=true
```
Supported inference precision modes:
- `inference_precision=fp32`
- `inference_precision=bf16`
- `inference_precision=fp16`
Outputs are written next to the evaluation results:
- `torch_profile/key_averages.txt` for the aggregated operator table
- `torch_profile/trace.json` for Chrome tracing
- TensorBoard trace files under `torch_profile/`
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
## Pretrained Checkpoints ## Pretrained Checkpoints
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`. Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.

135
eval.py
View File

@@ -3,6 +3,7 @@ import os
os.environ["MUJOCO_GL"] = "egl" os.environ["MUJOCO_GL"] = "egl"
import time import time
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import hydra import hydra
@@ -46,6 +47,99 @@ def get_dataset(cfg, dataset_name):
) )
return dataset return dataset
def get_profile_cfg(cfg):
profile_cfg = {
"enabled": False,
"trace_dirname": "torch_profile",
"record_shapes": True,
"profile_memory": True,
"with_stack": False,
"with_flops": False,
"row_limit": 40,
"worker_name": "eval",
"export_chrome_trace": True,
"export_tensorboard": True,
}
cfg_profile = cfg.get("profile")
if cfg_profile is not None:
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
return profile_cfg
def get_inference_context(cfg, device):
precision = str(cfg.get("inference_precision", "fp32")).lower()
device_type = "cuda" if device.startswith("cuda") else "cpu"
if precision == "fp32":
return nullcontext(), "fp32"
if precision in {"bf16", "bfloat16"}:
return (
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
"bf16",
)
if precision in {"fp16", "float16"}:
if device_type != "cuda":
print("fp16 inference is only supported on CUDA, falling back to fp32.")
return nullcontext(), "fp32"
return (
torch.autocast(device_type=device_type, dtype=torch.float16),
"fp16",
)
raise ValueError(
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
)
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
return nullcontext(), None, profile_cfg
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
profile_dir = results_path / profile_cfg["trace_dirname"]
profile_dir.mkdir(parents=True, exist_ok=True)
profiler = torch.profiler.profile(
activities=activities,
record_shapes=profile_cfg["record_shapes"],
profile_memory=profile_cfg["profile_memory"],
with_stack=profile_cfg["with_stack"],
with_flops=profile_cfg["with_flops"],
)
return profiler, profile_dir, profile_cfg
def dump_profiler_results(profiler, profile_dir, profile_cfg):
if profiler is None or profile_dir is None:
return None
has_cuda = torch.cuda.is_available()
table = profiler.key_averages().table(
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
row_limit=profile_cfg["row_limit"],
)
summary_path = profile_dir / "key_averages.txt"
summary_path.write_text(table)
if profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
if profile_cfg["export_tensorboard"]:
trace_handler = torch.profiler.tensorboard_trace_handler(
str(profile_dir), worker_name=profile_cfg["worker_name"]
)
trace_handler(profiler)
return summary_path
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") @hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig): def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy.""" """Run evaluation of dinowm vs random policy."""
@@ -83,12 +177,15 @@ def run(cfg: DictConfig):
# -- run evaluation # -- run evaluation
policy = cfg.get("policy", "random") policy = cfg.get("policy", "random")
if policy != "random": if policy != "random":
model = swm.policy.AutoCostModel(cfg.policy) model = swm.policy.AutoCostModel(cfg.policy)
model = model.to("cuda") device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model = model.eval() model = model.eval()
model.requires_grad_(False) model.requires_grad_(False)
print(f"model parameter dtype: {next(model.parameters()).dtype}")
inference_ctx, inference_precision = get_inference_context(cfg, device)
print(f"inference execution precision: {inference_precision}")
model.interpolate_pos_encoding = True model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config) config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model) solver = hydra.utils.instantiate(cfg.solver, model=model)
@@ -98,12 +195,15 @@ def run(cfg: DictConfig):
else: else:
policy = swm.policy.RandomPolicy() policy = swm.policy.RandomPolicy()
inference_ctx = nullcontext()
inference_precision = "fp32"
results_path = ( results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random" if cfg.policy != "random"
else Path(__file__).parent else Path(__file__).parent
) )
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
# sample the episodes and the starting indices # sample the episodes and the starting indices
episode_len = get_episodes_length(dataset, ep_indices) episode_len = get_episodes_length(dataset, ep_indices)
@@ -138,17 +238,25 @@ def run(cfg: DictConfig):
world.set_policy(policy) world.set_policy(policy)
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
metrics = world.evaluate_from_dataset( with profiler_ctx as profiler:
dataset, with inference_ctx:
start_steps=eval_start_idx.tolist(), with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
goal_offset_steps=cfg.eval.goal_offset_steps, metrics = world.evaluate_from_dataset(
eval_budget=cfg.eval.eval_budget, dataset,
episodes_idx=eval_episodes.tolist(), start_steps=eval_start_idx.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), goal_offset_steps=cfg.eval.goal_offset_steps,
video_path=results_path, eval_budget=cfg.eval.eval_budget,
) episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
print(metrics) print(metrics)
@@ -165,6 +273,11 @@ def run(cfg: DictConfig):
f.write("==== RESULTS ====\n") f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n") f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n") f.write(f"evaluation_time: {end_time - start_time} seconds\n")
f.write(f"inference_precision: {inference_precision}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
if __name__ == "__main__": if __name__ == "__main__":

154
jepa.py
View File

@@ -30,29 +30,30 @@ class JEPA(nn.Module):
"""Encode observations and actions into embeddings. """Encode observations and actions into embeddings.
info: dict with pixels and action keys info: dict with pixels and action keys
""" """
with torch.profiler.record_function("lewm.encode"):
pixels = info['pixels'].float()
b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
pixels = info['pixels'].float() if "action" in info:
b = pixels.size(0) info["act_emb"] = self.action_encoder(info["action"])
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
if "action" in info: return info
info["act_emb"] = self.action_encoder(info["action"])
return info
def predict(self, emb, act_emb): def predict(self, emb, act_emb):
"""Predict next state embedding """Predict next state embedding
emb: (B, T, D) emb: (B, T, D)
act_emb: (B, T, A_emb) act_emb: (B, T, A_emb)
""" """
preds = self.predictor(emb, act_emb) with torch.profiler.record_function("lewm.predict"):
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) preds = self.predictor(emb, act_emb)
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
return preds preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
return preds
#################### ####################
## Inference only ## ## Inference only ##
@@ -65,89 +66,90 @@ class JEPA(nn.Module):
- S is the number of action plan samples - S is the number of action plan samples
- T is the time horizon - T is the time horizon
""" """
with torch.profiler.record_function("lewm.rollout"):
assert "pixels" in info, "pixels not in info_dict"
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
info["action"] = act_0
n_steps = T - H
assert "pixels" in info, "pixels not in info_dict" # copy and encode initial info dict
H = info["pixels"].size(2) _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
B, S, T = action_sequence.shape[:3] _init = self.encode(_init)
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
info["action"] = act_0 _init = {k: detach_clone(v) for k, v in _init.items()}
n_steps = T - H
# copy and encode initial info dict # flatten batch and sample dimensions for rollout
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)} emb = rearrange(emb, "b s ... -> (b s) ...").clone()
_init = self.encode(_init) act = rearrange(act_0, "b s ... -> (b s) ...")
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1) act_future = rearrange(act_future, "b s ... -> (b s) ...")
_init = {k: detach_clone(v) for k, v in _init.items()}
# flatten batch and sample dimensions for rollout # rollout predictor autoregressively for n_steps
emb = rearrange(emb, "b s ... -> (b s) ...").clone() HS = history_size
act = rearrange(act_0, "b s ... -> (b s) ...") for t in range(n_steps):
act_future = rearrange(act_future, "b s ... -> (b s) ...") act_emb = self.action_encoder(act)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
# rollout predictor autoregressively for n_steps next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
HS = history_size act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
for t in range(n_steps):
act_emb = self.action_encoder(act) # predict the last state
act_emb = self.action_encoder(act) # (BS, T, A_emb)
emb_trunc = emb[:, -HS:] # (BS, HS, D) emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D) emb = torch.cat([emb, pred_emb], dim=1)
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim) # unflatten batch and sample dimensions
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim) pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout
# predict the last state return info
act_emb = self.action_encoder(act) # (BS, T, A_emb)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1)
# unflatten batch and sample dimensions
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout
return info
def criterion(self, info_dict: dict): def criterion(self, info_dict: dict):
"""Compute the cost between predicted embeddings and goal embeddings.""" """Compute the cost between predicted embeddings and goal embeddings."""
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) with torch.profiler.record_function("lewm.criterion"):
goal_emb = info_dict["goal_emb"] # (B, S, T, dim) pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb) goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
# return last-step cost per action candidate # return last-step cost per action candidate
cost = F.mse_loss( cost = F.mse_loss(
pred_emb[..., -1:, :], pred_emb[..., -1:, :],
goal_emb[..., -1:, :].detach(), goal_emb[..., -1:, :].detach(),
reduction="none", reduction="none",
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
return cost return cost
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
""" Compute the cost of action candidates given an info dict with goal and initial state.""" """ Compute the cost of action candidates given an info dict with goal and initial state."""
with torch.profiler.record_function("lewm.get_cost"):
assert "goal" in info_dict, "goal not in info_dict"
assert "goal" in info_dict, "goal not in info_dict" device = next(self.parameters()).device
for k in list(info_dict.keys()):
if torch.is_tensor(info_dict[k]):
info_dict[k] = info_dict[k].to(device)
device = next(self.parameters()).device goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
for k in list(info_dict.keys()): goal["pixels"] = goal["goal"]
if torch.is_tensor(info_dict[k]):
info_dict[k] = info_dict[k].to(device)
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)} for k in info_dict:
goal["pixels"] = goal["goal"] if k.startswith("goal_"):
goal[k[len("goal_") :]] = goal.pop(k)
for k in info_dict: goal.pop("action")
if k.startswith("goal_"): goal = self.encode(goal)
goal[k[len("goal_") :]] = goal.pop(k)
goal.pop("action") info_dict["goal_emb"] = goal["emb"]
goal = self.encode(goal) info_dict = self.rollout(info_dict, action_candidates)
info_dict["goal_emb"] = goal["emb"] cost = self.criterion(info_dict)
info_dict = self.rollout(info_dict, action_candidates)
cost = self.criterion(info_dict) return cost
return cost