减少循环里的张量形状重排和临时对象
This commit is contained in:
18
jepa.py
18
jepa.py
@@ -2,7 +2,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
class JEPA(nn.Module):
|
||||
@@ -99,12 +98,12 @@ class JEPA(nn.Module):
|
||||
"""
|
||||
with torch.profiler.record_function("lewm.encode"):
|
||||
pixels = info['pixels'].float()
|
||||
b = pixels.size(0)
|
||||
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
|
||||
b, t = pixels.shape[:2]
|
||||
pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
|
||||
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
||||
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
||||
emb = self.projector(pixels_emb)
|
||||
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
|
||||
info["emb"] = emb.reshape(b, t, -1)
|
||||
|
||||
if "action" in info:
|
||||
info["act_emb"] = self.action_encoder(info["action"])
|
||||
@@ -118,8 +117,7 @@ class JEPA(nn.Module):
|
||||
"""
|
||||
with torch.profiler.record_function("lewm.predict"):
|
||||
preds = self.predictor(emb, act_emb)
|
||||
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
|
||||
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
|
||||
preds = self.pred_proj(preds)
|
||||
return preds
|
||||
|
||||
####################
|
||||
@@ -143,11 +141,11 @@ class JEPA(nn.Module):
|
||||
init_emb = self._get_cached_init_emb(info)
|
||||
HS = history_size
|
||||
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
|
||||
emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...")
|
||||
emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1)
|
||||
|
||||
act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...")
|
||||
act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1)
|
||||
act_emb_hist = self.action_encoder(act_hist)
|
||||
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
||||
act_future = act_future.reshape(B * S, act_future.size(2), -1)
|
||||
|
||||
for t in range(act_future.size(1)):
|
||||
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
|
||||
@@ -164,7 +162,7 @@ class JEPA(nn.Module):
|
||||
act_emb_hist = next_act_emb
|
||||
|
||||
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
|
||||
info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S)
|
||||
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@@ -236,9 +236,13 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B*T, D)
|
||||
x: (..., D)
|
||||
"""
|
||||
return self.net(x)
|
||||
if x.ndim <= 2:
|
||||
return self.net(x)
|
||||
|
||||
output = self.net(x.reshape(-1, x.size(-1)))
|
||||
return output.reshape(*x.shape[:-1], output.size(-1))
|
||||
|
||||
|
||||
class ARPredictor(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user