"""JEPA Implementation""" import torch import torch.nn.functional as F from torch import nn class JEPA(nn.Module): def __init__( self, encoder, predictor, action_encoder, projector=None, pred_proj=None, ): super().__init__() self.encoder = encoder self.predictor = predictor self.action_encoder = action_encoder self.projector = projector or nn.Identity() self.pred_proj = pred_proj or nn.Identity() self._cached_device_tensors = {} self._cached_init_signature = None self._cached_init_emb = None self._cached_goal_signature = None self._cached_goal_emb = None def _ensure_runtime_caches(self): if not hasattr(self, "_cached_device_tensors"): self._cached_device_tensors = {} if not hasattr(self, "_cached_init_signature"): self._cached_init_signature = None if not hasattr(self, "_cached_init_emb"): self._cached_init_emb = None if not hasattr(self, "_cached_goal_signature"): self._cached_goal_signature = None if not hasattr(self, "_cached_goal_emb"): self._cached_goal_emb = None @staticmethod def _tensor_signature(tensor: torch.Tensor): try: version = tensor._version except RuntimeError: version = None return ( str(tensor.device), tensor.dtype, tuple(tensor.shape), tensor.data_ptr(), version, ) def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device): self._ensure_runtime_caches() signature = (self._tensor_signature(tensor), str(device)) cached = self._cached_device_tensors.get(key) if cached is None or cached[0] != signature: self._cached_device_tensors[key] = ( signature, tensor.to(device, non_blocking=True), ) return self._cached_device_tensors[key][1] def _ensure_info_device(self, info_dict: dict, device: torch.device): for key, value in list(info_dict.items()): if key.startswith("_lewm_"): continue if torch.is_tensor(value) and value.device != device: info_dict[key] = self._get_cached_device_tensor(key, value, device) return info_dict def _get_cached_init_emb(self, info_dict: dict): self._ensure_runtime_caches() pixels = info_dict["pixels"] signature = self._tensor_signature(pixels) if self._cached_init_signature != signature: init_info = {"pixels": pixels[:, 0]} self._cached_init_emb = self.encode(init_info)["emb"].detach() self._cached_init_signature = signature return self._cached_init_emb def _get_cached_goal_emb(self, info_dict: dict): self._ensure_runtime_caches() goal = info_dict["goal"] signature = self._tensor_signature(goal) if self._cached_goal_signature != signature: goal_info = {"pixels": goal[:, 0]} self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach() self._cached_goal_signature = signature return self._cached_goal_emb def encode(self, info): """Encode observations and actions into embeddings. info: dict with pixels and action keys """ with torch.profiler.record_function("lewm.encode"): pixels = info['pixels'].float() b, t = pixels.shape[:2] pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding output = self.encoder(pixels, interpolate_pos_encoding=True) pixels_emb = output.last_hidden_state[:, 0] # cls token emb = self.projector(pixels_emb) info["emb"] = emb.reshape(b, t, -1) if "action" in info: info["act_emb"] = self.action_encoder(info["action"]) return info def predict(self, emb, act_emb): """Predict next state embedding emb: (B, T, D) act_emb: (B, T, A_emb) """ with torch.profiler.record_function("lewm.predict"): preds = self.predictor(emb, act_emb) preds = self.pred_proj(preds) return preds #################### ## Inference only ## #################### def rollout(self, info, action_sequence, history_size: int = 3): """Rollout the model given an initial info dict and action sequence. pixels: (B, S, T, C, H, W) action_sequence: (B, S, T, action_dim) - S is the number of action plan samples - T is the time horizon """ with torch.profiler.record_function("lewm.rollout"): assert "pixels" in info, "pixels not in info_dict" H = info["pixels"].size(2) B, S, T = action_sequence.shape[:3] act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) # Cache the encoded initial state across solver iterations. init_emb = self._get_cached_init_emb(info) HS = history_size emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1) emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1) act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1) act_emb_hist = self.action_encoder(act_hist) act_future = act_future.reshape(B * S, act_future.size(2), -1) for t in range(act_future.size(1)): pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] if HS > 1: emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1) else: emb_hist = pred_emb next_act = act_future[:, t : t + 1, :] next_act_emb = self.action_encoder(next_act) if HS > 1: act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1) else: act_emb_hist = next_act_emb pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:]) return info def criterion(self, info_dict: dict): """Compute the cost between predicted embeddings and goal embeddings.""" with torch.profiler.record_function("lewm.criterion"): pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) goal_emb = info_dict["goal_emb"] # (B, S, T, dim) if goal_emb.ndim == pred_emb.ndim - 1: goal_emb = goal_emb.unsqueeze(1) # return last-step cost per action candidate cost = F.mse_loss( pred_emb[..., -1:, :], goal_emb[..., -1:, :].detach(), reduction="none", ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) return cost def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): """ Compute the cost of action candidates given an info dict with goal and initial state.""" with torch.profiler.record_function("lewm.get_cost"): assert "goal" in info_dict, "goal not in info_dict" self._ensure_runtime_caches() device = next(self.parameters()).device info_dict = self._ensure_info_device(info_dict, device) if action_candidates.device != device: action_candidates = action_candidates.to(device, non_blocking=True) info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict) info_dict = self.rollout(info_dict, action_candidates) cost = self.criterion(info_dict) return cost