没啥用
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@@ -130,7 +131,37 @@ class BasePolicy:
|
|||||||
info_dict[k] = v
|
info_dict[k] = v
|
||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
|
def _prepare_info_for_device(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
if device is None:
|
||||||
|
return self._prepare_info(info_dict)
|
||||||
|
|
||||||
|
prepare_info = self._prepare_info
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(prepare_info)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return prepare_info(info_dict)
|
||||||
|
|
||||||
|
accepts_device = any(
|
||||||
|
param.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
or (param.name == "device" and param.kind in {
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
})
|
||||||
|
for param in signature.parameters.values()
|
||||||
|
)
|
||||||
|
if accepts_device:
|
||||||
|
return prepare_info(info_dict, device=device)
|
||||||
|
return prepare_info(info_dict)
|
||||||
|
|
||||||
|
def _prepare_info(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
"""Pre-process and transform observations.
|
"""Pre-process and transform observations.
|
||||||
|
|
||||||
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
|
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
|
||||||
@@ -145,6 +176,7 @@ class BasePolicy:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If an expected numpy array is missing for processing.
|
ValueError: If an expected numpy array is missing for processing.
|
||||||
"""
|
"""
|
||||||
|
target_device = torch.device(device) if device is not None else None
|
||||||
for k, v in info_dict.items():
|
for k, v in info_dict.items():
|
||||||
is_numpy = isinstance(v, (np.ndarray | np.generic))
|
is_numpy = isinstance(v, (np.ndarray | np.generic))
|
||||||
|
|
||||||
@@ -178,10 +210,20 @@ class BasePolicy:
|
|||||||
)
|
)
|
||||||
v = torch.from_numpy(v)
|
v = torch.from_numpy(v)
|
||||||
is_numpy = False
|
is_numpy = False
|
||||||
|
moved_for_transform = False
|
||||||
|
if target_device is not None and target_device.type != "cpu":
|
||||||
|
if v.device != target_device:
|
||||||
|
v = v.to(target_device, non_blocking=True)
|
||||||
|
moved_for_transform = True
|
||||||
if k.startswith('pixels') or k.startswith('goal'):
|
if k.startswith('pixels') or k.startswith('goal'):
|
||||||
# Vectorized image transform on the full batch.
|
# Vectorized image transform on the full batch.
|
||||||
v = v.permute(0, 3, 1, 2)
|
v = v.permute(0, 3, 1, 2)
|
||||||
v = self.transform[k](v)
|
try:
|
||||||
|
v = self.transform[k](v)
|
||||||
|
except (NotImplementedError, RuntimeError):
|
||||||
|
if not moved_for_transform:
|
||||||
|
raise
|
||||||
|
v = self.transform[k](v.cpu())
|
||||||
|
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
v = v.reshape(*shape[:2], *v.shape[1:])
|
v = v.reshape(*shape[:2], *v.shape[1:])
|
||||||
@@ -189,7 +231,12 @@ class BasePolicy:
|
|||||||
if is_numpy and v.dtype.kind not in 'USO':
|
if is_numpy and v.dtype.kind not in 'USO':
|
||||||
v = torch.from_numpy(v)
|
v = torch.from_numpy(v)
|
||||||
|
|
||||||
if torch.is_tensor(v) and v.device.type == "cpu" and not v.is_pinned():
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and torch.is_tensor(v)
|
||||||
|
and v.device.type == "cpu"
|
||||||
|
and not v.is_pinned()
|
||||||
|
):
|
||||||
v = v.pin_memory()
|
v = v.pin_memory()
|
||||||
|
|
||||||
info_dict[k] = v
|
info_dict[k] = v
|
||||||
@@ -313,7 +360,9 @@ class FeedForwardPolicy(BasePolicy):
|
|||||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||||
|
|
||||||
# Prepare the info dict (transforms and normalizes inputs)
|
# Prepare the info dict (transforms and normalizes inputs)
|
||||||
info_dict = self._prepare_info(info_dict)
|
info_dict = self._prepare_info_for_device(
|
||||||
|
info_dict, device=next(self.model.parameters()).device
|
||||||
|
)
|
||||||
|
|
||||||
# Add goal_pixels key for GCBC model
|
# Add goal_pixels key for GCBC model
|
||||||
if 'goal' in info_dict:
|
if 'goal' in info_dict:
|
||||||
@@ -422,7 +471,9 @@ class WorldModelPolicy(BasePolicy):
|
|||||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||||
|
|
||||||
info_dict = self._prepare_info(info_dict)
|
info_dict = self._prepare_info_for_device(
|
||||||
|
info_dict, device=self.solver.device
|
||||||
|
)
|
||||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||||
|
|
||||||
outputs = self.solver(
|
outputs = self.solver(
|
||||||
|
|||||||
11
eval.py
11
eval.py
@@ -91,6 +91,10 @@ def get_compile_warmup_cfg(cfg):
|
|||||||
"num_eval": 1,
|
"num_eval": 1,
|
||||||
}
|
}
|
||||||
cfg_warmup = cfg.get("compile_warmup")
|
cfg_warmup = cfg.get("compile_warmup")
|
||||||
|
if cfg_warmup is None:
|
||||||
|
cfg_eval = cfg.get("eval")
|
||||||
|
if cfg_eval is not None:
|
||||||
|
cfg_warmup = cfg_eval.get("compile_warmup")
|
||||||
if cfg_warmup is not None:
|
if cfg_warmup is not None:
|
||||||
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
|
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
|
||||||
return warmup_cfg
|
return warmup_cfg
|
||||||
@@ -464,6 +468,7 @@ def run_eval_subset(
|
|||||||
*,
|
*,
|
||||||
device_override: str | None = None,
|
device_override: str | None = None,
|
||||||
enable_profile: bool = True,
|
enable_profile: bool = True,
|
||||||
|
enable_compile_warmup: bool = False,
|
||||||
before_evaluate=None,
|
before_evaluate=None,
|
||||||
):
|
):
|
||||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||||
@@ -545,6 +550,8 @@ def run_eval_subset(
|
|||||||
with get_eval_grad_context(solver):
|
with get_eval_grad_context(solver):
|
||||||
with profiler_ctx as profiler:
|
with profiler_ctx as profiler:
|
||||||
with inference_ctx:
|
with inference_ctx:
|
||||||
|
if enable_compile_warmup:
|
||||||
|
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||||
metrics = evaluate_subset(eval_episodes, eval_start_idx)
|
metrics = evaluate_subset(eval_episodes, eval_start_idx)
|
||||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
@@ -596,8 +603,8 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
|||||||
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
|
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
|
||||||
run_eval_subset(
|
run_eval_subset(
|
||||||
warmup_eval_cfg,
|
warmup_eval_cfg,
|
||||||
eval_episodes[:warmup_count].tolist(),
|
list(eval_episodes[:warmup_count]),
|
||||||
eval_start_idx[:warmup_count].tolist(),
|
list(eval_start_idx[:warmup_count]),
|
||||||
Path(tmpdir),
|
Path(tmpdir),
|
||||||
device_override=device_override,
|
device_override=device_override,
|
||||||
enable_profile=False,
|
enable_profile=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user