Optimize JEPA eval outputs and inference hot path
This commit is contained in:
20
eval.py
20
eval.py
@@ -129,14 +129,13 @@ def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
|||||||
summary_path = profile_dir / "key_averages.txt"
|
summary_path = profile_dir / "key_averages.txt"
|
||||||
summary_path.write_text(table)
|
summary_path.write_text(table)
|
||||||
|
|
||||||
if profile_cfg["export_chrome_trace"]:
|
|
||||||
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
|
||||||
|
|
||||||
if profile_cfg["export_tensorboard"]:
|
if profile_cfg["export_tensorboard"]:
|
||||||
trace_handler = torch.profiler.tensorboard_trace_handler(
|
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||||
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
||||||
)
|
)
|
||||||
trace_handler(profiler)
|
trace_handler(profiler)
|
||||||
|
elif profile_cfg["export_chrome_trace"]:
|
||||||
|
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
||||||
|
|
||||||
return summary_path
|
return summary_path
|
||||||
|
|
||||||
@@ -198,12 +197,11 @@ def run(cfg: DictConfig):
|
|||||||
inference_ctx = nullcontext()
|
inference_ctx = nullcontext()
|
||||||
inference_precision = "fp32"
|
inference_precision = "fp32"
|
||||||
|
|
||||||
results_path = (
|
# Hydra switches the working directory to the per-run outputs folder.
|
||||||
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
# Keep all generated artifacts with that run instead of scattering them
|
||||||
if cfg.policy != "random"
|
# next to the cache or source tree.
|
||||||
else Path(__file__).parent
|
output_dir = Path.cwd().resolve()
|
||||||
)
|
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, output_dir)
|
||||||
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
|
|
||||||
|
|
||||||
# sample the episodes and the starting indices
|
# sample the episodes and the starting indices
|
||||||
episode_len = get_episodes_length(dataset, ep_indices)
|
episode_len = get_episodes_length(dataset, ep_indices)
|
||||||
@@ -251,7 +249,7 @@ def run(cfg: DictConfig):
|
|||||||
eval_budget=cfg.eval.eval_budget,
|
eval_budget=cfg.eval.eval_budget,
|
||||||
episodes_idx=eval_episodes.tolist(),
|
episodes_idx=eval_episodes.tolist(),
|
||||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||||||
video_path=results_path,
|
video_path=output_dir,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -260,7 +258,7 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
results_path = results_path / cfg.output.filename
|
results_path = output_dir / cfg.output.filename
|
||||||
results_path.parent.mkdir(parents=True, exist_ok=True)
|
results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with results_path.open("a") as f:
|
with results_path.open("a") as f:
|
||||||
|
|||||||
147
jepa.py
147
jepa.py
@@ -5,9 +5,6 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
def detach_clone(v):
|
|
||||||
return v.detach().clone() if torch.is_tensor(v) else v
|
|
||||||
|
|
||||||
class JEPA(nn.Module):
|
class JEPA(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -25,6 +22,76 @@ class JEPA(nn.Module):
|
|||||||
self.action_encoder = action_encoder
|
self.action_encoder = action_encoder
|
||||||
self.projector = projector or nn.Identity()
|
self.projector = projector or nn.Identity()
|
||||||
self.pred_proj = pred_proj or nn.Identity()
|
self.pred_proj = pred_proj or nn.Identity()
|
||||||
|
self._cached_device_tensors = {}
|
||||||
|
self._cached_init_signature = None
|
||||||
|
self._cached_init_emb = None
|
||||||
|
self._cached_goal_signature = None
|
||||||
|
self._cached_goal_emb = None
|
||||||
|
|
||||||
|
def _ensure_runtime_caches(self):
|
||||||
|
if not hasattr(self, "_cached_device_tensors"):
|
||||||
|
self._cached_device_tensors = {}
|
||||||
|
if not hasattr(self, "_cached_init_signature"):
|
||||||
|
self._cached_init_signature = None
|
||||||
|
if not hasattr(self, "_cached_init_emb"):
|
||||||
|
self._cached_init_emb = None
|
||||||
|
if not hasattr(self, "_cached_goal_signature"):
|
||||||
|
self._cached_goal_signature = None
|
||||||
|
if not hasattr(self, "_cached_goal_emb"):
|
||||||
|
self._cached_goal_emb = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tensor_signature(tensor: torch.Tensor):
|
||||||
|
try:
|
||||||
|
version = tensor._version
|
||||||
|
except RuntimeError:
|
||||||
|
version = None
|
||||||
|
return (
|
||||||
|
str(tensor.device),
|
||||||
|
tensor.dtype,
|
||||||
|
tuple(tensor.shape),
|
||||||
|
tensor.data_ptr(),
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device):
|
||||||
|
self._ensure_runtime_caches()
|
||||||
|
signature = (self._tensor_signature(tensor), str(device))
|
||||||
|
cached = self._cached_device_tensors.get(key)
|
||||||
|
if cached is None or cached[0] != signature:
|
||||||
|
self._cached_device_tensors[key] = (
|
||||||
|
signature,
|
||||||
|
tensor.to(device, non_blocking=True),
|
||||||
|
)
|
||||||
|
return self._cached_device_tensors[key][1]
|
||||||
|
|
||||||
|
def _ensure_info_device(self, info_dict: dict, device: torch.device):
|
||||||
|
for key, value in list(info_dict.items()):
|
||||||
|
if key.startswith("_lewm_"):
|
||||||
|
continue
|
||||||
|
if torch.is_tensor(value) and value.device != device:
|
||||||
|
info_dict[key] = self._get_cached_device_tensor(key, value, device)
|
||||||
|
return info_dict
|
||||||
|
|
||||||
|
def _get_cached_init_emb(self, info_dict: dict):
|
||||||
|
self._ensure_runtime_caches()
|
||||||
|
pixels = info_dict["pixels"]
|
||||||
|
signature = self._tensor_signature(pixels)
|
||||||
|
if self._cached_init_signature != signature:
|
||||||
|
init_info = {"pixels": pixels[:, 0]}
|
||||||
|
self._cached_init_emb = self.encode(init_info)["emb"].detach()
|
||||||
|
self._cached_init_signature = signature
|
||||||
|
return self._cached_init_emb
|
||||||
|
|
||||||
|
def _get_cached_goal_emb(self, info_dict: dict):
|
||||||
|
self._ensure_runtime_caches()
|
||||||
|
goal = info_dict["goal"]
|
||||||
|
signature = self._tensor_signature(goal)
|
||||||
|
if self._cached_goal_signature != signature:
|
||||||
|
goal_info = {"pixels": goal[:, 0]}
|
||||||
|
self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach()
|
||||||
|
self._cached_goal_signature = signature
|
||||||
|
return self._cached_goal_emb
|
||||||
|
|
||||||
def encode(self, info):
|
def encode(self, info):
|
||||||
"""Encode observations and actions into embeddings.
|
"""Encode observations and actions into embeddings.
|
||||||
@@ -71,42 +138,33 @@ class JEPA(nn.Module):
|
|||||||
H = info["pixels"].size(2)
|
H = info["pixels"].size(2)
|
||||||
B, S, T = action_sequence.shape[:3]
|
B, S, T = action_sequence.shape[:3]
|
||||||
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
||||||
info["action"] = act_0
|
|
||||||
n_steps = T - H
|
|
||||||
|
|
||||||
# copy and encode initial info dict
|
# Cache the encoded initial state across solver iterations.
|
||||||
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
|
init_emb = self._get_cached_init_emb(info)
|
||||||
_init = self.encode(_init)
|
HS = history_size
|
||||||
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
|
emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1)
|
||||||
_init = {k: detach_clone(v) for k, v in _init.items()}
|
emb_hist = rearrange(emb_hist[..., -HS:, :], "b s ... -> (b s) ...")
|
||||||
|
|
||||||
# flatten batch and sample dimensions for rollout
|
act_hist = rearrange(act_0[..., -HS:, :], "b s ... -> (b s) ...")
|
||||||
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
|
act_emb_hist = self.action_encoder(act_hist)
|
||||||
act = rearrange(act_0, "b s ... -> (b s) ...")
|
|
||||||
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
# rollout predictor autoregressively for n_steps
|
for t in range(act_future.size(1)):
|
||||||
HS = history_size
|
pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
|
||||||
for t in range(n_steps):
|
if HS > 1:
|
||||||
act_emb = self.action_encoder(act)
|
emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1)
|
||||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
else:
|
||||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
emb_hist = pred_emb
|
||||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
|
||||||
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
|
|
||||||
|
|
||||||
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
|
next_act = act_future[:, t : t + 1, :]
|
||||||
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
|
next_act_emb = self.action_encoder(next_act)
|
||||||
|
if HS > 1:
|
||||||
|
act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1)
|
||||||
|
else:
|
||||||
|
act_emb_hist = next_act_emb
|
||||||
|
|
||||||
# predict the last state
|
pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:]
|
||||||
act_emb = self.action_encoder(act) # (BS, T, A_emb)
|
info["predicted_emb"] = rearrange(pred_rollout, "(b s) ... -> b s ...", b=B, s=S)
|
||||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
|
||||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
|
||||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
|
||||||
emb = torch.cat([emb, pred_emb], dim=1)
|
|
||||||
|
|
||||||
# unflatten batch and sample dimensions
|
|
||||||
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
|
|
||||||
info["predicted_emb"] = pred_rollout
|
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@@ -115,8 +173,8 @@ class JEPA(nn.Module):
|
|||||||
with torch.profiler.record_function("lewm.criterion"):
|
with torch.profiler.record_function("lewm.criterion"):
|
||||||
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
||||||
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
||||||
|
if goal_emb.ndim == pred_emb.ndim - 1:
|
||||||
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
|
goal_emb = goal_emb.unsqueeze(1)
|
||||||
|
|
||||||
# return last-step cost per action candidate
|
# return last-step cost per action candidate
|
||||||
cost = F.mse_loss(
|
cost = F.mse_loss(
|
||||||
@@ -132,22 +190,13 @@ class JEPA(nn.Module):
|
|||||||
with torch.profiler.record_function("lewm.get_cost"):
|
with torch.profiler.record_function("lewm.get_cost"):
|
||||||
assert "goal" in info_dict, "goal not in info_dict"
|
assert "goal" in info_dict, "goal not in info_dict"
|
||||||
|
|
||||||
|
self._ensure_runtime_caches()
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
for k in list(info_dict.keys()):
|
info_dict = self._ensure_info_device(info_dict, device)
|
||||||
if torch.is_tensor(info_dict[k]):
|
if action_candidates.device != device:
|
||||||
info_dict[k] = info_dict[k].to(device)
|
action_candidates = action_candidates.to(device, non_blocking=True)
|
||||||
|
|
||||||
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
|
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
|
||||||
goal["pixels"] = goal["goal"]
|
|
||||||
|
|
||||||
for k in info_dict:
|
|
||||||
if k.startswith("goal_"):
|
|
||||||
goal[k[len("goal_") :]] = goal.pop(k)
|
|
||||||
|
|
||||||
goal.pop("action")
|
|
||||||
goal = self.encode(goal)
|
|
||||||
|
|
||||||
info_dict["goal_emb"] = goal["emb"]
|
|
||||||
info_dict = self.rollout(info_dict, action_candidates)
|
info_dict = self.rollout(info_dict, action_candidates)
|
||||||
|
|
||||||
cost = self.criterion(info_dict)
|
cost = self.criterion(info_dict)
|
||||||
|
|||||||
57
tworoom_results.txt
Normal file
57
tworoom_results.txt
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
|
||||||
|
==== CONFIG ====
|
||||||
|
cache_dir: null
|
||||||
|
solver:
|
||||||
|
_target_: stable_worldmodel.solver.CEMSolver
|
||||||
|
model: ???
|
||||||
|
batch_size: 1
|
||||||
|
num_samples: 300
|
||||||
|
var_scale: 1.0
|
||||||
|
n_steps: 30
|
||||||
|
topk: 30
|
||||||
|
device: cuda
|
||||||
|
seed: ${seed}
|
||||||
|
world:
|
||||||
|
env_name: swm/TwoRoom-v1
|
||||||
|
num_envs: ${eval.num_eval}
|
||||||
|
max_episode_steps: 100
|
||||||
|
history_size: 1
|
||||||
|
frame_skip: 1
|
||||||
|
seed: 42
|
||||||
|
policy: two-room/tworoom/lejepa
|
||||||
|
dataset:
|
||||||
|
stats: ${eval.dataset_name}
|
||||||
|
keys_to_cache:
|
||||||
|
- action
|
||||||
|
- proprio
|
||||||
|
plan_config:
|
||||||
|
horizon: 5
|
||||||
|
receding_horizon: 5
|
||||||
|
action_block: 5
|
||||||
|
eval:
|
||||||
|
num_eval: 50
|
||||||
|
goal_offset_steps: 25
|
||||||
|
eval_budget: 50
|
||||||
|
img_size: 224
|
||||||
|
dataset_name: tworoom
|
||||||
|
callables:
|
||||||
|
- method: _set_state
|
||||||
|
args:
|
||||||
|
state:
|
||||||
|
value: proprio
|
||||||
|
- method: _set_goal_state
|
||||||
|
args:
|
||||||
|
goal_state:
|
||||||
|
value: goal_proprio
|
||||||
|
output:
|
||||||
|
filename: tworoom_results.txt
|
||||||
|
|
||||||
|
==== RESULTS ====
|
||||||
|
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||||
|
True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, False, True, True, True, True, True,
|
||||||
|
True, True, True, True, False, True, True, True, True,
|
||||||
|
True, True, False, True, True, True, True, True, True,
|
||||||
|
True, True, True, True, True]), 'seeds': None}
|
||||||
|
evaluation_time: 133.1857841014862 seconds
|
||||||
|
inference_precision: fp32
|
||||||
Reference in New Issue
Block a user