From 8b84251eb9d0e624cb87ea1794d9aa71e86dc6d9 Mon Sep 17 00:00:00 2001 From: qihuanye Date: Tue, 31 Mar 2026 11:09:02 +0000 Subject: [PATCH] add profile frame and bf15/fp16 switch --- .gitignore | 19 +++++++ README.md | 27 ++++++++++ eval.py | 135 ++++++++++++++++++++++++++++++++++++++++++---- jepa.py | 156 +++++++++++++++++++++++++++-------------------------- 4 files changed, 249 insertions(+), 88 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..716a108 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +.venv/ +outputs/ +__pycache__/ +*.py[cod] + +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ + +torch_profile/ +trace.json +key_averages.txt +eval_tmp_*.npy +*.mp4 +*.gif + +.DS_Store +.idea/ +.vscode/ diff --git a/README.md b/README.md index 6bc5813..fb1132c 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt ``` +## Profiling + +`eval.py` now supports optional inference profiling with PyTorch's native profiler. + +Example: + +```bash +python eval.py --config-name=pusht.yaml policy=pusht/lewm \ + inference_precision=bf16 \ + +profile.enabled=true \ + +profile.with_stack=true \ + +profile.record_shapes=true \ + +profile.profile_memory=true +``` + +Supported inference precision modes: +- `inference_precision=fp32` +- `inference_precision=bf16` +- `inference_precision=fp16` + +Outputs are written next to the evaluation results: +- `torch_profile/key_averages.txt` for the aggregated operator table +- `torch_profile/trace.json` for Chrome tracing +- TensorBoard trace files under `torch_profile/` + +The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect. + ## Pretrained Checkpoints Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`. diff --git a/eval.py b/eval.py index 859afd1..9cbc068 100644 --- a/eval.py +++ b/eval.py @@ -3,6 +3,7 @@ import os os.environ["MUJOCO_GL"] = "egl" import time +from contextlib import nullcontext from pathlib import Path import hydra @@ -46,6 +47,99 @@ def get_dataset(cfg, dataset_name): ) return dataset + +def get_profile_cfg(cfg): + profile_cfg = { + "enabled": False, + "trace_dirname": "torch_profile", + "record_shapes": True, + "profile_memory": True, + "with_stack": False, + "with_flops": False, + "row_limit": 40, + "worker_name": "eval", + "export_chrome_trace": True, + "export_tensorboard": True, + } + cfg_profile = cfg.get("profile") + if cfg_profile is not None: + profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True)) + return profile_cfg + + +def get_inference_context(cfg, device): + precision = str(cfg.get("inference_precision", "fp32")).lower() + device_type = "cuda" if device.startswith("cuda") else "cpu" + + if precision == "fp32": + return nullcontext(), "fp32" + + if precision in {"bf16", "bfloat16"}: + return ( + torch.autocast(device_type=device_type, dtype=torch.bfloat16), + "bf16", + ) + + if precision in {"fp16", "float16"}: + if device_type != "cuda": + print("fp16 inference is only supported on CUDA, falling back to fp32.") + return nullcontext(), "fp32" + return ( + torch.autocast(device_type=device_type, dtype=torch.float16), + "fp16", + ) + + raise ValueError( + f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16." + ) + + +def make_profiler(cfg, results_path): + profile_cfg = get_profile_cfg(cfg) + if not profile_cfg["enabled"]: + return nullcontext(), None, profile_cfg + + activities = [torch.profiler.ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(torch.profiler.ProfilerActivity.CUDA) + + profile_dir = results_path / profile_cfg["trace_dirname"] + profile_dir.mkdir(parents=True, exist_ok=True) + + profiler = torch.profiler.profile( + activities=activities, + record_shapes=profile_cfg["record_shapes"], + profile_memory=profile_cfg["profile_memory"], + with_stack=profile_cfg["with_stack"], + with_flops=profile_cfg["with_flops"], + ) + return profiler, profile_dir, profile_cfg + + +def dump_profiler_results(profiler, profile_dir, profile_cfg): + if profiler is None or profile_dir is None: + return None + + has_cuda = torch.cuda.is_available() + table = profiler.key_averages().table( + sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total", + row_limit=profile_cfg["row_limit"], + ) + + summary_path = profile_dir / "key_averages.txt" + summary_path.write_text(table) + + if profile_cfg["export_chrome_trace"]: + profiler.export_chrome_trace(str(profile_dir / "trace.json")) + + if profile_cfg["export_tensorboard"]: + trace_handler = torch.profiler.tensorboard_trace_handler( + str(profile_dir), worker_name=profile_cfg["worker_name"] + ) + trace_handler(profiler) + + return summary_path + @hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") def run(cfg: DictConfig): """Run evaluation of dinowm vs random policy.""" @@ -83,12 +177,15 @@ def run(cfg: DictConfig): # -- run evaluation policy = cfg.get("policy", "random") - if policy != "random": model = swm.policy.AutoCostModel(cfg.policy) - model = model.to("cuda") + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) model = model.eval() model.requires_grad_(False) + print(f"model parameter dtype: {next(model.parameters()).dtype}") + inference_ctx, inference_precision = get_inference_context(cfg, device) + print(f"inference execution precision: {inference_precision}") model.interpolate_pos_encoding = True config = swm.PlanConfig(**cfg.plan_config) solver = hydra.utils.instantiate(cfg.solver, model=model) @@ -98,12 +195,15 @@ def run(cfg: DictConfig): else: policy = swm.policy.RandomPolicy() + inference_ctx = nullcontext() + inference_precision = "fp32" results_path = ( Path(swm.data.utils.get_cache_dir(), cfg.policy).parent if cfg.policy != "random" else Path(__file__).parent ) + profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path) # sample the episodes and the starting indices episode_len = get_episodes_length(dataset, ep_indices) @@ -138,17 +238,25 @@ def run(cfg: DictConfig): world.set_policy(policy) + if torch.cuda.is_available(): + torch.cuda.synchronize() start_time = time.time() - metrics = world.evaluate_from_dataset( - dataset, - start_steps=eval_start_idx.tolist(), - goal_offset_steps=cfg.eval.goal_offset_steps, - eval_budget=cfg.eval.eval_budget, - episodes_idx=eval_episodes.tolist(), - callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), - video_path=results_path, - ) + with profiler_ctx as profiler: + with inference_ctx: + with torch.profiler.record_function("eval.world_evaluate_from_dataset"): + metrics = world.evaluate_from_dataset( + dataset, + start_steps=eval_start_idx.tolist(), + goal_offset_steps=cfg.eval.goal_offset_steps, + eval_budget=cfg.eval.eval_budget, + episodes_idx=eval_episodes.tolist(), + callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), + video_path=results_path, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() end_time = time.time() + profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg) print(metrics) @@ -165,6 +273,11 @@ def run(cfg: DictConfig): f.write("==== RESULTS ====\n") f.write(f"metrics: {metrics}\n") f.write(f"evaluation_time: {end_time - start_time} seconds\n") + f.write(f"inference_precision: {inference_precision}\n") + if profile_cfg["enabled"]: + f.write(f"profile_dir: {profile_dir}\n") + if profile_summary_path is not None: + f.write(f"profile_summary: {profile_summary_path}\n") if __name__ == "__main__": diff --git a/jepa.py b/jepa.py index 486fe93..a47131a 100644 --- a/jepa.py +++ b/jepa.py @@ -30,29 +30,30 @@ class JEPA(nn.Module): """Encode observations and actions into embeddings. info: dict with pixels and action keys """ + with torch.profiler.record_function("lewm.encode"): + pixels = info['pixels'].float() + b = pixels.size(0) + pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding + output = self.encoder(pixels, interpolate_pos_encoding=True) + pixels_emb = output.last_hidden_state[:, 0] # cls token + emb = self.projector(pixels_emb) + info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b) - pixels = info['pixels'].float() - b = pixels.size(0) - pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding - output = self.encoder(pixels, interpolate_pos_encoding=True) - pixels_emb = output.last_hidden_state[:, 0] # cls token - emb = self.projector(pixels_emb) - info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b) + if "action" in info: + info["act_emb"] = self.action_encoder(info["action"]) - if "action" in info: - info["act_emb"] = self.action_encoder(info["action"]) - - return info + return info def predict(self, emb, act_emb): """Predict next state embedding emb: (B, T, D) act_emb: (B, T, A_emb) """ - preds = self.predictor(emb, act_emb) - preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) - preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) - return preds + with torch.profiler.record_function("lewm.predict"): + preds = self.predictor(emb, act_emb) + preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) + preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) + return preds #################### ## Inference only ## @@ -65,89 +66,90 @@ class JEPA(nn.Module): - S is the number of action plan samples - T is the time horizon """ + with torch.profiler.record_function("lewm.rollout"): + assert "pixels" in info, "pixels not in info_dict" + H = info["pixels"].size(2) + B, S, T = action_sequence.shape[:3] + act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) + info["action"] = act_0 + n_steps = T - H - assert "pixels" in info, "pixels not in info_dict" - H = info["pixels"].size(2) - B, S, T = action_sequence.shape[:3] - act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) - info["action"] = act_0 - n_steps = T - H + # copy and encode initial info dict + _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)} + _init = self.encode(_init) + emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1) + _init = {k: detach_clone(v) for k, v in _init.items()} - # copy and encode initial info dict - _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)} - _init = self.encode(_init) - emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1) - _init = {k: detach_clone(v) for k, v in _init.items()} + # flatten batch and sample dimensions for rollout + emb = rearrange(emb, "b s ... -> (b s) ...").clone() + act = rearrange(act_0, "b s ... -> (b s) ...") + act_future = rearrange(act_future, "b s ... -> (b s) ...") - # flatten batch and sample dimensions for rollout - emb = rearrange(emb, "b s ... -> (b s) ...").clone() - act = rearrange(act_0, "b s ... -> (b s) ...") - act_future = rearrange(act_future, "b s ... -> (b s) ...") + # rollout predictor autoregressively for n_steps + HS = history_size + for t in range(n_steps): + act_emb = self.action_encoder(act) + emb_trunc = emb[:, -HS:] # (BS, HS, D) + act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) + pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) + emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D) - # rollout predictor autoregressively for n_steps - HS = history_size - for t in range(n_steps): - act_emb = self.action_encoder(act) + next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim) + act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim) + + # predict the last state + act_emb = self.action_encoder(act) # (BS, T, A_emb) emb_trunc = emb[:, -HS:] # (BS, HS, D) act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) - emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D) + emb = torch.cat([emb, pred_emb], dim=1) - next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim) - act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim) + # unflatten batch and sample dimensions + pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S) + info["predicted_emb"] = pred_rollout - # predict the last state - act_emb = self.action_encoder(act) # (BS, T, A_emb) - emb_trunc = emb[:, -HS:] # (BS, HS, D) - act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) - pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) - emb = torch.cat([emb, pred_emb], dim=1) - - # unflatten batch and sample dimensions - pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S) - info["predicted_emb"] = pred_rollout - - return info + return info def criterion(self, info_dict: dict): """Compute the cost between predicted embeddings and goal embeddings.""" - pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) - goal_emb = info_dict["goal_emb"] # (B, S, T, dim) + with torch.profiler.record_function("lewm.criterion"): + pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) + goal_emb = info_dict["goal_emb"] # (B, S, T, dim) - goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb) + goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb) - # return last-step cost per action candidate - cost = F.mse_loss( - pred_emb[..., -1:, :], - goal_emb[..., -1:, :].detach(), - reduction="none", - ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) + # return last-step cost per action candidate + cost = F.mse_loss( + pred_emb[..., -1:, :], + goal_emb[..., -1:, :].detach(), + reduction="none", + ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) - return cost + return cost def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): """ Compute the cost of action candidates given an info dict with goal and initial state.""" + with torch.profiler.record_function("lewm.get_cost"): + assert "goal" in info_dict, "goal not in info_dict" - assert "goal" in info_dict, "goal not in info_dict" + device = next(self.parameters()).device + for k in list(info_dict.keys()): + if torch.is_tensor(info_dict[k]): + info_dict[k] = info_dict[k].to(device) - device = next(self.parameters()).device - for k in list(info_dict.keys()): - if torch.is_tensor(info_dict[k]): - info_dict[k] = info_dict[k].to(device) + goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)} + goal["pixels"] = goal["goal"] - goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)} - goal["pixels"] = goal["goal"] + for k in info_dict: + if k.startswith("goal_"): + goal[k[len("goal_") :]] = goal.pop(k) - for k in info_dict: - if k.startswith("goal_"): - goal[k[len("goal_") :]] = goal.pop(k) + goal.pop("action") + goal = self.encode(goal) - goal.pop("action") - goal = self.encode(goal) + info_dict["goal_emb"] = goal["emb"] + info_dict = self.rollout(info_dict, action_candidates) - info_dict["goal_emb"] = goal["emb"] - info_dict = self.rollout(info_dict, action_candidates) - - cost = self.criterion(info_dict) - - return cost + cost = self.criterion(info_dict) + + return cost