继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel

solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把
  plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补
  了输入张量的 contiguous 处理;
This commit is contained in:
qihuanye
2026-04-09 12:33:50 +00:00
parent 995cd8cfec
commit 25e4ddb628
4 changed files with 432 additions and 29 deletions

View File

@@ -80,14 +80,24 @@ class CEMSolver:
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
device = torch.device(self.device)
var = self.var_scale * torch.ones(
[self.n_envs, self.horizon, self.action_dim],
device=device,
)
mean = (
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
if actions is None
else actions
)
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
new_mean = torch.zeros(
[self.n_envs, remaining, self.action_dim],
device=mean.device,
)
mean = torch.cat([mean, new_mean], dim=1)
return mean, var
@@ -105,8 +115,10 @@ class CEMSolver:
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
if mean.device != torch.device(self.device):
mean = mean.to(self.device, non_blocking=True)
if var.device != torch.device(self.device):
var = var.to(self.device, non_blocking=True)
total_envs = self.n_envs
@@ -138,6 +150,7 @@ class CEMSolver:
# Optimization Loop
final_batch_cost = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
for step in range(self.n_steps):
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
@@ -172,8 +185,6 @@ class CEMSolver:
# Gather Top-K Candidates
# We need to select the specific candidates corresponding to topk_inds
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Indexing: candidates[batch_idx, sample_idx]
# Result shape: (Batch, K, Horizon, Dim)
topk_candidates = candidates[batch_indices, topk_inds]
@@ -184,18 +195,19 @@ class CEMSolver:
# Update final cost for logging
# We average the cost of the top elites
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
final_batch_cost = topk_vals.mean(dim=1).detach()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
# Store history/metadata
outputs["costs"].extend(final_batch_cost)
outputs["costs"].append(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
outputs["actions"] = mean.detach()
outputs["mean"] = [mean.detach()]
outputs["var"] = [var.detach()]
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
return outputs