继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel

solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把
  plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补
  了输入张量的 contiguous 处理;
This commit is contained in:
qihuanye
2026-04-09 12:33:50 +00:00
parent 995cd8cfec
commit 25e4ddb628
4 changed files with 432 additions and 29 deletions

42
jepa.py
View File

@@ -49,18 +49,33 @@ class JEPA(nn.Module):
str(tensor.device),
tensor.dtype,
tuple(tensor.shape),
tuple(tensor.stride()),
tensor.storage_offset(),
tensor.data_ptr(),
version,
)
def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device):
def _get_cached_device_tensor(
self,
key: str,
tensor: torch.Tensor,
device: torch.device,
*,
ensure_contiguous: bool = False,
):
self._ensure_runtime_caches()
signature = (self._tensor_signature(tensor), str(device))
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
return tensor
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
cached = self._cached_device_tensors.get(key)
if cached is None or cached[0] != signature:
prepared = tensor.to(device, non_blocking=True)
if ensure_contiguous and not prepared.is_contiguous():
prepared = prepared.contiguous()
self._cached_device_tensors[key] = (
signature,
tensor.to(device, non_blocking=True),
prepared,
)
return self._cached_device_tensors[key][1]
@@ -68,8 +83,13 @@ class JEPA(nn.Module):
for key, value in list(info_dict.items()):
if key.startswith("_lewm_"):
continue
if torch.is_tensor(value) and value.device != device:
info_dict[key] = self._get_cached_device_tensor(key, value, device)
if torch.is_tensor(value):
info_dict[key] = self._get_cached_device_tensor(
key,
value,
device,
ensure_contiguous=True,
)
return info_dict
def _get_cached_init_emb(self, info_dict: dict):
@@ -152,9 +172,9 @@ class JEPA(nn.Module):
init_hist = init_emb[:, -hist_len:]
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1))
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
flat_actions = action_sequence.reshape(B * S, T, -1)
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
action_emb = self.action_encoder(flat_actions)
act_hist = action_emb[:, H - hist_len : H]
act_future = action_emb[:, H:]
@@ -245,8 +265,12 @@ class JEPA(nn.Module):
self._ensure_runtime_caches()
device = next(self.parameters()).device
info_dict = self._ensure_info_device(info_dict, device)
if action_candidates.device != device:
action_candidates = action_candidates.to(device, non_blocking=True)
action_candidates = self._get_cached_device_tensor(
"_lewm_action_candidates",
action_candidates,
device,
ensure_contiguous=True,
)
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
info_dict = self.rollout(info_dict, action_candidates)