diff --git a/eval.py b/eval.py index 9cbc068..4e63a1f 100644 --- a/eval.py +++ b/eval.py @@ -129,14 +129,13 @@ def dump_profiler_results(profiler, profile_dir, profile_cfg): summary_path = profile_dir / "key_averages.txt" summary_path.write_text(table) - if profile_cfg["export_chrome_trace"]: - profiler.export_chrome_trace(str(profile_dir / "trace.json")) - if profile_cfg["export_tensorboard"]: trace_handler = torch.profiler.tensorboard_trace_handler( str(profile_dir), worker_name=profile_cfg["worker_name"] ) trace_handler(profiler) + elif profile_cfg["export_chrome_trace"]: + profiler.export_chrome_trace(str(profile_dir / "trace.json")) return summary_path @@ -198,12 +197,11 @@ def run(cfg: DictConfig): inference_ctx = nullcontext() inference_precision = "fp32" - results_path = ( - Path(swm.data.utils.get_cache_dir(), cfg.policy).parent - if cfg.policy != "random" - else Path(__file__).parent - ) - profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path) + # Hydra switches the working directory to the per-run outputs folder. + # Keep all generated artifacts with that run instead of scattering them + # next to the cache or source tree. + output_dir = Path.cwd().resolve() + profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, output_dir) # sample the episodes and the starting indices episode_len = get_episodes_length(dataset, ep_indices) @@ -251,7 +249,7 @@ def run(cfg: DictConfig): eval_budget=cfg.eval.eval_budget, episodes_idx=eval_episodes.tolist(), callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), - video_path=results_path, + video_path=output_dir, ) if torch.cuda.is_available(): torch.cuda.synchronize() @@ -260,7 +258,7 @@ def run(cfg: DictConfig): print(metrics) - results_path = results_path / cfg.output.filename + results_path = output_dir / cfg.output.filename results_path.parent.mkdir(parents=True, exist_ok=True) with results_path.open("a") as f: diff --git a/jepa.py b/jepa.py index a47131a..971c995 100644 --- a/jepa.py +++ b/jepa.py @@ -5,9 +5,6 @@ import torch.nn.functional as F from einops import rearrange from torch import nn -def detach_clone(v): - return v.detach().clone() if torch.is_tensor(v) else v - class JEPA(nn.Module): def __init__( @@ -25,6 +22,76 @@ class JEPA(nn.Module): self.action_encoder = action_encoder self.projector = projector or nn.Identity() self.pred_proj = pred_proj or nn.Identity() + self._cached_device_tensors = {} + self._cached_init_signature = None + self._cached_init_emb = None + self._cached_goal_signature = None + self._cached_goal_emb = None + + def _ensure_runtime_caches(self): + if not hasattr(self, "_cached_device_tensors"): + self._cached_device_tensors = {} + if not hasattr(self, "_cached_init_signature"): + self._cached_init_signature = None + if not hasattr(self, "_cached_init_emb"): + self._cached_init_emb = None + if not hasattr(self, "_cached_goal_signature"): + self._cached_goal_signature = None + if not hasattr(self, "_cached_goal_emb"): + self._cached_goal_emb = None + + @staticmethod + def _tensor_signature(tensor: torch.Tensor): + try: + version = tensor._version + except RuntimeError: + version = None + return ( + str(tensor.device), + tensor.dtype, + tuple(tensor.shape), + tensor.data_ptr(), + version, + ) + + def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device): + self._ensure_runtime_caches() + signature = (self._tensor_signature(tensor), str(device)) + cached = self._cached_device_tensors.get(key) + if cached is None or cached[0] != signature: + self._cached_device_tensors[key] = ( + signature, + tensor.to(device, non_blocking=True), + ) + return self._cached_device_tensors[key][1] + + def _ensure_info_device(self, info_dict: dict, device: torch.device): + for key, value in list(info_dict.items()): + if key.startswith("_lewm_"): + continue + if torch.is_tensor(value) and value.device != device: + info_dict[key] = self._get_cached_device_tensor(key, value, device) + return info_dict + + def _get_cached_init_emb(self, info_dict: dict): + self._ensure_runtime_caches() + pixels = info_dict["pixels"] + signature = self._tensor_signature(pixels) + if self._cached_init_signature != signature: + init_info = {"pixels": pixels[:, 0]} + self._cached_init_emb = self.encode(init_info)["emb"].detach() + self._cached_init_signature = signature + return self._cached_init_emb + + def _get_cached_goal_emb(self, info_dict: dict): + self._ensure_runtime_caches() + goal = info_dict["goal"] + signature = self._tensor_signature(goal) + if self._cached_goal_signature != signature: + goal_info = {"pixels": goal[:, 0]} + self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach() + self._cached_goal_signature = signature + return self._cached_goal_emb def encode(self, info): """Encode observations and actions into embeddings. @@ -71,42 +138,33 @@ class JEPA(nn.Module): H = info["pixels"].size(2) B, S, T = action_sequence.shape[:3] act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) - info["action"] = act_0 - n_steps = T - H - # copy and encode initial info dict - _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)} - _init = self.encode(_init) - emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1) - _init = {k: detach_clone(v) for k, v in _init.items()} + # Cache the encoded initial state across solver iterations. + init_emb = self._get_cached_init_emb(info) + HS = history_size + emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1) + emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...") - # flatten batch and sample dimensions for rollout - emb = rearrange(emb, "b s ... -> (b s) ...").clone() - act = rearrange(act_0, "b s ... -> (b s) ...") + act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...") + act_emb_hist = self.action_encoder(act_hist) act_future = rearrange(act_future, "b s ... -> (b s) ...") - # rollout predictor autoregressively for n_steps - HS = history_size - for t in range(n_steps): - act_emb = self.action_encoder(act) - emb_trunc = emb[:, -HS:] # (BS, HS, D) - act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) - pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) - emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D) + for t in range(act_future.size(1)): + pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] + if HS > 1: + emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1) + else: + emb_hist = pred_emb - next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim) - act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim) + next_act = act_future[:, t : t + 1, :] + next_act_emb = self.action_encoder(next_act) + if HS > 1: + act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1) + else: + act_emb_hist = next_act_emb - # predict the last state - act_emb = self.action_encoder(act) # (BS, T, A_emb) - emb_trunc = emb[:, -HS:] # (BS, HS, D) - act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) - pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) - emb = torch.cat([emb, pred_emb], dim=1) - - # unflatten batch and sample dimensions - pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S) - info["predicted_emb"] = pred_rollout + pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] + info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S) return info @@ -115,8 +173,8 @@ class JEPA(nn.Module): with torch.profiler.record_function("lewm.criterion"): pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) goal_emb = info_dict["goal_emb"] # (B, S, T, dim) - - goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb) + if goal_emb.ndim == pred_emb.ndim - 1: + goal_emb = goal_emb.unsqueeze(1) # return last-step cost per action candidate cost = F.mse_loss( @@ -132,22 +190,13 @@ class JEPA(nn.Module): with torch.profiler.record_function("lewm.get_cost"): assert "goal" in info_dict, "goal not in info_dict" + self._ensure_runtime_caches() device = next(self.parameters()).device - for k in list(info_dict.keys()): - if torch.is_tensor(info_dict[k]): - info_dict[k] = info_dict[k].to(device) + info_dict = self._ensure_info_device(info_dict, device) + if action_candidates.device != device: + action_candidates = action_candidates.to(device, non_blocking=True) - goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)} - goal["pixels"] = goal["goal"] - - for k in info_dict: - if k.startswith("goal_"): - goal[k[len("goal_") :]] = goal.pop(k) - - goal.pop("action") - goal = self.encode(goal) - - info_dict["goal_emb"] = goal["emb"] + info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict) info_dict = self.rollout(info_dict, action_candidates) cost = self.criterion(info_dict) diff --git a/tworoom_results.txt b/tworoom_results.txt new file mode 100644 index 0000000..1770acb --- /dev/null +++ b/tworoom_results.txt @@ -0,0 +1,57 @@ + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 133.1857841014862 seconds +inference_precision: fp32