diff --git a/scripts/evaluation/base_model_inference.py b/scripts/evaluation/base_model_inference.py index 42945a7..a8f2bd9 100644 --- a/scripts/evaluation/base_model_inference.py +++ b/scripts/evaluation/base_model_inference.py @@ -289,13 +289,15 @@ def image_guided_synthesis(model: torch.nn.Module, if not text_input: prompts = [""] * batch_size - b, c, t, h, w = videos.shape - img = videos[:, :, 0] - img_emb = model.embedder(img) - img_emb = model.image_proj_model(img_emb) - img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t) - cond_emb = model.get_learned_conditioning(prompts) - cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0) + b, c, t, h, w = videos.shape + img = videos[:, :, 0] + img_emb = model.embedder(img) + cond_emb = model.get_learned_conditioning(prompts) + target_dtype = cond_emb.dtype + img_emb = model._projector_forward(model.image_proj_model, img_emb, + target_dtype) + img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t) + cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0) cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]} if model.model.conditioning_key == 'hybrid': diff --git a/scripts/evaluation/real_eval_server.py b/scripts/evaluation/real_eval_server.py index d780b5d..695d93f 100644 --- a/scripts/evaluation/real_eval_server.py +++ b/scripts/evaluation/real_eval_server.py @@ -168,26 +168,33 @@ def image_guided_synthesis( batch_size = noise_shape[0] fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) - img = observation['observation.images.top'] - cond_img = img[:, -1, ...] - cond_img_emb = model.embedder(cond_img) - cond_img_emb = model.image_proj_model(cond_img_emb) - - if model.model.conditioning_key == 'hybrid': - z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) - img_cat_cond = z[:, :, -1:, :, :] + img = observation['observation.images.top'] + cond_img = img[:, -1, ...] + cond_img_emb = model.embedder(cond_img) + + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) - cond = {"c_concat": [img_cat_cond]} - - cond_ins_emb = model.get_learned_conditioning(prompts) - cond_state = model.state_projector(observation['observation.state']) - cond_state_emb = model.agent_state_pos_emb + cond_state - - cond_action = model.action_projector(observation['action']) - cond_action_emb = model.agent_action_pos_emb + cond_action - cond_action_emb = torch.zeros_like(cond_action_emb) + cond = {"c_concat": [img_cat_cond]} + + cond_ins_emb = model.get_learned_conditioning(prompts) + target_dtype = cond_ins_emb.dtype + cond_img_emb = model._projector_forward(model.image_proj_model, + cond_img_emb, target_dtype) + cond_state = model._projector_forward(model.state_projector, + observation['observation.state'], + target_dtype) + cond_state_emb = model.agent_state_pos_emb.to(dtype=target_dtype) + cond_state + + cond_action = model._projector_forward(model.action_projector, + observation['action'], + target_dtype) + cond_action_emb = model.agent_action_pos_emb.to( + dtype=target_dtype) + cond_action + cond_action_emb = torch.zeros_like(cond_action_emb) cond["c_crossattn"] = [ torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 5bac2b2..2f2d690 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -770,28 +770,36 @@ def image_guided_synthesis_sim_mode( fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) with profiler.profile_section("synthesis/conditioning_prep"): - img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) - cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] - cond_img_emb = model.embedder(cond_img) - cond_img_emb = model.image_proj_model(cond_img_emb) - - if model.model.conditioning_key == 'hybrid': - z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) - img_cat_cond = z[:, :, -1:, :, :] - img_cat_cond = repeat(img_cat_cond, + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] + cond_img_emb = model.embedder(cond_img) + + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) cond = {"c_concat": [img_cat_cond]} - - if not text_input: - prompts = [""] * batch_size - cond_ins_emb = model.get_learned_conditioning(prompts) - - cond_state_emb = model.state_projector(observation['observation.state']) - cond_state_emb = cond_state_emb + model.agent_state_pos_emb - - cond_action_emb = model.action_projector(observation['action']) - cond_action_emb = cond_action_emb + model.agent_action_pos_emb + + if not text_input: + prompts = [""] * batch_size + cond_ins_emb = model.get_learned_conditioning(prompts) + target_dtype = cond_ins_emb.dtype + + cond_img_emb = model._projector_forward(model.image_proj_model, + cond_img_emb, target_dtype) + + cond_state_emb = model._projector_forward( + model.state_projector, observation['observation.state'], + target_dtype) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to( + dtype=target_dtype) + + cond_action_emb = model._projector_forward( + model.action_projector, observation['action'], target_dtype) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to( + dtype=target_dtype) if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py index fbf2042..ceca543 100644 --- a/src/unifolm_wma/models/ddpms.py +++ b/src/unifolm_wma/models/ddpms.py @@ -1882,6 +1882,7 @@ class LatentVisualDiffusion(LatentDiffusion): dp_use_ema: bool = False, pretrained_checkpoint: str | None = None, decision_making_only: bool = True, + projector_bf16: bool = True, *args, **kwargs): """ @@ -1907,6 +1908,7 @@ class LatentVisualDiffusion(LatentDiffusion): dp_use_ema: If True, maintain EMA for action UNet head. pretrained_checkpoint: Optional path to a pretrained checkpoint. decision_making_only: If True, use decision-only augmentation path. + projector_bf16: If True, run image/state/action projectors under BF16 autocast. """ super().__init__(*args, **kwargs) @@ -1917,6 +1919,7 @@ class LatentVisualDiffusion(LatentDiffusion): self.n_obs_steps_imagen = n_obs_steps_imagen self.n_obs_steps_acting = n_obs_steps_acting self.decision_making_only = decision_making_only + self.projector_bf16 = projector_bf16 self._init_embedder(img_cond_stage_config, freeze_embedder) self._init_img_ctx_projector(image_proj_stage_config, @@ -2025,6 +2028,28 @@ class LatentVisualDiffusion(LatentDiffusion): self.agent_state_pos_emb = nn.Parameter( torch.randn(1, self.n_obs_steps_imagen, self.global_emb_dim)) + def _projector_forward(self, projector: nn.Module, x: Tensor, + target_dtype: torch.dtype | None) -> Tensor: + use_bf16 = (self.projector_bf16 and x.device.type == "cuda" + and torch.cuda.is_bf16_supported()) + if use_bf16: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + out = projector(x) + else: + out = projector(x) + if not hasattr(self, "_printed_projector_bf16"): + print( + ">>> projector bf16 autocast: " + f"enabled={self.projector_bf16} " + f"use_bf16={use_bf16} " + f"input={x.dtype} " + f"output={out.dtype} " + f"target={target_dtype}") + self._printed_projector_bf16 = True + if target_dtype is not None and out.dtype != target_dtype: + out = out.to(dtype=target_dtype) + return out + def _get_augmented_batch( self, z: Tensor, @@ -2166,6 +2191,7 @@ class LatentVisualDiffusion(LatentDiffusion): null_prompt = self.get_learned_conditioning([""]) cond_ins_emb = torch.where(prompt_mask, null_prompt, cond_ins_emb.detach()) + target_dtype = cond_ins_emb.dtype # Get conditioning frames cond_frame_index = 0 @@ -2176,7 +2202,8 @@ class LatentVisualDiffusion(LatentDiffusion): cond_img = input_mask * img cond_img_emb = self.embedder(cond_img) - cond_img_emb = self.image_proj_model(cond_img_emb) + cond_img_emb = self._projector_forward(self.image_proj_model, + cond_img_emb, target_dtype) if self.model.conditioning_key == 'hybrid': if self.interp_mode: @@ -2191,11 +2218,15 @@ class LatentVisualDiffusion(LatentDiffusion): repeat=z.shape[2]) cond["c_concat"] = [img_cat_cond] - cond_action = self.action_projector(action) - cond_action_emb = self.agent_action_pos_emb + cond_action + cond_action = self._projector_forward(self.action_projector, action, + target_dtype) + cond_action_emb = self.agent_action_pos_emb.to( + dtype=target_dtype) + cond_action # Get conditioning states - cond_state = self.state_projector(obs_state) - cond_state_emb = self.agent_state_pos_emb + cond_state + cond_state = self._projector_forward(self.state_projector, obs_state, + target_dtype) + cond_state_emb = self.agent_state_pos_emb.to( + dtype=target_dtype) + cond_state if self.decision_making_only: is_sim_mode = False diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768731024.node-0.369734.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768731024.node-0.369734.0 new file mode 100644 index 0000000..1765774 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768731024.node-0.369734.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768731044.node-0.370115.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768731044.node-0.370115.0 new file mode 100644 index 0000000..00ded53 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768731044.node-0.370115.0 differ