更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平
This commit is contained in:
21
jepa.py
21
jepa.py
@@ -189,6 +189,27 @@ class JEPA(nn.Module):
|
||||
|
||||
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||
else:
|
||||
if torch.is_grad_enabled() and action_sequence.requires_grad:
|
||||
emb_slots = init_hist.split(1, dim=1)
|
||||
act_slots = act_hist.split(1, dim=1)
|
||||
|
||||
for t in range(act_future.size(1)):
|
||||
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||
next_act_emb = act_future[:, t : t + 1]
|
||||
|
||||
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
|
||||
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
|
||||
|
||||
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||
info["predicted_emb"] = pred_rollout.reshape(
|
||||
B, S, *pred_rollout.shape[1:]
|
||||
)
|
||||
return info
|
||||
|
||||
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
||||
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
||||
emb_hist[:, :hist_len].copy_(init_hist)
|
||||
|
||||
Reference in New Issue
Block a user