加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu
This commit is contained in:
@@ -76,6 +76,24 @@ class CEMSolver:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_active_mask(
|
||||
active_mask: torch.Tensor | np.ndarray | None,
|
||||
n_envs: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor | None:
|
||||
if active_mask is None:
|
||||
return None
|
||||
if not torch.is_tensor(active_mask):
|
||||
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
|
||||
else:
|
||||
active_mask = active_mask.to(device=device, dtype=torch.bool)
|
||||
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
|
||||
raise ValueError(
|
||||
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
|
||||
)
|
||||
return active_mask
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -103,7 +121,10 @@ class CEMSolver:
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Solve the planning problem using Cross Entropy Method."""
|
||||
start_time = time.time()
|
||||
@@ -119,6 +140,17 @@ class CEMSolver:
|
||||
mean = mean.to(self.device, non_blocking=True)
|
||||
if var.device != torch.device(self.device):
|
||||
var = var.to(self.device, non_blocking=True)
|
||||
active_mask = self._normalize_active_mask(
|
||||
active_mask, self.n_envs, torch.device(self.device)
|
||||
)
|
||||
|
||||
if active_mask is not None and not torch.any(active_mask):
|
||||
return {
|
||||
"costs": [],
|
||||
"actions": mean.detach(),
|
||||
"mean": [mean.detach()],
|
||||
"var": [var.detach()],
|
||||
}
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
@@ -148,6 +180,25 @@ class CEMSolver:
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
if active_mask is not None:
|
||||
batch_mask = active_mask[start_idx:end_idx]
|
||||
if not torch.any(batch_mask):
|
||||
outputs["costs"].append(
|
||||
torch.full((current_bs,), float("nan"), device=self.device)
|
||||
)
|
||||
continue
|
||||
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
|
||||
active_local_np = active_local.detach().cpu().numpy()
|
||||
batch_mean = batch_mean[active_local]
|
||||
batch_var = batch_var[active_local]
|
||||
expanded_infos = {
|
||||
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
|
||||
for k, v in expanded_infos.items()
|
||||
}
|
||||
current_bs = int(active_local.numel())
|
||||
else:
|
||||
active_local = None
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
|
||||
@@ -198,13 +249,26 @@ class CEMSolver:
|
||||
final_batch_cost = topk_vals.mean(dim=1).detach()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
if active_mask is not None:
|
||||
global_indices = start_idx + active_local
|
||||
mean[global_indices] = batch_mean
|
||||
var[global_indices] = batch_var
|
||||
batch_costs = torch.full(
|
||||
(end_idx - start_idx,), float("nan"), device=self.device
|
||||
)
|
||||
batch_costs[active_local] = final_batch_cost
|
||||
else:
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
batch_costs = final_batch_cost
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].append(final_batch_cost)
|
||||
outputs["costs"].append(batch_costs)
|
||||
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
if outputs["costs"]:
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
else:
|
||||
outputs["costs"] = []
|
||||
outputs["actions"] = mean.detach()
|
||||
outputs["mean"] = [mean.detach()]
|
||||
outputs["var"] = [var.detach()]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -61,7 +62,12 @@ class Solver(Protocol):
|
||||
"""Planning horizon length in timesteps."""
|
||||
...
|
||||
|
||||
def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict:
|
||||
def solve(
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Solve the planning optimization problem to find optimal actions.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user