没啥用
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Callable
|
||||
@@ -130,7 +131,37 @@ class BasePolicy:
|
||||
info_dict[k] = v
|
||||
return info_dict
|
||||
|
||||
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
|
||||
def _prepare_info_for_device(
|
||||
self,
|
||||
info_dict: dict,
|
||||
device: torch.device | str | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if device is None:
|
||||
return self._prepare_info(info_dict)
|
||||
|
||||
prepare_info = self._prepare_info
|
||||
try:
|
||||
signature = inspect.signature(prepare_info)
|
||||
except (TypeError, ValueError):
|
||||
return prepare_info(info_dict)
|
||||
|
||||
accepts_device = any(
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
or (param.name == "device" and param.kind in {
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
})
|
||||
for param in signature.parameters.values()
|
||||
)
|
||||
if accepts_device:
|
||||
return prepare_info(info_dict, device=device)
|
||||
return prepare_info(info_dict)
|
||||
|
||||
def _prepare_info(
|
||||
self,
|
||||
info_dict: dict,
|
||||
device: torch.device | str | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Pre-process and transform observations.
|
||||
|
||||
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
|
||||
@@ -145,6 +176,7 @@ class BasePolicy:
|
||||
Raises:
|
||||
ValueError: If an expected numpy array is missing for processing.
|
||||
"""
|
||||
target_device = torch.device(device) if device is not None else None
|
||||
for k, v in info_dict.items():
|
||||
is_numpy = isinstance(v, (np.ndarray | np.generic))
|
||||
|
||||
@@ -178,10 +210,20 @@ class BasePolicy:
|
||||
)
|
||||
v = torch.from_numpy(v)
|
||||
is_numpy = False
|
||||
moved_for_transform = False
|
||||
if target_device is not None and target_device.type != "cpu":
|
||||
if v.device != target_device:
|
||||
v = v.to(target_device, non_blocking=True)
|
||||
moved_for_transform = True
|
||||
if k.startswith('pixels') or k.startswith('goal'):
|
||||
# Vectorized image transform on the full batch.
|
||||
v = v.permute(0, 3, 1, 2)
|
||||
v = self.transform[k](v)
|
||||
try:
|
||||
v = self.transform[k](v)
|
||||
except (NotImplementedError, RuntimeError):
|
||||
if not moved_for_transform:
|
||||
raise
|
||||
v = self.transform[k](v.cpu())
|
||||
|
||||
if shape is not None:
|
||||
v = v.reshape(*shape[:2], *v.shape[1:])
|
||||
@@ -189,7 +231,12 @@ class BasePolicy:
|
||||
if is_numpy and v.dtype.kind not in 'USO':
|
||||
v = torch.from_numpy(v)
|
||||
|
||||
if torch.is_tensor(v) and v.device.type == "cpu" and not v.is_pinned():
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.is_tensor(v)
|
||||
and v.device.type == "cpu"
|
||||
and not v.is_pinned()
|
||||
):
|
||||
v = v.pin_memory()
|
||||
|
||||
info_dict[k] = v
|
||||
@@ -313,7 +360,9 @@ class FeedForwardPolicy(BasePolicy):
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
# Prepare the info dict (transforms and normalizes inputs)
|
||||
info_dict = self._prepare_info(info_dict)
|
||||
info_dict = self._prepare_info_for_device(
|
||||
info_dict, device=next(self.model.parameters()).device
|
||||
)
|
||||
|
||||
# Add goal_pixels key for GCBC model
|
||||
if 'goal' in info_dict:
|
||||
@@ -422,7 +471,9 @@ class WorldModelPolicy(BasePolicy):
|
||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
info_dict = self._prepare_info(info_dict)
|
||||
info_dict = self._prepare_info_for_device(
|
||||
info_dict, device=self.solver.device
|
||||
)
|
||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||
|
||||
outputs = self.solver(
|
||||
|
||||
11
eval.py
11
eval.py
@@ -91,6 +91,10 @@ def get_compile_warmup_cfg(cfg):
|
||||
"num_eval": 1,
|
||||
}
|
||||
cfg_warmup = cfg.get("compile_warmup")
|
||||
if cfg_warmup is None:
|
||||
cfg_eval = cfg.get("eval")
|
||||
if cfg_eval is not None:
|
||||
cfg_warmup = cfg_eval.get("compile_warmup")
|
||||
if cfg_warmup is not None:
|
||||
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
|
||||
return warmup_cfg
|
||||
@@ -464,6 +468,7 @@ def run_eval_subset(
|
||||
*,
|
||||
device_override: str | None = None,
|
||||
enable_profile: bool = True,
|
||||
enable_compile_warmup: bool = False,
|
||||
before_evaluate=None,
|
||||
):
|
||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
@@ -545,6 +550,8 @@ def run_eval_subset(
|
||||
with get_eval_grad_context(solver):
|
||||
with profiler_ctx as profiler:
|
||||
with inference_ctx:
|
||||
if enable_compile_warmup:
|
||||
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||
metrics = evaluate_subset(eval_episodes, eval_start_idx)
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
@@ -596,8 +603,8 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
||||
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
|
||||
run_eval_subset(
|
||||
warmup_eval_cfg,
|
||||
eval_episodes[:warmup_count].tolist(),
|
||||
eval_start_idx[:warmup_count].tolist(),
|
||||
list(eval_episodes[:warmup_count]),
|
||||
list(eval_start_idx[:warmup_count]),
|
||||
Path(tmpdir),
|
||||
device_override=device_override,
|
||||
enable_profile=False,
|
||||
|
||||
Reference in New Issue
Block a user