没啥用

This commit is contained in:
qihuanye
2026-05-18 02:09:19 +08:00
parent 28f2fba0e8
commit a639fdefca
2 changed files with 65 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
from collections import deque
from dataclasses import dataclass
import inspect
from pathlib import Path
from typing import Any, Protocol
from collections.abc import Callable
@@ -130,7 +131,37 @@ class BasePolicy:
info_dict[k] = v
return info_dict
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
def _prepare_info_for_device(
self,
info_dict: dict,
device: torch.device | str | None = None,
) -> dict[str, torch.Tensor]:
if device is None:
return self._prepare_info(info_dict)
prepare_info = self._prepare_info
try:
signature = inspect.signature(prepare_info)
except (TypeError, ValueError):
return prepare_info(info_dict)
accepts_device = any(
param.kind == inspect.Parameter.VAR_KEYWORD
or (param.name == "device" and param.kind in {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
})
for param in signature.parameters.values()
)
if accepts_device:
return prepare_info(info_dict, device=device)
return prepare_info(info_dict)
def _prepare_info(
self,
info_dict: dict,
device: torch.device | str | None = None,
) -> dict[str, torch.Tensor]:
"""Pre-process and transform observations.
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
@@ -145,6 +176,7 @@ class BasePolicy:
Raises:
ValueError: If an expected numpy array is missing for processing.
"""
target_device = torch.device(device) if device is not None else None
for k, v in info_dict.items():
is_numpy = isinstance(v, (np.ndarray | np.generic))
@@ -178,10 +210,20 @@ class BasePolicy:
)
v = torch.from_numpy(v)
is_numpy = False
moved_for_transform = False
if target_device is not None and target_device.type != "cpu":
if v.device != target_device:
v = v.to(target_device, non_blocking=True)
moved_for_transform = True
if k.startswith('pixels') or k.startswith('goal'):
# Vectorized image transform on the full batch.
v = v.permute(0, 3, 1, 2)
v = self.transform[k](v)
try:
v = self.transform[k](v)
except (NotImplementedError, RuntimeError):
if not moved_for_transform:
raise
v = self.transform[k](v.cpu())
if shape is not None:
v = v.reshape(*shape[:2], *v.shape[1:])
@@ -189,7 +231,12 @@ class BasePolicy:
if is_numpy and v.dtype.kind not in 'USO':
v = torch.from_numpy(v)
if torch.is_tensor(v) and v.device.type == "cpu" and not v.is_pinned():
if (
torch.cuda.is_available()
and torch.is_tensor(v)
and v.device.type == "cpu"
and not v.is_pinned()
):
v = v.pin_memory()
info_dict[k] = v
@@ -313,7 +360,9 @@ class FeedForwardPolicy(BasePolicy):
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
# Prepare the info dict (transforms and normalizes inputs)
info_dict = self._prepare_info(info_dict)
info_dict = self._prepare_info_for_device(
info_dict, device=next(self.model.parameters()).device
)
# Add goal_pixels key for GCBC model
if 'goal' in info_dict:
@@ -422,7 +471,9 @@ class WorldModelPolicy(BasePolicy):
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info(info_dict)
info_dict = self._prepare_info_for_device(
info_dict, device=self.solver.device
)
info_dict = self._move_info_to_device(info_dict, self.solver.device)
outputs = self.solver(

11
eval.py
View File

@@ -91,6 +91,10 @@ def get_compile_warmup_cfg(cfg):
"num_eval": 1,
}
cfg_warmup = cfg.get("compile_warmup")
if cfg_warmup is None:
cfg_eval = cfg.get("eval")
if cfg_eval is not None:
cfg_warmup = cfg_eval.get("compile_warmup")
if cfg_warmup is not None:
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
return warmup_cfg
@@ -464,6 +468,7 @@ def run_eval_subset(
*,
device_override: str | None = None,
enable_profile: bool = True,
enable_compile_warmup: bool = False,
before_evaluate=None,
):
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
@@ -545,6 +550,8 @@ def run_eval_subset(
with get_eval_grad_context(solver):
with profiler_ctx as profiler:
with inference_ctx:
if enable_compile_warmup:
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = evaluate_subset(eval_episodes, eval_start_idx)
if str(device).startswith("cuda") and torch.cuda.is_available():
@@ -596,8 +603,8 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
run_eval_subset(
warmup_eval_cfg,
eval_episodes[:warmup_count].tolist(),
eval_start_idx[:warmup_count].tolist(),
list(eval_episodes[:warmup_count]),
list(eval_start_idx[:warmup_count]),
Path(tmpdir),
device_override=device_override,
enable_profile=False,