加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu

This commit is contained in:
qihuanye
2026-05-18 00:48:59 +08:00
parent 113e591899
commit 28f2fba0e8
5 changed files with 138 additions and 16 deletions

View File

@@ -76,6 +76,24 @@ class CEMSolver:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
@staticmethod
def _normalize_active_mask(
active_mask: torch.Tensor | np.ndarray | None,
n_envs: int,
device: torch.device,
) -> torch.Tensor | None:
if active_mask is None:
return None
if not torch.is_tensor(active_mask):
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
else:
active_mask = active_mask.to(device=device, dtype=torch.bool)
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
raise ValueError(
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
)
return active_mask
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -103,7 +121,10 @@ class CEMSolver:
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning problem using Cross Entropy Method."""
start_time = time.time()
@@ -119,6 +140,17 @@ class CEMSolver:
mean = mean.to(self.device, non_blocking=True)
if var.device != torch.device(self.device):
var = var.to(self.device, non_blocking=True)
active_mask = self._normalize_active_mask(
active_mask, self.n_envs, torch.device(self.device)
)
if active_mask is not None and not torch.any(active_mask):
return {
"costs": [],
"actions": mean.detach(),
"mean": [mean.detach()],
"var": [var.detach()],
}
total_envs = self.n_envs
@@ -148,6 +180,25 @@ class CEMSolver:
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
if active_mask is not None:
batch_mask = active_mask[start_idx:end_idx]
if not torch.any(batch_mask):
outputs["costs"].append(
torch.full((current_bs,), float("nan"), device=self.device)
)
continue
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
active_local_np = active_local.detach().cpu().numpy()
batch_mean = batch_mean[active_local]
batch_var = batch_var[active_local]
expanded_infos = {
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
for k, v in expanded_infos.items()
}
current_bs = int(active_local.numel())
else:
active_local = None
# Optimization Loop
final_batch_cost = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
@@ -198,13 +249,26 @@ class CEMSolver:
final_batch_cost = topk_vals.mean(dim=1).detach()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
if active_mask is not None:
global_indices = start_idx + active_local
mean[global_indices] = batch_mean
var[global_indices] = batch_var
batch_costs = torch.full(
(end_idx - start_idx,), float("nan"), device=self.device
)
batch_costs[active_local] = final_batch_cost
else:
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
batch_costs = final_batch_cost
# Store history/metadata
outputs["costs"].append(final_batch_cost)
outputs["costs"].append(batch_costs)
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
if outputs["costs"]:
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
else:
outputs["costs"] = []
outputs["actions"] = mean.detach()
outputs["mean"] = [mean.detach()]
outputs["var"] = [var.detach()]

View File

@@ -1,6 +1,7 @@
from typing import Any, Protocol, runtime_checkable
import gymnasium as gym
import numpy as np
import torch
@@ -61,7 +62,12 @@ class Solver(Protocol):
"""Planning horizon length in timesteps."""
...
def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict:
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning optimization problem to find optimal actions.
Args: