"""JEPA Implementation""" import torch import torch.nn.functional as F from torch import nn class JEPA(nn.Module): def __init__( self, encoder, predictor, action_encoder, projector=None, pred_proj=None, ): super().__init__() self.encoder = encoder self.predictor = predictor self.action_encoder = action_encoder self.projector = projector or nn.Identity() self.pred_proj = pred_proj or nn.Identity() self._cached_device_tensors = {} self._cached_init_signature = None self._cached_init_emb = None self._cached_goal_signature = None self._cached_goal_emb = None def _ensure_runtime_caches(self): if not hasattr(self, "_cached_device_tensors"): self._cached_device_tensors = {} if not hasattr(self, "_cached_init_signature"): self._cached_init_signature = None if not hasattr(self, "_cached_init_emb"): self._cached_init_emb = None if not hasattr(self, "_cached_goal_signature"): self._cached_goal_signature = None if not hasattr(self, "_cached_goal_emb"): self._cached_goal_emb = None @staticmethod def _tensor_signature(tensor: torch.Tensor): try: version = tensor._version except RuntimeError: version = None return ( str(tensor.device), tensor.dtype, tuple(tensor.shape), tuple(tensor.stride()), tensor.storage_offset(), tensor.data_ptr(), version, ) def _get_cached_device_tensor( self, key: str, tensor: torch.Tensor, device: torch.device, *, ensure_contiguous: bool = False, ): self._ensure_runtime_caches() if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()): return tensor signature = (self._tensor_signature(tensor), str(device), ensure_contiguous) cached = self._cached_device_tensors.get(key) if cached is None or cached[0] != signature: prepared = tensor.to(device, non_blocking=True) if ensure_contiguous and not prepared.is_contiguous(): prepared = prepared.contiguous() self._cached_device_tensors[key] = ( signature, prepared, ) return self._cached_device_tensors[key][1] def _ensure_info_device(self, info_dict: dict, device: torch.device): for key, value in list(info_dict.items()): if key.startswith("_lewm_"): continue if torch.is_tensor(value): info_dict[key] = self._get_cached_device_tensor( key, value, device, ensure_contiguous=True, ) return info_dict def _get_cached_init_emb(self, info_dict: dict): self._ensure_runtime_caches() pixels = info_dict["pixels"] signature = self._tensor_signature(pixels) if self._cached_init_signature != signature: init_info = {"pixels": pixels[:, 0]} self._cached_init_emb = self.encode(init_info)["emb"].detach() self._cached_init_signature = signature return self._cached_init_emb def _get_cached_goal_emb(self, info_dict: dict): self._ensure_runtime_caches() goal = info_dict["goal"] signature = self._tensor_signature(goal) if self._cached_goal_signature != signature: goal_info = {"pixels": goal[:, 0]} self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach() self._cached_goal_signature = signature return self._cached_goal_emb def encode(self, info): """Encode observations and actions into embeddings. info: dict with pixels and action keys """ with torch.profiler.record_function("lewm.encode"): pixels = info['pixels'].float() b, t = pixels.shape[:2] pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding output = self.encoder(pixels, interpolate_pos_encoding=True) pixels_emb = output.last_hidden_state[:, 0] # cls token emb = self.projector(pixels_emb) info["emb"] = emb.reshape(b, t, -1) if "action" in info: info["act_emb"] = self.action_encoder(info["action"]) return info def predict(self, emb, act_emb): """Predict next state embedding emb: (B, T, D) act_emb: (B, T, A_emb) """ with torch.profiler.record_function("lewm.predict"): preds = self.predictor(emb, act_emb) preds = self.pred_proj(preds) return preds #################### ## Inference only ## #################### def rollout(self, info, action_sequence, history_size: int = 3): """Rollout the model given an initial info dict and action sequence. pixels: (B, S, T, C, H, W) action_sequence: (B, S, T, action_dim) - S is the number of action plan samples - T is the time horizon """ with torch.profiler.record_function("lewm.rollout"): assert "pixels" in info, "pixels not in info_dict" if history_size < 1: raise ValueError("history_size must be >= 1") H = info["pixels"].size(2) B, S, T = action_sequence.shape[:3] if T < H: raise ValueError( f"action_sequence horizon ({T}) must be >= history length ({H})" ) # Cache the encoded initial state across solver iterations. init_emb = self._get_cached_init_emb(info) HS = history_size hist_len = min(HS, init_emb.size(1), H) if hist_len < 1: raise ValueError("rollout requires at least one history step") init_hist = init_emb[:, -hist_len:] init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1) init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous() flat_actions = action_sequence.contiguous().view(B * S, T, -1) action_emb = self.action_encoder(flat_actions) act_hist = action_emb[:, H - hist_len : H] act_future = action_emb[:, H:] if HS == 1: emb_hist = init_hist[:, -1:] act_emb_hist = act_hist[:, -1:] for t in range(act_future.size(1)): emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:] act_emb_hist = act_future[:, t : t + 1] pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:] else: emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1))) act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1))) emb_hist[:, :hist_len].copy_(init_hist) act_emb_hist[:, :hist_len].copy_(act_hist) history_order = torch.stack( [ (torch.arange(HS, device=action_emb.device) + offset) % HS for offset in range(HS) ] ) filled = hist_len next_slot = hist_len % HS for t in range(act_future.size(1)): if filled < HS: emb_view = emb_hist[:, :filled] act_view = act_emb_hist[:, :filled] elif next_slot == 0: emb_view = emb_hist act_view = act_emb_hist else: order = history_order[next_slot] emb_view = emb_hist.index_select(1, order) act_view = act_emb_hist.index_select(1, order) pred_emb = self.predict(emb_view, act_view)[:, -1:] next_act_emb = act_future[:, t : t + 1] emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb) act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb) if filled < HS: filled += 1 next_slot = (next_slot + 1) % HS if filled < HS: emb_view = emb_hist[:, :filled] act_view = act_emb_hist[:, :filled] elif next_slot == 0: emb_view = emb_hist act_view = act_emb_hist else: order = history_order[next_slot] emb_view = emb_hist.index_select(1, order) act_view = act_emb_hist.index_select(1, order) pred_rollout = self.predict(emb_view, act_view)[:, -1:] info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:]) return info def criterion(self, info_dict: dict): """Compute the cost between predicted embeddings and goal embeddings.""" with torch.profiler.record_function("lewm.criterion"): pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) goal_emb = info_dict["goal_emb"] # (B, S, T, dim) if goal_emb.ndim == pred_emb.ndim - 1: goal_emb = goal_emb.unsqueeze(1) # return last-step cost per action candidate cost = F.mse_loss( pred_emb[..., -1:, :], goal_emb[..., -1:, :].detach(), reduction="none", ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) return cost def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): """ Compute the cost of action candidates given an info dict with goal and initial state.""" with torch.profiler.record_function("lewm.get_cost"): assert "goal" in info_dict, "goal not in info_dict" self._ensure_runtime_caches() device = next(self.parameters()).device info_dict = self._ensure_info_device(info_dict, device) action_candidates = self._get_cached_device_tensor( "_lewm_action_candidates", action_candidates, device, ensure_contiguous=True, ) info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict) info_dict = self.rollout(info_dict, action_candidates) cost = self.criterion(info_dict) return cost