更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平

This commit is contained in:
qihuanye
2026-05-04 07:55:13 +00:00
parent 4c3fdbcce6
commit cf43af0729
8 changed files with 558 additions and 3 deletions

21
jepa.py
View File

@@ -189,6 +189,27 @@ class JEPA(nn.Module):
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
else:
if torch.is_grad_enabled() and action_sequence.requires_grad:
emb_slots = init_hist.split(1, dim=1)
act_slots = act_hist.split(1, dim=1)
for t in range(act_future.size(1)):
emb_view = torch.cat(emb_slots[-HS:], dim=1)
act_view = torch.cat(act_slots[-HS:], dim=1)
pred_emb = self.predict(emb_view, act_view)[:, -1:]
next_act_emb = act_future[:, t : t + 1]
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
emb_view = torch.cat(emb_slots[-HS:], dim=1)
act_view = torch.cat(act_slots[-HS:], dim=1)
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
info["predicted_emb"] = pred_rollout.reshape(
B, S, *pred_rollout.shape[1:]
)
return info
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
emb_hist[:, :hist_len].copy_(init_hist)