优化 jepa.py 中通用 rollout 热路径:批量预编码动

作、移除循环内
  torch.cat,并为 history_size==1 与环形缓冲区更新
  添加更轻量实现; 收益不大
This commit is contained in:
qihuanye
2026-04-09 11:57:09 +00:00
parent cd03a0d5cb
commit 995cd8cfec
3 changed files with 597 additions and 197 deletions

92
jepa.py
View File

@@ -133,35 +133,89 @@ class JEPA(nn.Module):
"""
with torch.profiler.record_function("lewm.rollout"):
assert "pixels" in info, "pixels not in info_dict"
if history_size < 1:
raise ValueError("history_size must be >= 1")
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
if T < H:
raise ValueError(
f"action_sequence horizon ({T}) must be >= history length ({H})"
)
# Cache the encoded initial state across solver iterations.
init_emb = self._get_cached_init_emb(info)
HS = history_size
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1)
hist_len = min(HS, init_emb.size(1), H)
if hist_len < 1:
raise ValueError("rollout requires at least one history step")
act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1)
act_emb_hist = self.action_encoder(act_hist)
act_future = act_future.reshape(B * S, act_future.size(2), -1)
init_hist = init_emb[:, -hist_len:]
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1))
for t in range(act_future.size(1)):
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
if HS > 1:
emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1)
flat_actions = action_sequence.reshape(B * S, T, -1)
action_emb = self.action_encoder(flat_actions)
act_hist = action_emb[:, H - hist_len : H]
act_future = action_emb[:, H:]
if HS == 1:
emb_hist = init_hist[:, -1:]
act_emb_hist = act_hist[:, -1:]
for t in range(act_future.size(1)):
emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:]
act_emb_hist = act_future[:, t : t + 1]
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
else:
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
emb_hist[:, :hist_len].copy_(init_hist)
act_emb_hist[:, :hist_len].copy_(act_hist)
history_order = torch.stack(
[
(torch.arange(HS, device=action_emb.device) + offset) % HS
for offset in range(HS)
]
)
filled = hist_len
next_slot = hist_len % HS
for t in range(act_future.size(1)):
if filled < HS:
emb_view = emb_hist[:, :filled]
act_view = act_emb_hist[:, :filled]
elif next_slot == 0:
emb_view = emb_hist
act_view = act_emb_hist
else:
order = history_order[next_slot]
emb_view = emb_hist.index_select(1, order)
act_view = act_emb_hist.index_select(1, order)
pred_emb = self.predict(emb_view, act_view)[:, -1:]
next_act_emb = act_future[:, t : t + 1]
emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb)
act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb)
if filled < HS:
filled += 1
next_slot = (next_slot + 1) % HS
if filled < HS:
emb_view = emb_hist[:, :filled]
act_view = act_emb_hist[:, :filled]
elif next_slot == 0:
emb_view = emb_hist
act_view = act_emb_hist
else:
emb_hist = pred_emb
order = history_order[next_slot]
emb_view = emb_hist.index_select(1, order)
act_view = act_emb_hist.index_select(1, order)
next_act = act_future[:, t : t + 1, :]
next_act_emb = self.action_encoder(next_act)
if HS > 1:
act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1)
else:
act_emb_hist = next_act_emb
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
return info