from typing import Any, Protocol, runtime_checkable import gymnasium as gym import numpy as np import torch class Costable(Protocol): """Protocol for world model cost functions.""" def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: """Compute the cost criterion for action candidates. Args: info_dict: Dictionary containing environment state information. action_candidates: Tensor of proposed actions. Returns: A tensor of cost values for each action candidate. """ ... def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover """Compute cost for given action candidates based on info dictionary. Args: info_dict: Dictionary containing environment state information. action_candidates: Tensor of proposed actions. Returns: A tensor of cost values for each action candidate. """ ... @runtime_checkable class Solver(Protocol): """Protocol for model-based planning solvers.""" def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: """Configure the solver with environment and planning specifications. Args: action_space: The action space of the environment. n_envs: Number of parallel environments. config: Planning configuration object. """ ... @property def action_dim(self) -> int: """Flattened action dimension including action_block grouping.""" ... @property def n_envs(self) -> int: """Number of parallel environments being planned for.""" ... @property def horizon(self) -> int: """Planning horizon length in timesteps.""" ... def solve( self, info_dict: dict, init_action: torch.Tensor | None = None, active_mask: torch.Tensor | np.ndarray | None = None, ) -> dict: """Solve the planning optimization problem to find optimal actions. Args: info_dict: Dictionary containing environment state information. init_action: Optional initial action sequence to warm-start the solver. Returns: Dictionary containing optimized actions and other solver-specific info. """ ...