继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel
solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把 plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补 了输入张量的 contiguous 处理;
This commit is contained in:
@@ -122,8 +122,12 @@ class BasePolicy:
|
||||
) -> dict[str, Any]:
|
||||
target = torch.device(device)
|
||||
for k, v in info_dict.items():
|
||||
if torch.is_tensor(v) and v.device != target:
|
||||
info_dict[k] = v.to(target, non_blocking=True)
|
||||
if torch.is_tensor(v):
|
||||
if v.device != target:
|
||||
v = v.to(target, non_blocking=True)
|
||||
if not v.is_contiguous():
|
||||
v = v.contiguous()
|
||||
info_dict[k] = v
|
||||
return info_dict
|
||||
|
||||
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
|
||||
@@ -415,18 +419,21 @@ class WorldModelPolicy(BasePolicy):
|
||||
keep_horizon = self.cfg.receding_horizon
|
||||
plan = actions[:, :keep_horizon]
|
||||
rest = actions[:, keep_horizon:]
|
||||
self._next_init = rest if self.cfg.warm_start else None
|
||||
self._next_init = rest.contiguous() if self.cfg.warm_start else None
|
||||
|
||||
# frameskip back to timestep
|
||||
plan = plan.reshape(
|
||||
self.env.num_envs, self.flatten_receding_horizon, -1
|
||||
)
|
||||
).contiguous()
|
||||
|
||||
self._action_buffer.extend(plan.transpose(0, 1))
|
||||
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
|
||||
|
||||
action = self._action_buffer.popleft()
|
||||
action = action.reshape(*self.env.action_space.shape)
|
||||
action = action.numpy()
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
else:
|
||||
action = np.asarray(action)
|
||||
|
||||
# post-process action
|
||||
if 'action' in self.process:
|
||||
|
||||
@@ -80,14 +80,24 @@ class CEMSolver:
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
device = torch.device(self.device)
|
||||
var = self.var_scale * torch.ones(
|
||||
[self.n_envs, self.horizon, self.action_dim],
|
||||
device=device,
|
||||
)
|
||||
mean = (
|
||||
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
|
||||
if actions is None
|
||||
else actions
|
||||
)
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
new_mean = torch.zeros(
|
||||
[self.n_envs, remaining, self.action_dim],
|
||||
device=mean.device,
|
||||
)
|
||||
mean = torch.cat([mean, new_mean], dim=1)
|
||||
|
||||
return mean, var
|
||||
|
||||
@@ -105,8 +115,10 @@ class CEMSolver:
|
||||
|
||||
# -- initialize the action distribution globally
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
if mean.device != torch.device(self.device):
|
||||
mean = mean.to(self.device, non_blocking=True)
|
||||
if var.device != torch.device(self.device):
|
||||
var = var.to(self.device, non_blocking=True)
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
@@ -138,6 +150,7 @@ class CEMSolver:
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
|
||||
@@ -172,8 +185,6 @@ class CEMSolver:
|
||||
|
||||
# Gather Top-K Candidates
|
||||
# We need to select the specific candidates corresponding to topk_inds
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
|
||||
# Indexing: candidates[batch_idx, sample_idx]
|
||||
# Result shape: (Batch, K, Horizon, Dim)
|
||||
topk_candidates = candidates[batch_indices, topk_inds]
|
||||
@@ -184,18 +195,19 @@ class CEMSolver:
|
||||
|
||||
# Update final cost for logging
|
||||
# We average the cost of the top elites
|
||||
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
|
||||
final_batch_cost = topk_vals.mean(dim=1).detach()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
outputs["costs"].append(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
outputs["actions"] = mean.detach()
|
||||
outputs["mean"] = [mean.detach()]
|
||||
outputs["var"] = [var.detach()]
|
||||
|
||||
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
|
||||
42
jepa.py
42
jepa.py
@@ -49,18 +49,33 @@ class JEPA(nn.Module):
|
||||
str(tensor.device),
|
||||
tensor.dtype,
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
tensor.storage_offset(),
|
||||
tensor.data_ptr(),
|
||||
version,
|
||||
)
|
||||
|
||||
def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device):
|
||||
def _get_cached_device_tensor(
|
||||
self,
|
||||
key: str,
|
||||
tensor: torch.Tensor,
|
||||
device: torch.device,
|
||||
*,
|
||||
ensure_contiguous: bool = False,
|
||||
):
|
||||
self._ensure_runtime_caches()
|
||||
signature = (self._tensor_signature(tensor), str(device))
|
||||
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
|
||||
return tensor
|
||||
|
||||
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
|
||||
cached = self._cached_device_tensors.get(key)
|
||||
if cached is None or cached[0] != signature:
|
||||
prepared = tensor.to(device, non_blocking=True)
|
||||
if ensure_contiguous and not prepared.is_contiguous():
|
||||
prepared = prepared.contiguous()
|
||||
self._cached_device_tensors[key] = (
|
||||
signature,
|
||||
tensor.to(device, non_blocking=True),
|
||||
prepared,
|
||||
)
|
||||
return self._cached_device_tensors[key][1]
|
||||
|
||||
@@ -68,8 +83,13 @@ class JEPA(nn.Module):
|
||||
for key, value in list(info_dict.items()):
|
||||
if key.startswith("_lewm_"):
|
||||
continue
|
||||
if torch.is_tensor(value) and value.device != device:
|
||||
info_dict[key] = self._get_cached_device_tensor(key, value, device)
|
||||
if torch.is_tensor(value):
|
||||
info_dict[key] = self._get_cached_device_tensor(
|
||||
key,
|
||||
value,
|
||||
device,
|
||||
ensure_contiguous=True,
|
||||
)
|
||||
return info_dict
|
||||
|
||||
def _get_cached_init_emb(self, info_dict: dict):
|
||||
@@ -152,9 +172,9 @@ class JEPA(nn.Module):
|
||||
|
||||
init_hist = init_emb[:, -hist_len:]
|
||||
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
|
||||
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1))
|
||||
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
|
||||
|
||||
flat_actions = action_sequence.reshape(B * S, T, -1)
|
||||
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
|
||||
action_emb = self.action_encoder(flat_actions)
|
||||
act_hist = action_emb[:, H - hist_len : H]
|
||||
act_future = action_emb[:, H:]
|
||||
@@ -245,8 +265,12 @@ class JEPA(nn.Module):
|
||||
self._ensure_runtime_caches()
|
||||
device = next(self.parameters()).device
|
||||
info_dict = self._ensure_info_device(info_dict, device)
|
||||
if action_candidates.device != device:
|
||||
action_candidates = action_candidates.to(device, non_blocking=True)
|
||||
action_candidates = self._get_cached_device_tensor(
|
||||
"_lewm_action_candidates",
|
||||
action_candidates,
|
||||
device,
|
||||
ensure_contiguous=True,
|
||||
)
|
||||
|
||||
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
|
||||
info_dict = self.rollout(info_dict, action_candidates)
|
||||
|
||||
@@ -1768,3 +1768,363 @@ evaluation_time: 43.71034002304077 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
inference_precision: fp16
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 47.23623466491699 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
inference_precision: fp16
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 57.10417580604553 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
inference_precision: fp16
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 51.94328594207764 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
inference_precision: fp16
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 46.037922620773315 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
inference_precision: fp16
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 40.61683630943298 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
inference_precision: fp16
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 41.09517192840576 seconds
|
||||
inference_precision: fp16
|
||||
inference_compile_target: predictor
|
||||
inference_compile_mode: reduce-overhead
|
||||
|
||||
Reference in New Issue
Block a user