Files
lewm/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py

81 lines
2.4 KiB
Python

from typing import Any, Protocol, runtime_checkable
import gymnasium as gym
import numpy as np
import torch
class Costable(Protocol):
"""Protocol for world model cost functions."""
def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor:
"""Compute the cost criterion for action candidates.
Args:
info_dict: Dictionary containing environment state information.
action_candidates: Tensor of proposed actions.
Returns:
A tensor of cost values for each action candidate.
"""
...
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""Compute cost for given action candidates based on info dictionary.
Args:
info_dict: Dictionary containing environment state information.
action_candidates: Tensor of proposed actions.
Returns:
A tensor of cost values for each action candidate.
"""
...
@runtime_checkable
class Solver(Protocol):
"""Protocol for model-based planning solvers."""
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment and planning specifications.
Args:
action_space: The action space of the environment.
n_envs: Number of parallel environments.
config: Planning configuration object.
"""
...
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
...
@property
def n_envs(self) -> int:
"""Number of parallel environments being planned for."""
...
@property
def horizon(self) -> int:
"""Planning horizon length in timesteps."""
...
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning optimization problem to find optimal actions.
Args:
info_dict: Dictionary containing environment state information.
init_action: Optional initial action sequence to warm-start the solver.
Returns:
Dictionary containing optimized actions and other solver-specific info.
"""
...