减少循环里的张量形状重排和临时对象

This commit is contained in:
qihuanye
2026-04-09 10:14:58 +00:00
parent 3a94829eac
commit 006102d00c
2 changed files with 14 additions and 12 deletions

18
jepa.py
View File

@@ -2,7 +2,6 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
class JEPA(nn.Module):
@@ -99,12 +98,12 @@ class JEPA(nn.Module):
"""
with torch.profiler.record_function("lewm.encode"):
pixels = info['pixels'].float()
b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
b, t = pixels.shape[:2]
pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
info["emb"] = emb.reshape(b, t, -1)
if "action" in info:
info["act_emb"] = self.action_encoder(info["action"])
@@ -118,8 +117,7 @@ class JEPA(nn.Module):
"""
with torch.profiler.record_function("lewm.predict"):
preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
preds = self.pred_proj(preds)
return preds
####################
@@ -143,11 +141,11 @@ class JEPA(nn.Module):
init_emb = self._get_cached_init_emb(info)
HS = history_size
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...")
emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1)
act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...")
act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1)
act_emb_hist = self.action_encoder(act_hist)
act_future = rearrange(act_future, "b s ... -> (b s) ...")
act_future = act_future.reshape(B * S, act_future.size(2), -1)
for t in range(act_future.size(1)):
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
@@ -164,7 +162,7 @@ class JEPA(nn.Module):
act_emb_hist = next_act_emb
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
return info

View File

@@ -236,9 +236,13 @@ class MLP(nn.Module):
def forward(self, x):
"""
x: (B*T, D)
x: (..., D)
"""
return self.net(x)
if x.ndim <= 2:
return self.net(x)
output = self.net(x.reshape(-1, x.size(-1)))
return output.reshape(*x.shape[:-1], output.size(-1))
class ARPredictor(nn.Module):