From 006102d00c61583dd8156dddc7b679afa26a2c5f Mon Sep 17 00:00:00 2001 From: qihuanye Date: Thu, 9 Apr 2026 10:14:58 +0000 Subject: [PATCH] =?UTF-8?q?=E5=87=8F=E5=B0=91=E5=BE=AA=E7=8E=AF=E9=87=8C?= =?UTF-8?q?=E7=9A=84=E5=BC=A0=E9=87=8F=E5=BD=A2=E7=8A=B6=E9=87=8D=E6=8E=92?= =?UTF-8?q?=E5=92=8C=E4=B8=B4=E6=97=B6=E5=AF=B9=E8=B1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jepa.py | 18 ++++++++---------- module.py | 8 ++++++-- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/jepa.py b/jepa.py index 971c995..6c0b947 100644 --- a/jepa.py +++ b/jepa.py @@ -2,7 +2,6 @@ import torch import torch.nn.functional as F -from einops import rearrange from torch import nn class JEPA(nn.Module): @@ -99,12 +98,12 @@ class JEPA(nn.Module): """ with torch.profiler.record_function("lewm.encode"): pixels = info['pixels'].float() - b = pixels.size(0) - pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding + b, t = pixels.shape[:2] + pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding output = self.encoder(pixels, interpolate_pos_encoding=True) pixels_emb = output.last_hidden_state[:, 0] # cls token emb = self.projector(pixels_emb) - info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b) + info["emb"] = emb.reshape(b, t, -1) if "action" in info: info["act_emb"] = self.action_encoder(info["action"]) @@ -118,8 +117,7 @@ class JEPA(nn.Module): """ with torch.profiler.record_function("lewm.predict"): preds = self.predictor(emb, act_emb) - preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) - preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) + preds = self.pred_proj(preds) return preds #################### @@ -143,11 +141,11 @@ class JEPA(nn.Module): init_emb = self._get_cached_init_emb(info) HS = history_size emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1) - emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...") + emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1) - act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...") + act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1) act_emb_hist = self.action_encoder(act_hist) - act_future = rearrange(act_future, "b s ... -> (b s) ...") + act_future = act_future.reshape(B * S, act_future.size(2), -1) for t in range(act_future.size(1)): pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] @@ -164,7 +162,7 @@ class JEPA(nn.Module): act_emb_hist = next_act_emb pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] - info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S) + info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:]) return info diff --git a/module.py b/module.py index 16c4907..acc6949 100644 --- a/module.py +++ b/module.py @@ -236,9 +236,13 @@ class MLP(nn.Module): def forward(self, x): """ - x: (B*T, D) + x: (..., D) """ - return self.net(x) + if x.ndim <= 2: + return self.net(x) + + output = self.net(x.reshape(-1, x.size(-1))) + return output.reshape(*x.shape[:-1], output.size(-1)) class ARPredictor(nn.Module):