减少循环里的张量形状重排和临时对象

This commit is contained in:
qihuanye
2026-04-09 10:14:58 +00:00
parent 3a94829eac
commit 006102d00c
2 changed files with 14 additions and 12 deletions

18
jepa.py
View File

@@ -2,7 +2,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from torch import nn from torch import nn
class JEPA(nn.Module): class JEPA(nn.Module):
@@ -99,12 +98,12 @@ class JEPA(nn.Module):
""" """
with torch.profiler.record_function("lewm.encode"): with torch.profiler.record_function("lewm.encode"):
pixels = info['pixels'].float() pixels = info['pixels'].float()
b = pixels.size(0) b, t = pixels.shape[:2]
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True) output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb) emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b) info["emb"] = emb.reshape(b, t, -1)
if "action" in info: if "action" in info:
info["act_emb"] = self.action_encoder(info["action"]) info["act_emb"] = self.action_encoder(info["action"])
@@ -118,8 +117,7 @@ class JEPA(nn.Module):
""" """
with torch.profiler.record_function("lewm.predict"): with torch.profiler.record_function("lewm.predict"):
preds = self.predictor(emb, act_emb) preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) preds = self.pred_proj(preds)
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
return preds return preds
#################### ####################
@@ -143,11 +141,11 @@ class JEPA(nn.Module):
init_emb = self._get_cached_init_emb(info) init_emb = self._get_cached_init_emb(info)
HS = history_size HS = history_size
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1) emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...") emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1)
act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...") act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1)
act_emb_hist = self.action_encoder(act_hist) act_emb_hist = self.action_encoder(act_hist)
act_future = rearrange(act_future, "b s ... -> (b s) ...") act_future = act_future.reshape(B * S, act_future.size(2), -1)
for t in range(act_future.size(1)): for t in range(act_future.size(1)):
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
@@ -164,7 +162,7 @@ class JEPA(nn.Module):
act_emb_hist = next_act_emb act_emb_hist = next_act_emb
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S) info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
return info return info

View File

@@ -236,9 +236,13 @@ class MLP(nn.Module):
def forward(self, x): def forward(self, x):
""" """
x: (B*T, D) x: (..., D)
""" """
return self.net(x) if x.ndim <= 2:
return self.net(x)
output = self.net(x.reshape(-1, x.size(-1)))
return output.reshape(*x.shape[:-1], output.size(-1))
class ARPredictor(nn.Module): class ARPredictor(nn.Module):