继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel
solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把 plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补 了输入张量的 contiguous 处理;
This commit is contained in:
@@ -122,8 +122,12 @@ class BasePolicy:
|
||||
) -> dict[str, Any]:
|
||||
target = torch.device(device)
|
||||
for k, v in info_dict.items():
|
||||
if torch.is_tensor(v) and v.device != target:
|
||||
info_dict[k] = v.to(target, non_blocking=True)
|
||||
if torch.is_tensor(v):
|
||||
if v.device != target:
|
||||
v = v.to(target, non_blocking=True)
|
||||
if not v.is_contiguous():
|
||||
v = v.contiguous()
|
||||
info_dict[k] = v
|
||||
return info_dict
|
||||
|
||||
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
|
||||
@@ -415,18 +419,21 @@ class WorldModelPolicy(BasePolicy):
|
||||
keep_horizon = self.cfg.receding_horizon
|
||||
plan = actions[:, :keep_horizon]
|
||||
rest = actions[:, keep_horizon:]
|
||||
self._next_init = rest if self.cfg.warm_start else None
|
||||
self._next_init = rest.contiguous() if self.cfg.warm_start else None
|
||||
|
||||
# frameskip back to timestep
|
||||
plan = plan.reshape(
|
||||
self.env.num_envs, self.flatten_receding_horizon, -1
|
||||
)
|
||||
).contiguous()
|
||||
|
||||
self._action_buffer.extend(plan.transpose(0, 1))
|
||||
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
|
||||
|
||||
action = self._action_buffer.popleft()
|
||||
action = action.reshape(*self.env.action_space.shape)
|
||||
action = action.numpy()
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
else:
|
||||
action = np.asarray(action)
|
||||
|
||||
# post-process action
|
||||
if 'action' in self.process:
|
||||
|
||||
@@ -80,14 +80,24 @@ class CEMSolver:
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
device = torch.device(self.device)
|
||||
var = self.var_scale * torch.ones(
|
||||
[self.n_envs, self.horizon, self.action_dim],
|
||||
device=device,
|
||||
)
|
||||
mean = (
|
||||
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
|
||||
if actions is None
|
||||
else actions
|
||||
)
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
new_mean = torch.zeros(
|
||||
[self.n_envs, remaining, self.action_dim],
|
||||
device=mean.device,
|
||||
)
|
||||
mean = torch.cat([mean, new_mean], dim=1)
|
||||
|
||||
return mean, var
|
||||
|
||||
@@ -105,8 +115,10 @@ class CEMSolver:
|
||||
|
||||
# -- initialize the action distribution globally
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
if mean.device != torch.device(self.device):
|
||||
mean = mean.to(self.device, non_blocking=True)
|
||||
if var.device != torch.device(self.device):
|
||||
var = var.to(self.device, non_blocking=True)
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
@@ -138,6 +150,7 @@ class CEMSolver:
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
|
||||
@@ -172,8 +185,6 @@ class CEMSolver:
|
||||
|
||||
# Gather Top-K Candidates
|
||||
# We need to select the specific candidates corresponding to topk_inds
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
|
||||
# Indexing: candidates[batch_idx, sample_idx]
|
||||
# Result shape: (Batch, K, Horizon, Dim)
|
||||
topk_candidates = candidates[batch_indices, topk_inds]
|
||||
@@ -184,18 +195,19 @@ class CEMSolver:
|
||||
|
||||
# Update final cost for logging
|
||||
# We average the cost of the top elites
|
||||
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
|
||||
final_batch_cost = topk_vals.mean(dim=1).detach()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
outputs["costs"].append(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
outputs["actions"] = mean.detach()
|
||||
outputs["mean"] = [mean.detach()]
|
||||
outputs["var"] = [var.detach()]
|
||||
|
||||
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user