继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel

solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把
  plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补
  了输入张量的 contiguous 处理;
This commit is contained in:
qihuanye
2026-04-09 12:33:50 +00:00
parent 995cd8cfec
commit 25e4ddb628
4 changed files with 432 additions and 29 deletions

View File

@@ -122,8 +122,12 @@ class BasePolicy:
) -> dict[str, Any]:
target = torch.device(device)
for k, v in info_dict.items():
if torch.is_tensor(v) and v.device != target:
info_dict[k] = v.to(target, non_blocking=True)
if torch.is_tensor(v):
if v.device != target:
v = v.to(target, non_blocking=True)
if not v.is_contiguous():
v = v.contiguous()
info_dict[k] = v
return info_dict
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
@@ -415,18 +419,21 @@ class WorldModelPolicy(BasePolicy):
keep_horizon = self.cfg.receding_horizon
plan = actions[:, :keep_horizon]
rest = actions[:, keep_horizon:]
self._next_init = rest if self.cfg.warm_start else None
self._next_init = rest.contiguous() if self.cfg.warm_start else None
# frameskip back to timestep
plan = plan.reshape(
self.env.num_envs, self.flatten_receding_horizon, -1
)
).contiguous()
self._action_buffer.extend(plan.transpose(0, 1))
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
action = self._action_buffer.popleft()
action = action.reshape(*self.env.action_space.shape)
action = action.numpy()
if torch.is_tensor(action):
action = action.detach().cpu().numpy()
else:
action = np.asarray(action)
# post-process action
if 'action' in self.process:

View File

@@ -80,14 +80,24 @@ class CEMSolver:
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
device = torch.device(self.device)
var = self.var_scale * torch.ones(
[self.n_envs, self.horizon, self.action_dim],
device=device,
)
mean = (
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
if actions is None
else actions
)
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
new_mean = torch.zeros(
[self.n_envs, remaining, self.action_dim],
device=mean.device,
)
mean = torch.cat([mean, new_mean], dim=1)
return mean, var
@@ -105,8 +115,10 @@ class CEMSolver:
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
if mean.device != torch.device(self.device):
mean = mean.to(self.device, non_blocking=True)
if var.device != torch.device(self.device):
var = var.to(self.device, non_blocking=True)
total_envs = self.n_envs
@@ -138,6 +150,7 @@ class CEMSolver:
# Optimization Loop
final_batch_cost = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
for step in range(self.n_steps):
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
@@ -172,8 +185,6 @@ class CEMSolver:
# Gather Top-K Candidates
# We need to select the specific candidates corresponding to topk_inds
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Indexing: candidates[batch_idx, sample_idx]
# Result shape: (Batch, K, Horizon, Dim)
topk_candidates = candidates[batch_indices, topk_inds]
@@ -184,18 +195,19 @@ class CEMSolver:
# Update final cost for logging
# We average the cost of the top elites
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
final_batch_cost = topk_vals.mean(dim=1).detach()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
# Store history/metadata
outputs["costs"].extend(final_batch_cost)
outputs["costs"].append(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
outputs["actions"] = mean.detach()
outputs["mean"] = [mean.detach()]
outputs["var"] = [var.detach()]
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
return outputs

42
jepa.py
View File

@@ -49,18 +49,33 @@ class JEPA(nn.Module):
str(tensor.device),
tensor.dtype,
tuple(tensor.shape),
tuple(tensor.stride()),
tensor.storage_offset(),
tensor.data_ptr(),
version,
)
def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device):
def _get_cached_device_tensor(
self,
key: str,
tensor: torch.Tensor,
device: torch.device,
*,
ensure_contiguous: bool = False,
):
self._ensure_runtime_caches()
signature = (self._tensor_signature(tensor), str(device))
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
return tensor
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
cached = self._cached_device_tensors.get(key)
if cached is None or cached[0] != signature:
prepared = tensor.to(device, non_blocking=True)
if ensure_contiguous and not prepared.is_contiguous():
prepared = prepared.contiguous()
self._cached_device_tensors[key] = (
signature,
tensor.to(device, non_blocking=True),
prepared,
)
return self._cached_device_tensors[key][1]
@@ -68,8 +83,13 @@ class JEPA(nn.Module):
for key, value in list(info_dict.items()):
if key.startswith("_lewm_"):
continue
if torch.is_tensor(value) and value.device != device:
info_dict[key] = self._get_cached_device_tensor(key, value, device)
if torch.is_tensor(value):
info_dict[key] = self._get_cached_device_tensor(
key,
value,
device,
ensure_contiguous=True,
)
return info_dict
def _get_cached_init_emb(self, info_dict: dict):
@@ -152,9 +172,9 @@ class JEPA(nn.Module):
init_hist = init_emb[:, -hist_len:]
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1))
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
flat_actions = action_sequence.reshape(B * S, T, -1)
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
action_emb = self.action_encoder(flat_actions)
act_hist = action_emb[:, H - hist_len : H]
act_future = action_emb[:, H:]
@@ -245,8 +265,12 @@ class JEPA(nn.Module):
self._ensure_runtime_caches()
device = next(self.parameters()).device
info_dict = self._ensure_info_device(info_dict, device)
if action_candidates.device != device:
action_candidates = action_candidates.to(device, non_blocking=True)
action_candidates = self._get_cached_device_tensor(
"_lewm_action_candidates",
action_candidates,
device,
ensure_contiguous=True,
)
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
info_dict = self.rollout(info_dict, action_candidates)

View File

@@ -1768,3 +1768,363 @@ evaluation_time: 43.71034002304077 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 47.23623466491699 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 57.10417580604553 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 51.94328594207764 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 46.037922620773315 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 40.61683630943298 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 41.09517192840576 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead