没啥用

This commit is contained in:
qihuanye
2026-05-18 02:09:19 +08:00
parent 28f2fba0e8
commit a639fdefca
2 changed files with 65 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
import inspect
from pathlib import Path from pathlib import Path
from typing import Any, Protocol from typing import Any, Protocol
from collections.abc import Callable from collections.abc import Callable
@@ -130,7 +131,37 @@ class BasePolicy:
info_dict[k] = v info_dict[k] = v
return info_dict return info_dict
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]: def _prepare_info_for_device(
self,
info_dict: dict,
device: torch.device | str | None = None,
) -> dict[str, torch.Tensor]:
if device is None:
return self._prepare_info(info_dict)
prepare_info = self._prepare_info
try:
signature = inspect.signature(prepare_info)
except (TypeError, ValueError):
return prepare_info(info_dict)
accepts_device = any(
param.kind == inspect.Parameter.VAR_KEYWORD
or (param.name == "device" and param.kind in {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
})
for param in signature.parameters.values()
)
if accepts_device:
return prepare_info(info_dict, device=device)
return prepare_info(info_dict)
def _prepare_info(
self,
info_dict: dict,
device: torch.device | str | None = None,
) -> dict[str, torch.Tensor]:
"""Pre-process and transform observations. """Pre-process and transform observations.
Applies preprocessing (via `self.process`) and transformations (via `self.transform`) Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
@@ -145,6 +176,7 @@ class BasePolicy:
Raises: Raises:
ValueError: If an expected numpy array is missing for processing. ValueError: If an expected numpy array is missing for processing.
""" """
target_device = torch.device(device) if device is not None else None
for k, v in info_dict.items(): for k, v in info_dict.items():
is_numpy = isinstance(v, (np.ndarray | np.generic)) is_numpy = isinstance(v, (np.ndarray | np.generic))
@@ -178,10 +210,20 @@ class BasePolicy:
) )
v = torch.from_numpy(v) v = torch.from_numpy(v)
is_numpy = False is_numpy = False
moved_for_transform = False
if target_device is not None and target_device.type != "cpu":
if v.device != target_device:
v = v.to(target_device, non_blocking=True)
moved_for_transform = True
if k.startswith('pixels') or k.startswith('goal'): if k.startswith('pixels') or k.startswith('goal'):
# Vectorized image transform on the full batch. # Vectorized image transform on the full batch.
v = v.permute(0, 3, 1, 2) v = v.permute(0, 3, 1, 2)
v = self.transform[k](v) try:
v = self.transform[k](v)
except (NotImplementedError, RuntimeError):
if not moved_for_transform:
raise
v = self.transform[k](v.cpu())
if shape is not None: if shape is not None:
v = v.reshape(*shape[:2], *v.shape[1:]) v = v.reshape(*shape[:2], *v.shape[1:])
@@ -189,7 +231,12 @@ class BasePolicy:
if is_numpy and v.dtype.kind not in 'USO': if is_numpy and v.dtype.kind not in 'USO':
v = torch.from_numpy(v) v = torch.from_numpy(v)
if torch.is_tensor(v) and v.device.type == "cpu" and not v.is_pinned(): if (
torch.cuda.is_available()
and torch.is_tensor(v)
and v.device.type == "cpu"
and not v.is_pinned()
):
v = v.pin_memory() v = v.pin_memory()
info_dict[k] = v info_dict[k] = v
@@ -313,7 +360,9 @@ class FeedForwardPolicy(BasePolicy):
assert 'goal' in info_dict, "'goal' must be provided in info_dict" assert 'goal' in info_dict, "'goal' must be provided in info_dict"
# Prepare the info dict (transforms and normalizes inputs) # Prepare the info dict (transforms and normalizes inputs)
info_dict = self._prepare_info(info_dict) info_dict = self._prepare_info_for_device(
info_dict, device=next(self.model.parameters()).device
)
# Add goal_pixels key for GCBC model # Add goal_pixels key for GCBC model
if 'goal' in info_dict: if 'goal' in info_dict:
@@ -422,7 +471,9 @@ class WorldModelPolicy(BasePolicy):
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict" assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict" assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info(info_dict) info_dict = self._prepare_info_for_device(
info_dict, device=self.solver.device
)
info_dict = self._move_info_to_device(info_dict, self.solver.device) info_dict = self._move_info_to_device(info_dict, self.solver.device)
outputs = self.solver( outputs = self.solver(

11
eval.py
View File

@@ -91,6 +91,10 @@ def get_compile_warmup_cfg(cfg):
"num_eval": 1, "num_eval": 1,
} }
cfg_warmup = cfg.get("compile_warmup") cfg_warmup = cfg.get("compile_warmup")
if cfg_warmup is None:
cfg_eval = cfg.get("eval")
if cfg_eval is not None:
cfg_warmup = cfg_eval.get("compile_warmup")
if cfg_warmup is not None: if cfg_warmup is not None:
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True)) warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
return warmup_cfg return warmup_cfg
@@ -464,6 +468,7 @@ def run_eval_subset(
*, *,
device_override: str | None = None, device_override: str | None = None,
enable_profile: bool = True, enable_profile: bool = True,
enable_compile_warmup: bool = False,
before_evaluate=None, before_evaluate=None,
): ):
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False)) local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
@@ -545,6 +550,8 @@ def run_eval_subset(
with get_eval_grad_context(solver): with get_eval_grad_context(solver):
with profiler_ctx as profiler: with profiler_ctx as profiler:
with inference_ctx: with inference_ctx:
if enable_compile_warmup:
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
with torch.profiler.record_function("eval.world_evaluate_from_dataset"): with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = evaluate_subset(eval_episodes, eval_start_idx) metrics = evaluate_subset(eval_episodes, eval_start_idx)
if str(device).startswith("cuda") and torch.cuda.is_available(): if str(device).startswith("cuda") and torch.cuda.is_available():
@@ -596,8 +603,8 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir: with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
run_eval_subset( run_eval_subset(
warmup_eval_cfg, warmup_eval_cfg,
eval_episodes[:warmup_count].tolist(), list(eval_episodes[:warmup_count]),
eval_start_idx[:warmup_count].tolist(), list(eval_start_idx[:warmup_count]),
Path(tmpdir), Path(tmpdir),
device_override=device_override, device_override=device_override,
enable_profile=False, enable_profile=False,