继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel
solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把 plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补 了输入张量的 contiguous 处理;
This commit is contained in:
42
jepa.py
42
jepa.py
@@ -49,18 +49,33 @@ class JEPA(nn.Module):
|
||||
str(tensor.device),
|
||||
tensor.dtype,
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
tensor.storage_offset(),
|
||||
tensor.data_ptr(),
|
||||
version,
|
||||
)
|
||||
|
||||
def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device):
|
||||
def _get_cached_device_tensor(
|
||||
self,
|
||||
key: str,
|
||||
tensor: torch.Tensor,
|
||||
device: torch.device,
|
||||
*,
|
||||
ensure_contiguous: bool = False,
|
||||
):
|
||||
self._ensure_runtime_caches()
|
||||
signature = (self._tensor_signature(tensor), str(device))
|
||||
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
|
||||
return tensor
|
||||
|
||||
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
|
||||
cached = self._cached_device_tensors.get(key)
|
||||
if cached is None or cached[0] != signature:
|
||||
prepared = tensor.to(device, non_blocking=True)
|
||||
if ensure_contiguous and not prepared.is_contiguous():
|
||||
prepared = prepared.contiguous()
|
||||
self._cached_device_tensors[key] = (
|
||||
signature,
|
||||
tensor.to(device, non_blocking=True),
|
||||
prepared,
|
||||
)
|
||||
return self._cached_device_tensors[key][1]
|
||||
|
||||
@@ -68,8 +83,13 @@ class JEPA(nn.Module):
|
||||
for key, value in list(info_dict.items()):
|
||||
if key.startswith("_lewm_"):
|
||||
continue
|
||||
if torch.is_tensor(value) and value.device != device:
|
||||
info_dict[key] = self._get_cached_device_tensor(key, value, device)
|
||||
if torch.is_tensor(value):
|
||||
info_dict[key] = self._get_cached_device_tensor(
|
||||
key,
|
||||
value,
|
||||
device,
|
||||
ensure_contiguous=True,
|
||||
)
|
||||
return info_dict
|
||||
|
||||
def _get_cached_init_emb(self, info_dict: dict):
|
||||
@@ -152,9 +172,9 @@ class JEPA(nn.Module):
|
||||
|
||||
init_hist = init_emb[:, -hist_len:]
|
||||
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
|
||||
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1))
|
||||
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
|
||||
|
||||
flat_actions = action_sequence.reshape(B * S, T, -1)
|
||||
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
|
||||
action_emb = self.action_encoder(flat_actions)
|
||||
act_hist = action_emb[:, H - hist_len : H]
|
||||
act_future = action_emb[:, H:]
|
||||
@@ -245,8 +265,12 @@ class JEPA(nn.Module):
|
||||
self._ensure_runtime_caches()
|
||||
device = next(self.parameters()).device
|
||||
info_dict = self._ensure_info_device(info_dict, device)
|
||||
if action_candidates.device != device:
|
||||
action_candidates = action_candidates.to(device, non_blocking=True)
|
||||
action_candidates = self._get_cached_device_tensor(
|
||||
"_lewm_action_candidates",
|
||||
action_candidates,
|
||||
device,
|
||||
ensure_contiguous=True,
|
||||
)
|
||||
|
||||
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
|
||||
info_dict = self.rollout(info_dict, action_candidates)
|
||||
|
||||
Reference in New Issue
Block a user