diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py index 63c408a..de45f89 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py @@ -1,5 +1,6 @@ from collections import deque from dataclasses import dataclass +import inspect from pathlib import Path from typing import Any, Protocol from collections.abc import Callable @@ -130,7 +131,37 @@ class BasePolicy: info_dict[k] = v return info_dict - def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]: + def _prepare_info_for_device( + self, + info_dict: dict, + device: torch.device | str | None = None, + ) -> dict[str, torch.Tensor]: + if device is None: + return self._prepare_info(info_dict) + + prepare_info = self._prepare_info + try: + signature = inspect.signature(prepare_info) + except (TypeError, ValueError): + return prepare_info(info_dict) + + accepts_device = any( + param.kind == inspect.Parameter.VAR_KEYWORD + or (param.name == "device" and param.kind in { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + }) + for param in signature.parameters.values() + ) + if accepts_device: + return prepare_info(info_dict, device=device) + return prepare_info(info_dict) + + def _prepare_info( + self, + info_dict: dict, + device: torch.device | str | None = None, + ) -> dict[str, torch.Tensor]: """Pre-process and transform observations. Applies preprocessing (via `self.process`) and transformations (via `self.transform`) @@ -145,6 +176,7 @@ class BasePolicy: Raises: ValueError: If an expected numpy array is missing for processing. """ + target_device = torch.device(device) if device is not None else None for k, v in info_dict.items(): is_numpy = isinstance(v, (np.ndarray | np.generic)) @@ -178,10 +210,20 @@ class BasePolicy: ) v = torch.from_numpy(v) is_numpy = False + moved_for_transform = False + if target_device is not None and target_device.type != "cpu": + if v.device != target_device: + v = v.to(target_device, non_blocking=True) + moved_for_transform = True if k.startswith('pixels') or k.startswith('goal'): # Vectorized image transform on the full batch. v = v.permute(0, 3, 1, 2) - v = self.transform[k](v) + try: + v = self.transform[k](v) + except (NotImplementedError, RuntimeError): + if not moved_for_transform: + raise + v = self.transform[k](v.cpu()) if shape is not None: v = v.reshape(*shape[:2], *v.shape[1:]) @@ -189,7 +231,12 @@ class BasePolicy: if is_numpy and v.dtype.kind not in 'USO': v = torch.from_numpy(v) - if torch.is_tensor(v) and v.device.type == "cpu" and not v.is_pinned(): + if ( + torch.cuda.is_available() + and torch.is_tensor(v) + and v.device.type == "cpu" + and not v.is_pinned() + ): v = v.pin_memory() info_dict[k] = v @@ -313,7 +360,9 @@ class FeedForwardPolicy(BasePolicy): assert 'goal' in info_dict, "'goal' must be provided in info_dict" # Prepare the info dict (transforms and normalizes inputs) - info_dict = self._prepare_info(info_dict) + info_dict = self._prepare_info_for_device( + info_dict, device=next(self.model.parameters()).device + ) # Add goal_pixels key for GCBC model if 'goal' in info_dict: @@ -422,7 +471,9 @@ class WorldModelPolicy(BasePolicy): assert 'pixels' in info_dict, "'pixels' must be provided in info_dict" assert 'goal' in info_dict, "'goal' must be provided in info_dict" - info_dict = self._prepare_info(info_dict) + info_dict = self._prepare_info_for_device( + info_dict, device=self.solver.device + ) info_dict = self._move_info_to_device(info_dict, self.solver.device) outputs = self.solver( diff --git a/eval.py b/eval.py index 75fc2c3..4c83bb1 100644 --- a/eval.py +++ b/eval.py @@ -91,6 +91,10 @@ def get_compile_warmup_cfg(cfg): "num_eval": 1, } cfg_warmup = cfg.get("compile_warmup") + if cfg_warmup is None: + cfg_eval = cfg.get("eval") + if cfg_eval is not None: + cfg_warmup = cfg_eval.get("compile_warmup") if cfg_warmup is not None: warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True)) return warmup_cfg @@ -464,6 +468,7 @@ def run_eval_subset( *, device_override: str | None = None, enable_profile: bool = True, + enable_compile_warmup: bool = False, before_evaluate=None, ): local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False)) @@ -545,6 +550,8 @@ def run_eval_subset( with get_eval_grad_context(solver): with profiler_ctx as profiler: with inference_ctx: + if enable_compile_warmup: + maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx) with torch.profiler.record_function("eval.world_evaluate_from_dataset"): metrics = evaluate_subset(eval_episodes, eval_start_idx) if str(device).startswith("cuda") and torch.cuda.is_available(): @@ -596,8 +603,8 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx): with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir: run_eval_subset( warmup_eval_cfg, - eval_episodes[:warmup_count].tolist(), - eval_start_idx[:warmup_count].tolist(), + list(eval_episodes[:warmup_count]), + list(eval_start_idx[:warmup_count]), Path(tmpdir), device_override=device_override, enable_profile=False,