优化 jepa.py 中通用 rollout 热路径:批量预编码动
作、移除循环内 torch.cat,并为 history_size==1 与环形缓冲区更新 添加更轻量实现; 收益不大
This commit is contained in:
92
jepa.py
92
jepa.py
@@ -133,35 +133,89 @@ class JEPA(nn.Module):
|
||||
"""
|
||||
with torch.profiler.record_function("lewm.rollout"):
|
||||
assert "pixels" in info, "pixels not in info_dict"
|
||||
if history_size < 1:
|
||||
raise ValueError("history_size must be >= 1")
|
||||
|
||||
H = info["pixels"].size(2)
|
||||
B, S, T = action_sequence.shape[:3]
|
||||
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
||||
if T < H:
|
||||
raise ValueError(
|
||||
f"action_sequence horizon ({T}) must be >= history length ({H})"
|
||||
)
|
||||
|
||||
# Cache the encoded initial state across solver iterations.
|
||||
init_emb = self._get_cached_init_emb(info)
|
||||
HS = history_size
|
||||
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
|
||||
emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1)
|
||||
hist_len = min(HS, init_emb.size(1), H)
|
||||
if hist_len < 1:
|
||||
raise ValueError("rollout requires at least one history step")
|
||||
|
||||
act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1)
|
||||
act_emb_hist = self.action_encoder(act_hist)
|
||||
act_future = act_future.reshape(B * S, act_future.size(2), -1)
|
||||
init_hist = init_emb[:, -hist_len:]
|
||||
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
|
||||
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1))
|
||||
|
||||
for t in range(act_future.size(1)):
|
||||
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
|
||||
if HS > 1:
|
||||
emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1)
|
||||
flat_actions = action_sequence.reshape(B * S, T, -1)
|
||||
action_emb = self.action_encoder(flat_actions)
|
||||
act_hist = action_emb[:, H - hist_len : H]
|
||||
act_future = action_emb[:, H:]
|
||||
|
||||
if HS == 1:
|
||||
emb_hist = init_hist[:, -1:]
|
||||
act_emb_hist = act_hist[:, -1:]
|
||||
|
||||
for t in range(act_future.size(1)):
|
||||
emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||
act_emb_hist = act_future[:, t : t + 1]
|
||||
|
||||
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||
else:
|
||||
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
||||
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
||||
emb_hist[:, :hist_len].copy_(init_hist)
|
||||
act_emb_hist[:, :hist_len].copy_(act_hist)
|
||||
|
||||
history_order = torch.stack(
|
||||
[
|
||||
(torch.arange(HS, device=action_emb.device) + offset) % HS
|
||||
for offset in range(HS)
|
||||
]
|
||||
)
|
||||
filled = hist_len
|
||||
next_slot = hist_len % HS
|
||||
|
||||
for t in range(act_future.size(1)):
|
||||
if filled < HS:
|
||||
emb_view = emb_hist[:, :filled]
|
||||
act_view = act_emb_hist[:, :filled]
|
||||
elif next_slot == 0:
|
||||
emb_view = emb_hist
|
||||
act_view = act_emb_hist
|
||||
else:
|
||||
order = history_order[next_slot]
|
||||
emb_view = emb_hist.index_select(1, order)
|
||||
act_view = act_emb_hist.index_select(1, order)
|
||||
|
||||
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||
next_act_emb = act_future[:, t : t + 1]
|
||||
emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb)
|
||||
act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb)
|
||||
|
||||
if filled < HS:
|
||||
filled += 1
|
||||
next_slot = (next_slot + 1) % HS
|
||||
|
||||
if filled < HS:
|
||||
emb_view = emb_hist[:, :filled]
|
||||
act_view = act_emb_hist[:, :filled]
|
||||
elif next_slot == 0:
|
||||
emb_view = emb_hist
|
||||
act_view = act_emb_hist
|
||||
else:
|
||||
emb_hist = pred_emb
|
||||
order = history_order[next_slot]
|
||||
emb_view = emb_hist.index_select(1, order)
|
||||
act_view = act_emb_hist.index_select(1, order)
|
||||
|
||||
next_act = act_future[:, t : t + 1, :]
|
||||
next_act_emb = self.action_encoder(next_act)
|
||||
if HS > 1:
|
||||
act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1)
|
||||
else:
|
||||
act_emb_hist = next_act_emb
|
||||
|
||||
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
|
||||
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
|
||||
|
||||
return info
|
||||
|
||||
Reference in New Issue
Block a user