From 25e4ddb628bcc6343b3fecbfaba18c21cecc6130 Mon Sep 17 00:00:00 2001 From: qihuanye Date: Thu, 9 Apr 2026 12:33:50 +0000 Subject: [PATCH] =?UTF-8?q?=E7=BB=A7=E7=BB=AD=E5=81=9A=E4=BA=86=E9=80=9A?= =?UTF-8?q?=E7=94=A8=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96=EF=BC=8C=E9=87=8D?= =?UTF-8?q?=E7=82=B9=E4=BB=8E=20jepa.py=20=E7=83=AD=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E8=BD=AC=E5=88=B0=E5=AE=9E=E9=99=85=E7=9A=84=20stable=5Fworldm?= =?UTF-8?q?odel=20=20=20solver/policy=20=E8=BE=B9=E7=95=8C=EF=BC=9A?= =?UTF-8?q?=E5=8E=BB=E6=8E=89=20CEM=20=E6=AF=8F=E8=BD=AE=20cpu().tolist()?= =?UTF-8?q?=20=E5=92=8C=E7=BB=93=E6=9E=9C=E8=BF=87=E6=97=A9=E5=9B=9E=20CPU?= =?UTF-8?q?=EF=BC=8C=E6=8A=8A=20=20=20plan/warm-start=20=E4=BF=9D=E6=8C=81?= =?UTF-8?q?=E5=9C=A8=20GPU=EF=BC=8C=E5=8F=AA=E5=9C=A8=20env.step=20?= =?UTF-8?q?=E5=89=8D=E6=9C=80=E5=90=8E=E4=B8=80=E6=AD=A5=E8=BD=AC=E6=88=90?= =?UTF-8?q?=20numpy=EF=BC=8C=E5=90=8C=E6=97=B6=E8=A1=A5=20=20=20=E4=BA=86?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E5=BC=A0=E9=87=8F=E7=9A=84=20contiguous=20?= =?UTF-8?q?=E5=A4=84=E7=90=86=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../site-packages/stable_worldmodel/policy.py | 19 +- .../stable_worldmodel/solver/cem.py | 40 +- jepa.py | 42 +- tworoom_results.txt | 360 ++++++++++++++++++ 4 files changed, 432 insertions(+), 29 deletions(-) diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py index af672c4..1d8c85e 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py @@ -122,8 +122,12 @@ class BasePolicy: ) -> dict[str, Any]: target = torch.device(device) for k, v in info_dict.items(): - if torch.is_tensor(v) and v.device != target: - info_dict[k] = v.to(target, non_blocking=True) + if torch.is_tensor(v): + if v.device != target: + v = v.to(target, non_blocking=True) + if not v.is_contiguous(): + v = v.contiguous() + info_dict[k] = v return info_dict def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]: @@ -415,18 +419,21 @@ class WorldModelPolicy(BasePolicy): keep_horizon = self.cfg.receding_horizon plan = actions[:, :keep_horizon] rest = actions[:, keep_horizon:] - self._next_init = rest if self.cfg.warm_start else None + self._next_init = rest.contiguous() if self.cfg.warm_start else None # frameskip back to timestep plan = plan.reshape( self.env.num_envs, self.flatten_receding_horizon, -1 - ) + ).contiguous() - self._action_buffer.extend(plan.transpose(0, 1)) + self._action_buffer.extend(plan.transpose(0, 1).unbind(0)) action = self._action_buffer.popleft() action = action.reshape(*self.env.action_space.shape) - action = action.numpy() + if torch.is_tensor(action): + action = action.detach().cpu().numpy() + else: + action = np.asarray(action) # post-process action if 'action' in self.process: diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py index 77fa14a..328d141 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py @@ -80,14 +80,24 @@ class CEMSolver: self, actions: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Initialize the action distribution parameters (mean and variance).""" - var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim]) - mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions + device = torch.device(self.device) + var = self.var_scale * torch.ones( + [self.n_envs, self.horizon, self.action_dim], + device=device, + ) + mean = ( + torch.zeros([self.n_envs, 0, self.action_dim], device=device) + if actions is None + else actions + ) remaining = self.horizon - mean.shape[1] if remaining > 0: - device = mean.device - new_mean = torch.zeros([self.n_envs, remaining, self.action_dim]) - mean = torch.cat([mean, new_mean], dim=1).to(device) + new_mean = torch.zeros( + [self.n_envs, remaining, self.action_dim], + device=mean.device, + ) + mean = torch.cat([mean, new_mean], dim=1) return mean, var @@ -105,8 +115,10 @@ class CEMSolver: # -- initialize the action distribution globally mean, var = self.init_action_distrib(init_action) - mean = mean.to(self.device) - var = var.to(self.device) + if mean.device != torch.device(self.device): + mean = mean.to(self.device, non_blocking=True) + if var.device != torch.device(self.device): + var = var.to(self.device, non_blocking=True) total_envs = self.n_envs @@ -138,6 +150,7 @@ class CEMSolver: # Optimization Loop final_batch_cost = None + batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1) for step in range(self.n_steps): # Sample action sequences: (Batch, Num_Samples, Horizon, Dim) @@ -172,8 +185,6 @@ class CEMSolver: # Gather Top-K Candidates # We need to select the specific candidates corresponding to topk_inds - batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk) - # Indexing: candidates[batch_idx, sample_idx] # Result shape: (Batch, K, Horizon, Dim) topk_candidates = candidates[batch_indices, topk_inds] @@ -184,18 +195,19 @@ class CEMSolver: # Update final cost for logging # We average the cost of the top elites - final_batch_cost = topk_vals.mean(dim=1).cpu().tolist() + final_batch_cost = topk_vals.mean(dim=1).detach() # Write results back to global storage mean[start_idx:end_idx] = batch_mean var[start_idx:end_idx] = batch_var # Store history/metadata - outputs["costs"].extend(final_batch_cost) + outputs["costs"].append(final_batch_cost) - outputs["actions"] = mean.detach().cpu() - outputs["mean"] = [mean.detach().cpu()] - outputs["var"] = [var.detach().cpu()] + outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist() + outputs["actions"] = mean.detach() + outputs["mean"] = [mean.detach()] + outputs["var"] = [var.detach()] print(f"CEM solve time: {time.time() - start_time:.4f} seconds") return outputs diff --git a/jepa.py b/jepa.py index 4ecc2b7..368856e 100644 --- a/jepa.py +++ b/jepa.py @@ -49,18 +49,33 @@ class JEPA(nn.Module): str(tensor.device), tensor.dtype, tuple(tensor.shape), + tuple(tensor.stride()), + tensor.storage_offset(), tensor.data_ptr(), version, ) - def _get_cached_device_tensor(self, key: str, tensor: torch.Tensor, device: torch.device): + def _get_cached_device_tensor( + self, + key: str, + tensor: torch.Tensor, + device: torch.device, + *, + ensure_contiguous: bool = False, + ): self._ensure_runtime_caches() - signature = (self._tensor_signature(tensor), str(device)) + if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()): + return tensor + + signature = (self._tensor_signature(tensor), str(device), ensure_contiguous) cached = self._cached_device_tensors.get(key) if cached is None or cached[0] != signature: + prepared = tensor.to(device, non_blocking=True) + if ensure_contiguous and not prepared.is_contiguous(): + prepared = prepared.contiguous() self._cached_device_tensors[key] = ( signature, - tensor.to(device, non_blocking=True), + prepared, ) return self._cached_device_tensors[key][1] @@ -68,8 +83,13 @@ class JEPA(nn.Module): for key, value in list(info_dict.items()): if key.startswith("_lewm_"): continue - if torch.is_tensor(value) and value.device != device: - info_dict[key] = self._get_cached_device_tensor(key, value, device) + if torch.is_tensor(value): + info_dict[key] = self._get_cached_device_tensor( + key, + value, + device, + ensure_contiguous=True, + ) return info_dict def _get_cached_init_emb(self, info_dict: dict): @@ -152,9 +172,9 @@ class JEPA(nn.Module): init_hist = init_emb[:, -hist_len:] init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1) - init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)) + init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous() - flat_actions = action_sequence.reshape(B * S, T, -1) + flat_actions = action_sequence.contiguous().view(B * S, T, -1) action_emb = self.action_encoder(flat_actions) act_hist = action_emb[:, H - hist_len : H] act_future = action_emb[:, H:] @@ -245,8 +265,12 @@ class JEPA(nn.Module): self._ensure_runtime_caches() device = next(self.parameters()).device info_dict = self._ensure_info_device(info_dict, device) - if action_candidates.device != device: - action_candidates = action_candidates.to(device, non_blocking=True) + action_candidates = self._get_cached_device_tensor( + "_lewm_action_candidates", + action_candidates, + device, + ensure_contiguous=True, + ) info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict) info_dict = self.rollout(info_dict, action_candidates) diff --git a/tworoom_results.txt b/tworoom_results.txt index bc74d7c..19c40eb 100644 --- a/tworoom_results.txt +++ b/tworoom_results.txt @@ -1768,3 +1768,363 @@ evaluation_time: 43.71034002304077 seconds inference_precision: fp16 inference_compile_target: predictor inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 47.23623466491699 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 57.10417580604553 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 51.94328594207764 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 46.037922620773315 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 40.61683630943298 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 41.09517192840576 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead