Optimize JEPA eval outputs and inference hot path

This commit is contained in:
qihuanye
2026-04-08 12:41:21 +00:00
parent 8b84251eb9
commit fa1c15c896
3 changed files with 164 additions and 60 deletions

147
jepa.py
View File

@@ -5,9 +5,6 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
def detach_clone(v):
return v.detach().clone() if torch.is_tensor(v) else v
class JEPA(nn.Module):
def __init__(
@@ -25,6 +22,76 @@ class JEPA(nn.Module):
self.action_encoder = action_encoder
self.projector = projector or nn.Identity()
self.pred_proj = pred_proj or nn.Identity()
self._cached_device_tensors = {}
self._cached_init_signature = None
self._cached_init_emb = None
self._cached_goal_signature = None
self._cached_goal_emb = None
def _ensure_runtime_caches(self):
if not hasattr(self, "_cached_device_tensors"):
self._cached_device_tensors = {}
if not hasattr(self, "_cached_init_signature"):
self._cached_init_signature = None
if not hasattr(self, "_cached_init_emb"):
self._cached_init_emb = None
if not hasattr(self, "_cached_goal_signature"):
self._cached_goal_signature = None
if not hasattr(self, "_cached_goal_emb"):
self._cached_goal_emb = None
@staticmethod
def _tensor_signature(tensor: torch.Tensor):
try:
version = tensor._version
except RuntimeError:
version = None
return (
str(tensor.device),
tensor.dtype,
tuple(tensor.shape),
tensor.data_ptr(),
version,
)
def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device):
self._ensure_runtime_caches()
signature = (self._tensor_signature(tensor), str(device))
cached = self._cached_device_tensors.get(key)
if cached is None or cached[0] != signature:
self._cached_device_tensors[key] = (
signature,
tensor.to(device, non_blocking=True),
)
return self._cached_device_tensors[key][1]
def _ensure_info_device(self, info_dict: dict, device: torch.device):
for key, value in list(info_dict.items()):
if key.startswith("_lewm_"):
continue
if torch.is_tensor(value) and value.device != device:
info_dict[key] = self._get_cached_device_tensor(key, value, device)
return info_dict
def _get_cached_init_emb(self, info_dict: dict):
self._ensure_runtime_caches()
pixels = info_dict["pixels"]
signature = self._tensor_signature(pixels)
if self._cached_init_signature != signature:
init_info = {"pixels": pixels[:, 0]}
self._cached_init_emb = self.encode(init_info)["emb"].detach()
self._cached_init_signature = signature
return self._cached_init_emb
def _get_cached_goal_emb(self, info_dict: dict):
self._ensure_runtime_caches()
goal = info_dict["goal"]
signature = self._tensor_signature(goal)
if self._cached_goal_signature != signature:
goal_info = {"pixels": goal[:, 0]}
self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach()
self._cached_goal_signature = signature
return self._cached_goal_emb
def encode(self, info):
"""Encode observations and actions into embeddings.
@@ -71,42 +138,33 @@ class JEPA(nn.Module):
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
info["action"] = act_0
n_steps = T - H
# copy and encode initial info dict
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
_init = self.encode(_init)
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
_init = {k: detach_clone(v) for k, v in _init.items()}
# Cache the encoded initial state across solver iterations.
init_emb = self._get_cached_init_emb(info)
HS = history_size
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...")
# flatten batch and sample dimensions for rollout
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
act = rearrange(act_0, "b s ... -> (b s) ...")
act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...")
act_emb_hist = self.action_encoder(act_hist)
act_future = rearrange(act_future, "b s ... -> (b s) ...")
# rollout predictor autoregressively for n_steps
HS = history_size
for t in range(n_steps):
act_emb = self.action_encoder(act)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
for t in range(act_future.size(1)):
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
if HS > 1:
emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1)
else:
emb_hist = pred_emb
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
next_act = act_future[:, t : t + 1, :]
next_act_emb = self.action_encoder(next_act)
if HS > 1:
act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1)
else:
act_emb_hist = next_act_emb
# predict the last state
act_emb = self.action_encoder(act) # (BS, T, A_emb)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1)
# unflatten batch and sample dimensions
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S)
return info
@@ -115,8 +173,8 @@ class JEPA(nn.Module):
with torch.profiler.record_function("lewm.criterion"):
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
if goal_emb.ndim == pred_emb.ndim - 1:
goal_emb = goal_emb.unsqueeze(1)
# return last-step cost per action candidate
cost = F.mse_loss(
@@ -132,22 +190,13 @@ class JEPA(nn.Module):
with torch.profiler.record_function("lewm.get_cost"):
assert "goal" in info_dict, "goal not in info_dict"
self._ensure_runtime_caches()
device = next(self.parameters()).device
for k in list(info_dict.keys()):
if torch.is_tensor(info_dict[k]):
info_dict[k] = info_dict[k].to(device)
info_dict = self._ensure_info_device(info_dict, device)
if action_candidates.device != device:
action_candidates = action_candidates.to(device, non_blocking=True)
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
goal["pixels"] = goal["goal"]
for k in info_dict:
if k.startswith("goal_"):
goal[k[len("goal_") :]] = goal.pop(k)
goal.pop("action")
goal = self.encode(goal)
info_dict["goal_emb"] = goal["emb"]
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
info_dict = self.rollout(info_dict, action_candidates)
cost = self.criterion(info_dict)