diff --git a/configs/inference/world_model_interaction.yaml b/configs/inference/world_model_interaction.yaml index 2d69e09..a1e115a 100644 --- a/configs/inference/world_model_interaction.yaml +++ b/configs/inference/world_model_interaction.yaml @@ -222,7 +222,7 @@ data: test: target: unifolm_wma.data.wma_data.WMAData params: - data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts' + data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts' video_length: ${model.params.wma_config.params.temporal_length} frame_stride: 2 load_raw_resolution: True diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 27fcd2d..71ed2fc 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -13,7 +13,7 @@ import time import json from contextlib import contextmanager, nullcontext from dataclasses import dataclass, field, asdict -from typing import Optional, Dict, List, Any, Mapping +from typing import Optional, Dict, List, Any, Mapping from pytorch_lightning import seed_everything from omegaconf import OmegaConf @@ -56,50 +56,50 @@ class TimingRecord: } -class ProfilerManager: - """Manages macro and micro-level profiling.""" - - def __init__( - self, - enabled: bool = False, - output_dir: str = "./profile_output", - profile_detail: str = "light", - ): - self.enabled = enabled - self.output_dir = output_dir - self.profile_detail = profile_detail - self.macro_timings: Dict[str, List[float]] = {} - self.cuda_events: Dict[str, List[tuple]] = {} - self.memory_snapshots: List[Dict] = [] - self.pytorch_profiler = None - self.current_iteration = 0 - self.operator_stats: Dict[str, Dict] = {} - self.profiler_config = self._build_profiler_config(profile_detail) - - if enabled: - os.makedirs(output_dir, exist_ok=True) - - def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]: - """Return profiler settings based on the requested detail level.""" - if profile_detail not in ("light", "full"): - raise ValueError(f"Unsupported profile_detail: {profile_detail}") - if profile_detail == "full": - return { - "record_shapes": True, - "profile_memory": True, - "with_stack": True, - "with_flops": True, - "with_modules": True, - "group_by_input_shape": True, - } - return { - "record_shapes": False, - "profile_memory": False, - "with_stack": False, - "with_flops": False, - "with_modules": False, - "group_by_input_shape": False, - } +class ProfilerManager: + """Manages macro and micro-level profiling.""" + + def __init__( + self, + enabled: bool = False, + output_dir: str = "./profile_output", + profile_detail: str = "light", + ): + self.enabled = enabled + self.output_dir = output_dir + self.profile_detail = profile_detail + self.macro_timings: Dict[str, List[float]] = {} + self.cuda_events: Dict[str, List[tuple]] = {} + self.memory_snapshots: List[Dict] = [] + self.pytorch_profiler = None + self.current_iteration = 0 + self.operator_stats: Dict[str, Dict] = {} + self.profiler_config = self._build_profiler_config(profile_detail) + + if enabled: + os.makedirs(output_dir, exist_ok=True) + + def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]: + """Return profiler settings based on the requested detail level.""" + if profile_detail not in ("light", "full"): + raise ValueError(f"Unsupported profile_detail: {profile_detail}") + if profile_detail == "full": + return { + "record_shapes": True, + "profile_memory": True, + "with_stack": True, + "with_flops": True, + "with_modules": True, + "group_by_input_shape": True, + } + return { + "record_shapes": False, + "profile_memory": False, + "with_stack": False, + "with_flops": False, + "with_modules": False, + "group_by_input_shape": False, + } @contextmanager def profile_section(self, name: str, sync_cuda: bool = True): @@ -162,22 +162,22 @@ class ProfilerManager: if not self.enabled: return nullcontext() - self.pytorch_profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=wait, warmup=warmup, active=active, repeat=1 - ), - on_trace_ready=self._trace_handler, - record_shapes=self.profiler_config["record_shapes"], - profile_memory=self.profiler_config["profile_memory"], - with_stack=self.profiler_config["with_stack"], - with_flops=self.profiler_config["with_flops"], - with_modules=self.profiler_config["with_modules"], - ) - return self.pytorch_profiler + self.pytorch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=1 + ), + on_trace_ready=self._trace_handler, + record_shapes=self.profiler_config["record_shapes"], + profile_memory=self.profiler_config["profile_memory"], + with_stack=self.profiler_config["with_stack"], + with_flops=self.profiler_config["with_flops"], + with_modules=self.profiler_config["with_modules"], + ) + return self.pytorch_profiler def _trace_handler(self, prof): """Handle profiler trace output.""" @@ -187,10 +187,10 @@ class ProfilerManager: ) prof.export_chrome_trace(trace_path) - # Extract operator statistics - key_averages = prof.key_averages( - group_by_input_shape=self.profiler_config["group_by_input_shape"] - ) + # Extract operator statistics + key_averages = prof.key_averages( + group_by_input_shape=self.profiler_config["group_by_input_shape"] + ) for evt in key_averages: op_name = evt.key if op_name not in self.operator_stats: @@ -375,22 +375,22 @@ class ProfilerManager: # Global profiler instance _profiler: Optional[ProfilerManager] = None -def get_profiler() -> ProfilerManager: - """Get the global profiler instance.""" - global _profiler - if _profiler is None: - _profiler = ProfilerManager(enabled=False) - return _profiler - -def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager: - """Initialize the global profiler.""" - global _profiler - _profiler = ProfilerManager( - enabled=enabled, - output_dir=output_dir, - profile_detail=profile_detail, - ) - return _profiler +def get_profiler() -> ProfilerManager: + """Get the global profiler instance.""" + global _profiler + if _profiler is None: + _profiler = ProfilerManager(enabled=False) + return _profiler + +def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager: + """Initialize the global profiler.""" + global _profiler + _profiler = ProfilerManager( + enabled=enabled, + output_dir=output_dir, + profile_detail=profile_detail, + ) + return _profiler # ========== Original Functions ========== @@ -441,160 +441,162 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: return file_list -def _load_state_dict(model: nn.Module, - state_dict: Mapping[str, torch.Tensor], - strict: bool = True, - assign: bool = False) -> None: - if assign: - try: - model.load_state_dict(state_dict, strict=strict, assign=True) - return - except TypeError: - warnings.warn( - "load_state_dict(assign=True) not supported; " - "falling back to copy load.") - model.load_state_dict(state_dict, strict=strict) - - -def load_model_checkpoint(model: nn.Module, - ckpt: str, - assign: bool | None = None) -> nn.Module: - """Load model weights from checkpoint file. - - Args: - model (nn.Module): Model instance. - ckpt (str): Path to the checkpoint file. - assign (bool | None): Whether to preserve checkpoint tensor dtypes - via load_state_dict(assign=True). If None, auto-enable when a - casted checkpoint metadata is detected. - - Returns: - nn.Module: Model with loaded weights. - """ - ckpt_data = torch.load(ckpt, map_location="cpu") - use_assign = False - if assign is not None: - use_assign = assign - elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data: - use_assign = True - if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data: - state_dict = ckpt_data["state_dict"] - try: - _load_state_dict(model, state_dict, strict=True, assign=use_assign) - except Exception: - new_pl_sd = OrderedDict() - for k, v in state_dict.items(): - new_pl_sd[k] = v - - for k in list(new_pl_sd.keys()): - if "framestride_embed" in k: - new_key = k.replace("framestride_embed", "fps_embedding") - new_pl_sd[new_key] = new_pl_sd[k] - del new_pl_sd[k] - _load_state_dict(model, - new_pl_sd, - strict=True, - assign=use_assign) - elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data: - new_pl_sd = OrderedDict() - for key in ckpt_data['module'].keys(): - new_pl_sd[key[16:]] = ckpt_data['module'][key] - _load_state_dict(model, new_pl_sd, strict=True, assign=use_assign) - else: - _load_state_dict(model, - ckpt_data, - strict=True, - assign=use_assign) - print('>>> model checkpoint loaded.') - return model - - -def maybe_cast_module(module: nn.Module | None, - dtype: torch.dtype, - label: str, - profiler: Optional[ProfilerManager] = None, - profile_name: Optional[str] = None) -> None: - if module is None: - return - try: - param = next(module.parameters()) - except StopIteration: - print(f">>> {label} has no parameters; skip cast") - return - if param.dtype == dtype: - print(f">>> {label} already {dtype}; skip cast") - return - ctx = nullcontext() - if profiler is not None and profile_name: - ctx = profiler.profile_section(profile_name) - with ctx: - module.to(dtype=dtype) - print(f">>> {label} cast to {dtype}") - - -def save_casted_checkpoint(model: nn.Module, - save_path: str, - metadata: Optional[Dict[str, Any]] = None) -> None: - if not save_path: - return - save_dir = os.path.dirname(save_path) - if save_dir: - os.makedirs(save_dir, exist_ok=True) - cpu_state = {} - for key, value in model.state_dict().items(): - if isinstance(value, torch.Tensor): - cpu_state[key] = value.detach().to("cpu") - else: - cpu_state[key] = value - payload: Dict[str, Any] = {"state_dict": cpu_state} - if metadata: - payload["precision_metadata"] = metadata - torch.save(payload, save_path) - print(f">>> Saved casted checkpoint to {save_path}") - - -def _module_param_dtype(module: nn.Module | None) -> str: - if module is None: - return "None" - dtype_counts: Dict[str, int] = {} - for param in module.parameters(): - dtype_key = str(param.dtype) - dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel() - if not dtype_counts: - return "no_params" - if len(dtype_counts) == 1: - return next(iter(dtype_counts)) - total = sum(dtype_counts.values()) - parts = [] - for dtype_key in sorted(dtype_counts.keys()): - ratio = dtype_counts[dtype_key] / total - parts.append(f"{dtype_key}={ratio:.1%}") - return f"mixed({', '.join(parts)})" - - -def log_inference_precision(model: nn.Module) -> None: - device = "unknown" - for param in model.parameters(): - device = str(param.device) - break - model_dtype = _module_param_dtype(model) - - print(f">>> inference precision: model={model_dtype}, device={device}") - for attr in [ - "model", "first_stage_model", "cond_stage_model", "embedder", - "image_proj_model" - ]: - if hasattr(model, attr): - submodule = getattr(model, attr) - print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}") - - print( - ">>> autocast gpu dtype default: " - f"{torch.get_autocast_gpu_dtype()} " - f"(enabled={torch.is_autocast_enabled()})") - - -def is_inferenced(save_dir: str, filename: str) -> bool: +def _load_state_dict(model: nn.Module, + state_dict: Mapping[str, torch.Tensor], + strict: bool = True, + assign: bool = False) -> None: + if assign: + try: + model.load_state_dict(state_dict, strict=strict, assign=True) + return + except TypeError: + warnings.warn( + "load_state_dict(assign=True) not supported; " + "falling back to copy load.") + model.load_state_dict(state_dict, strict=strict) + + +def load_model_checkpoint(model: nn.Module, + ckpt: str, + assign: bool | None = None, + device: str | torch.device = "cpu") -> nn.Module: + """Load model weights from checkpoint file. + + Args: + model (nn.Module): Model instance. + ckpt (str): Path to the checkpoint file. + assign (bool | None): Whether to preserve checkpoint tensor dtypes + via load_state_dict(assign=True). If None, auto-enable when a + casted checkpoint metadata is detected. + device (str | torch.device): Target device for loaded tensors. + + Returns: + nn.Module: Model with loaded weights. + """ + ckpt_data = torch.load(ckpt, map_location=device, mmap=True) + use_assign = False + if assign is not None: + use_assign = assign + elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data: + use_assign = True + if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data: + state_dict = ckpt_data["state_dict"] + try: + _load_state_dict(model, state_dict, strict=True, assign=use_assign) + except Exception: + new_pl_sd = OrderedDict() + for k, v in state_dict.items(): + new_pl_sd[k] = v + + for k in list(new_pl_sd.keys()): + if "framestride_embed" in k: + new_key = k.replace("framestride_embed", "fps_embedding") + new_pl_sd[new_key] = new_pl_sd[k] + del new_pl_sd[k] + _load_state_dict(model, + new_pl_sd, + strict=True, + assign=use_assign) + elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data: + new_pl_sd = OrderedDict() + for key in ckpt_data['module'].keys(): + new_pl_sd[key[16:]] = ckpt_data['module'][key] + _load_state_dict(model, new_pl_sd, strict=True, assign=use_assign) + else: + _load_state_dict(model, + ckpt_data, + strict=True, + assign=use_assign) + print('>>> model checkpoint loaded.') + return model + + +def maybe_cast_module(module: nn.Module | None, + dtype: torch.dtype, + label: str, + profiler: Optional[ProfilerManager] = None, + profile_name: Optional[str] = None) -> None: + if module is None: + return + try: + param = next(module.parameters()) + except StopIteration: + print(f">>> {label} has no parameters; skip cast") + return + if param.dtype == dtype: + print(f">>> {label} already {dtype}; skip cast") + return + ctx = nullcontext() + if profiler is not None and profile_name: + ctx = profiler.profile_section(profile_name) + with ctx: + module.to(dtype=dtype) + print(f">>> {label} cast to {dtype}") + + +def save_casted_checkpoint(model: nn.Module, + save_path: str, + metadata: Optional[Dict[str, Any]] = None) -> None: + if not save_path: + return + save_dir = os.path.dirname(save_path) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + cpu_state = {} + for key, value in model.state_dict().items(): + if isinstance(value, torch.Tensor): + cpu_state[key] = value.detach().to("cpu") + else: + cpu_state[key] = value + payload: Dict[str, Any] = {"state_dict": cpu_state} + if metadata: + payload["precision_metadata"] = metadata + torch.save(payload, save_path) + print(f">>> Saved casted checkpoint to {save_path}") + + +def _module_param_dtype(module: nn.Module | None) -> str: + if module is None: + return "None" + dtype_counts: Dict[str, int] = {} + for param in module.parameters(): + dtype_key = str(param.dtype) + dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel() + if not dtype_counts: + return "no_params" + if len(dtype_counts) == 1: + return next(iter(dtype_counts)) + total = sum(dtype_counts.values()) + parts = [] + for dtype_key in sorted(dtype_counts.keys()): + ratio = dtype_counts[dtype_key] / total + parts.append(f"{dtype_key}={ratio:.1%}") + return f"mixed({', '.join(parts)})" + + +def log_inference_precision(model: nn.Module) -> None: + device = "unknown" + for param in model.parameters(): + device = str(param.device) + break + model_dtype = _module_param_dtype(model) + + print(f">>> inference precision: model={model_dtype}, device={device}") + for attr in [ + "model", "first_stage_model", "cond_stage_model", "embedder", + "image_proj_model" + ]: + if hasattr(model, attr): + submodule = getattr(model, attr) + print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}") + + print( + ">>> autocast gpu dtype default: " + f"{torch.get_autocast_gpu_dtype()} " + f"(enabled={torch.is_autocast_enabled()})") + + +def is_inferenced(save_dir: str, filename: str) -> bool: """Check if a given filename has already been processed and saved. Args: @@ -735,7 +737,7 @@ def prepare_init_input(start_idx: int, return data, ori_state_dim, ori_action_dim -def get_latent_z(model, videos: Tensor) -> Tensor: +def get_latent_z(model, videos: Tensor) -> Tensor: """ Extracts latent features from a video batch using the model's first-stage encoder. @@ -747,20 +749,20 @@ def get_latent_z(model, videos: Tensor) -> Tensor: Tensor: Latent video tensor of shape [B, C, T, H, W]. """ profiler = get_profiler() - with profiler.profile_section("get_latent_z/encode"): - b, c, t, h, w = videos.shape - x = rearrange(videos, 'b c t h w -> (b t) c h w') - vae_ctx = nullcontext() - if getattr(model, "vae_bf16", False) and model.device.type == "cuda": - vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) - with vae_ctx: - z = model.encode_first_stage(x) - z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + with profiler.profile_section("get_latent_z/encode"): + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + vae_ctx = nullcontext() + if getattr(model, "vae_bf16", False) and model.device.type == "cuda": + vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with vae_ctx: + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z -def preprocess_observation( - model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: +def preprocess_observation( + model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. @@ -801,37 +803,37 @@ def preprocess_observation( return_observations['observation.state'].to(model.device) })['observation.state'] - return return_observations - - -def _move_to_device(batch: Mapping[str, Any], - device: torch.device) -> dict[str, Any]: - moved = {} - for key, value in batch.items(): - if isinstance(value, torch.Tensor) and value.device != device: - moved[key] = value.to(device, non_blocking=True) - else: - moved[key] = value - return moved + return return_observations -def image_guided_synthesis_sim_mode( - model: torch.nn.Module, - prompts: list[str], - observation: dict, - noise_shape: tuple[int, int, int, int, int], - action_cond_step: int = 16, - n_samples: int = 1, - ddim_steps: int = 50, - ddim_eta: float = 1.0, - unconditional_guidance_scale: float = 1.0, - fs: int | None = None, - text_input: bool = True, - timestep_spacing: str = 'uniform', - guidance_rescale: float = 0.0, - sim_mode: bool = True, - diffusion_autocast_dtype: Optional[torch.dtype] = None, - **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _move_to_device(batch: Mapping[str, Any], + device: torch.device) -> dict[str, Any]: + moved = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor) and value.device != device: + moved[key] = value.to(device, non_blocking=True) + else: + moved[key] = value + return moved + + +def image_guided_synthesis_sim_mode( + model: torch.nn.Module, + prompts: list[str], + observation: dict, + noise_shape: tuple[int, int, int, int, int], + action_cond_step: int = 16, + n_samples: int = 1, + ddim_steps: int = 50, + ddim_eta: float = 1.0, + unconditional_guidance_scale: float = 1.0, + fs: int | None = None, + text_input: bool = True, + timestep_spacing: str = 'uniform', + guidance_rescale: float = 0.0, + sim_mode: bool = True, + diffusion_autocast_dtype: Optional[torch.dtype] = None, + **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). @@ -852,10 +854,10 @@ def image_guided_synthesis_sim_mode( fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". - guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. - sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. - diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16). - **kwargs: Additional arguments passed to the DDIM sampler. + guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. + sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. + diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16). + **kwargs: Additional arguments passed to the DDIM sampler. Returns: batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. @@ -865,77 +867,77 @@ def image_guided_synthesis_sim_mode( profiler = get_profiler() b, _, t, _, _ = noise_shape - ddim_sampler = getattr(model, "_ddim_sampler", None) - if ddim_sampler is None: - ddim_sampler = DDIMSampler(model) - model._ddim_sampler = ddim_sampler - batch_size = noise_shape[0] + ddim_sampler = getattr(model, "_ddim_sampler", None) + if ddim_sampler is None: + ddim_sampler = DDIMSampler(model) + model._ddim_sampler = ddim_sampler + batch_size = noise_shape[0] fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) with profiler.profile_section("synthesis/conditioning_prep"): - img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) - cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] - if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": - if getattr(model, "encoder_mode", "autocast") == "autocast": - preprocess_ctx = torch.autocast("cuda", enabled=False) - with preprocess_ctx: - cond_img_fp32 = cond_img.float() - if hasattr(model.embedder, "preprocess"): - preprocessed = model.embedder.preprocess(cond_img_fp32) - else: - preprocessed = cond_img_fp32 - - if hasattr(model.embedder, - "encode_with_vision_transformer") and hasattr( - model.embedder, "preprocess"): - original_preprocess = model.embedder.preprocess - try: - model.embedder.preprocess = lambda x: x - with torch.autocast("cuda", dtype=torch.bfloat16): - cond_img_emb = model.embedder.encode_with_vision_transformer( - preprocessed) - finally: - model.embedder.preprocess = original_preprocess - else: - with torch.autocast("cuda", dtype=torch.bfloat16): - cond_img_emb = model.embedder(preprocessed) - else: - with torch.autocast("cuda", dtype=torch.bfloat16): - cond_img_emb = model.embedder(cond_img) - else: - cond_img_emb = model.embedder(cond_img) - - if model.model.conditioning_key == 'hybrid': - z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) - img_cat_cond = z[:, :, -1:, :, :] - img_cat_cond = repeat(img_cat_cond, + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] + if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": + if getattr(model, "encoder_mode", "autocast") == "autocast": + preprocess_ctx = torch.autocast("cuda", enabled=False) + with preprocess_ctx: + cond_img_fp32 = cond_img.float() + if hasattr(model.embedder, "preprocess"): + preprocessed = model.embedder.preprocess(cond_img_fp32) + else: + preprocessed = cond_img_fp32 + + if hasattr(model.embedder, + "encode_with_vision_transformer") and hasattr( + model.embedder, "preprocess"): + original_preprocess = model.embedder.preprocess + try: + model.embedder.preprocess = lambda x: x + with torch.autocast("cuda", dtype=torch.bfloat16): + cond_img_emb = model.embedder.encode_with_vision_transformer( + preprocessed) + finally: + model.embedder.preprocess = original_preprocess + else: + with torch.autocast("cuda", dtype=torch.bfloat16): + cond_img_emb = model.embedder(preprocessed) + else: + with torch.autocast("cuda", dtype=torch.bfloat16): + cond_img_emb = model.embedder(cond_img) + else: + cond_img_emb = model.embedder(cond_img) + + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) cond = {"c_concat": [img_cat_cond]} - - if not text_input: - prompts = [""] * batch_size - encoder_ctx = nullcontext() - if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": - encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16) - with encoder_ctx: - cond_ins_emb = model.get_learned_conditioning(prompts) - target_dtype = cond_ins_emb.dtype - - cond_img_emb = model._projector_forward(model.image_proj_model, - cond_img_emb, target_dtype) - - cond_state_emb = model._projector_forward( - model.state_projector, observation['observation.state'], - target_dtype) - cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to( - dtype=target_dtype) - - cond_action_emb = model._projector_forward( - model.action_projector, observation['action'], target_dtype) - cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to( - dtype=target_dtype) + + if not text_input: + prompts = [""] * batch_size + encoder_ctx = nullcontext() + if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": + encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with encoder_ctx: + cond_ins_emb = model.get_learned_conditioning(prompts) + target_dtype = cond_ins_emb.dtype + + cond_img_emb = model._projector_forward(model.image_proj_model, + cond_img_emb, target_dtype) + + cond_state_emb = model._projector_forward( + model.state_projector, observation['observation.state'], + target_dtype) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to( + dtype=target_dtype) + + cond_action_emb = model._projector_forward( + model.action_projector, observation['action'], target_dtype) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to( + dtype=target_dtype) if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) @@ -956,51 +958,51 @@ def image_guided_synthesis_sim_mode( uc = None kwargs.update({"unconditional_conditioning_img_nonetext": None}) cond_mask = None - cond_z0 = None - - if ddim_sampler is not None: - with profiler.profile_section("synthesis/ddim_sampling"): - autocast_ctx = nullcontext() - if diffusion_autocast_dtype is not None and model.device.type == "cuda": - autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype) - with autocast_ctx: - samples, actions, states, intermedia = ddim_sampler.sample( - S=ddim_steps, - conditioning=cond, - batch_size=batch_size, - shape=noise_shape[1:], - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - cfg_img=None, - mask=cond_mask, - x0=cond_z0, - fs=fs, - timestep_spacing=timestep_spacing, - guidance_rescale=guidance_rescale, - **kwargs) - - # Reconstruct from latent to pixel space - with profiler.profile_section("synthesis/decode_first_stage"): - if getattr(model, "vae_bf16", False): - if samples.dtype != torch.bfloat16: - samples = samples.to(dtype=torch.bfloat16) - vae_ctx = nullcontext() - if model.device.type == "cuda": - vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) - with vae_ctx: - batch_images = model.decode_first_stage(samples) - else: - if samples.dtype != torch.float32: - samples = samples.float() - batch_images = model.decode_first_stage(samples) - batch_variants = batch_images + cond_z0 = None + + if ddim_sampler is not None: + with profiler.profile_section("synthesis/ddim_sampling"): + autocast_ctx = nullcontext() + if diffusion_autocast_dtype is not None and model.device.type == "cuda": + autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype) + with autocast_ctx: + samples, actions, states, intermedia = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=None, + mask=cond_mask, + x0=cond_z0, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs) + + # Reconstruct from latent to pixel space + with profiler.profile_section("synthesis/decode_first_stage"): + if getattr(model, "vae_bf16", False): + if samples.dtype != torch.bfloat16: + samples = samples.to(dtype=torch.bfloat16) + vae_ctx = nullcontext() + if model.device.type == "cuda": + vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with vae_ctx: + batch_images = model.decode_first_stage(samples) + else: + if samples.dtype != torch.float32: + samples = samples.float() + batch_images = model.decode_first_stage(samples) + batch_variants = batch_images return batch_variants, actions, states -def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: +def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: """ Run inference pipeline on prompts and image inputs. @@ -1012,7 +1014,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: Returns: None """ - profiler = get_profiler() + profiler = get_profiler() # Create inference and tensorboard dirs os.makedirs(args.savedir + '/inference', exist_ok=True) @@ -1035,8 +1037,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" with profiler.profile_section("model_loading/checkpoint"): - model = load_model_checkpoint(model, args.ckpt_path) + model = load_model_checkpoint(model, args.ckpt_path, + device=f"cuda:{gpu_no}") model.eval() + model = model.cuda(gpu_no) # move residual buffers not in state_dict print(f'>>> Load pre-trained model ...') # Build unnomalizer @@ -1045,110 +1049,133 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: data = instantiate_from_config(config.data) data.setup() print(">>> Dataset is successfully loaded ...") + device = get_device_from_parameters(model) - with profiler.profile_section("model_to_cuda"): - model = model.cuda(gpu_no) - device = get_device_from_parameters(model) - - diffusion_autocast_dtype = None - if args.diffusion_dtype == "bf16": - maybe_cast_module( - model.model, - torch.bfloat16, - "diffusion backbone", - profiler=profiler, - profile_name="model_loading/diffusion_bf16", - ) - diffusion_autocast_dtype = torch.bfloat16 - print(">>> diffusion backbone set to bfloat16") - - if hasattr(model, "first_stage_model") and model.first_stage_model is not None: - vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32 - maybe_cast_module( - model.first_stage_model, - vae_weight_dtype, - "VAE", - profiler=profiler, - profile_name="model_loading/vae_cast", - ) - model.vae_bf16 = args.vae_dtype == "bf16" - print(f">>> VAE dtype set to {args.vae_dtype}") - - encoder_mode = args.encoder_mode - encoder_bf16 = encoder_mode in ("autocast", "bf16_full") - encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 - if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None: - maybe_cast_module( - model.cond_stage_model, - encoder_weight_dtype, - "cond_stage_model", - profiler=profiler, - profile_name="model_loading/encoder_cond_cast", - ) - if hasattr(model, "embedder") and model.embedder is not None: - maybe_cast_module( - model.embedder, - encoder_weight_dtype, - "embedder", - profiler=profiler, - profile_name="model_loading/encoder_embedder_cast", - ) - model.encoder_bf16 = encoder_bf16 - model.encoder_mode = encoder_mode - print( - f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})" - ) - - projector_mode = args.projector_mode - projector_bf16 = projector_mode in ("autocast", "bf16_full") - projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32 - if hasattr(model, "image_proj_model") and model.image_proj_model is not None: - maybe_cast_module( - model.image_proj_model, - projector_weight_dtype, - "image_proj_model", - profiler=profiler, - profile_name="model_loading/projector_image_cast", - ) - if hasattr(model, "state_projector") and model.state_projector is not None: - maybe_cast_module( - model.state_projector, - projector_weight_dtype, - "state_projector", - profiler=profiler, - profile_name="model_loading/projector_state_cast", - ) - if hasattr(model, "action_projector") and model.action_projector is not None: - maybe_cast_module( - model.action_projector, - projector_weight_dtype, - "action_projector", - profiler=profiler, - profile_name="model_loading/projector_action_cast", - ) - if hasattr(model, "projector_bf16"): - model.projector_bf16 = projector_bf16 - model.projector_mode = projector_mode - print( - f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})" - ) - - log_inference_precision(model) - - if args.export_casted_ckpt: - metadata = { - "diffusion_dtype": args.diffusion_dtype, - "vae_dtype": args.vae_dtype, - "encoder_mode": args.encoder_mode, - "projector_mode": args.projector_mode, - "perframe_ae": args.perframe_ae, - } - save_casted_checkpoint(model, args.export_casted_ckpt, metadata) - if args.export_only: - print(">>> export_only set; skipping inference.") - return - - profiler.record_memory("after_model_load") + diffusion_autocast_dtype = None + if args.diffusion_dtype == "bf16": + maybe_cast_module( + model.model, + torch.bfloat16, + "diffusion backbone", + profiler=profiler, + profile_name="model_loading/diffusion_bf16", + ) + diffusion_autocast_dtype = torch.bfloat16 + print(">>> diffusion backbone set to bfloat16") + + if hasattr(model, "first_stage_model") and model.first_stage_model is not None: + vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32 + maybe_cast_module( + model.first_stage_model, + vae_weight_dtype, + "VAE", + profiler=profiler, + profile_name="model_loading/vae_cast", + ) + model.vae_bf16 = args.vae_dtype == "bf16" + print(f">>> VAE dtype set to {args.vae_dtype}") + + # --- VAE performance optimizations --- + if hasattr(model, "first_stage_model") and model.first_stage_model is not None: + vae = model.first_stage_model + + # Channels-last memory format: cuDNN uses faster NHWC kernels + if args.vae_channels_last: + vae = vae.to(memory_format=torch.channels_last) + vae._channels_last = True + model.first_stage_model = vae + print(">>> VAE converted to channels_last (NHWC) memory format") + + # torch.compile: fuses GroupNorm+SiLU, conv chains, etc. + if args.vae_compile: + vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead") + vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead") + print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)") + + # Batch decode size + vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999 + model.vae_decode_bs = vae_decode_bs + model.vae_encode_bs = vae_decode_bs + if args.vae_decode_bs > 0: + print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}") + else: + print(">>> VAE encode/decode batch size: all frames at once") + + encoder_mode = args.encoder_mode + encoder_bf16 = encoder_mode in ("autocast", "bf16_full") + encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 + if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None: + maybe_cast_module( + model.cond_stage_model, + encoder_weight_dtype, + "cond_stage_model", + profiler=profiler, + profile_name="model_loading/encoder_cond_cast", + ) + if hasattr(model, "embedder") and model.embedder is not None: + maybe_cast_module( + model.embedder, + encoder_weight_dtype, + "embedder", + profiler=profiler, + profile_name="model_loading/encoder_embedder_cast", + ) + model.encoder_bf16 = encoder_bf16 + model.encoder_mode = encoder_mode + print( + f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})" + ) + + projector_mode = args.projector_mode + projector_bf16 = projector_mode in ("autocast", "bf16_full") + projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32 + if hasattr(model, "image_proj_model") and model.image_proj_model is not None: + maybe_cast_module( + model.image_proj_model, + projector_weight_dtype, + "image_proj_model", + profiler=profiler, + profile_name="model_loading/projector_image_cast", + ) + if hasattr(model, "state_projector") and model.state_projector is not None: + maybe_cast_module( + model.state_projector, + projector_weight_dtype, + "state_projector", + profiler=profiler, + profile_name="model_loading/projector_state_cast", + ) + if hasattr(model, "action_projector") and model.action_projector is not None: + maybe_cast_module( + model.action_projector, + projector_weight_dtype, + "action_projector", + profiler=profiler, + profile_name="model_loading/projector_action_cast", + ) + if hasattr(model, "projector_bf16"): + model.projector_bf16 = projector_bf16 + model.projector_mode = projector_mode + print( + f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})" + ) + + log_inference_precision(model) + + if args.export_casted_ckpt: + metadata = { + "diffusion_dtype": args.diffusion_dtype, + "vae_dtype": args.vae_dtype, + "encoder_mode": args.encoder_mode, + "projector_mode": args.projector_mode, + "perframe_ae": args.perframe_ae, + } + save_casted_checkpoint(model, args.export_casted_ckpt, metadata) + if args.export_only: + print(">>> export_only set; skipping inference.") + return + + profiler.record_memory("after_model_load") # Run over data assert (args.height % 16 == 0) and ( @@ -1229,7 +1256,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0) } - observation = _move_to_device(observation, device) + observation = _move_to_device(observation, device) # Update observation queues cond_obs_queues = populate_queues(cond_obs_queues, observation) @@ -1242,9 +1269,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Multi-round interaction with the world-model with pytorch_prof_ctx: - for itr in tqdm(range(args.n_iter)): - log_every = max(1, args.step_log_every) - log_step = (itr % log_every == 0) + for itr in tqdm(range(args.n_iter)): + log_every = max(1, args.step_log_every) + log_step = (itr % log_every == 0) profiler.current_iteration = itr profiler.record_memory(f"iter_{itr}_start") @@ -1262,27 +1289,27 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } - observation = _move_to_device(observation, device) + observation = _move_to_device(observation, device) # Use world-model in policy to generate action - if log_step: - print(f'>>> Step {itr}: generating actions ...') + if log_step: + print(f'>>> Step {itr}: generating actions ...') with profiler.profile_section("action_generation"): - pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( - model, - sample['instruction'], - observation, - noise_shape, + pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( + model, + sample['instruction'], + observation, + noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - sim_mode=False, - diffusion_autocast_dtype=diffusion_autocast_dtype) + unconditional_guidance_scale=args. + unconditional_guidance_scale, + fs=model_input_fs, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False, + diffusion_autocast_dtype=diffusion_autocast_dtype) # Update future actions in the observation queues with profiler.profile_section("update_action_queues"): @@ -1305,27 +1332,27 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } - observation = _move_to_device(observation, device) + observation = _move_to_device(observation, device) # Interaction with the world-model - if log_step: - print(f'>>> Step {itr}: interacting with world model ...') + if log_step: + print(f'>>> Step {itr}: interacting with world model ...') with profiler.profile_section("world_model_interaction"): - pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( - model, - "", - observation, - noise_shape, + pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( + model, + "", + observation, + noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - text_input=False, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - diffusion_autocast_dtype=diffusion_autocast_dtype) + unconditional_guidance_scale, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + diffusion_autocast_dtype=diffusion_autocast_dtype) with profiler.profile_section("update_state_queues"): for step_idx in range(args.exe_steps): @@ -1470,66 +1497,84 @@ def get_parser(): help= "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) - parser.add_argument( - "--perframe_ae", - action='store_true', - default=False, - help= - "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." - ) - parser.add_argument( - "--diffusion_dtype", - type=str, - choices=["fp32", "bf16"], - default="fp32", - help="Dtype for diffusion backbone weights and sampling autocast." - ) - parser.add_argument( - "--projector_mode", - type=str, - choices=["fp32", "autocast", "bf16_full"], - default="fp32", - help= - "Projector precision mode for image/state/action projectors: " - "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " - "bf16_full=bf16 weights + bf16 forward." - ) - parser.add_argument( - "--encoder_mode", - type=str, - choices=["fp32", "autocast", "bf16_full"], - default="fp32", - help= - "Encoder precision mode for cond_stage_model/embedder: " - "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " - "bf16_full=bf16 weights + bf16 forward." - ) - parser.add_argument( - "--vae_dtype", - type=str, - choices=["fp32", "bf16"], - default="fp32", - help="Dtype for VAE/first_stage_model weights and forward autocast." - ) - parser.add_argument( - "--export_casted_ckpt", - type=str, - default=None, - help= - "Save a checkpoint after applying precision settings (mixed dtypes preserved)." - ) - parser.add_argument( - "--export_only", - action='store_true', - default=False, - help="Exit after exporting the casted checkpoint." - ) - parser.add_argument( - "--step_log_every", - type=int, - default=1, - help="Print per-iteration step logs every N iterations." - ) + parser.add_argument( + "--perframe_ae", + action='store_true', + default=False, + help= + "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." + ) + parser.add_argument( + "--diffusion_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="Dtype for diffusion backbone weights and sampling autocast." + ) + parser.add_argument( + "--projector_mode", + type=str, + choices=["fp32", "autocast", "bf16_full"], + default="fp32", + help= + "Projector precision mode for image/state/action projectors: " + "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " + "bf16_full=bf16 weights + bf16 forward." + ) + parser.add_argument( + "--encoder_mode", + type=str, + choices=["fp32", "autocast", "bf16_full"], + default="fp32", + help= + "Encoder precision mode for cond_stage_model/embedder: " + "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " + "bf16_full=bf16 weights + bf16 forward." + ) + parser.add_argument( + "--vae_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="Dtype for VAE/first_stage_model weights and forward autocast." + ) + parser.add_argument( + "--vae_compile", + action='store_true', + default=False, + help="Apply torch.compile to VAE decoder for kernel fusion." + ) + parser.add_argument( + "--vae_channels_last", + action='store_true', + default=False, + help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions." + ) + parser.add_argument( + "--vae_decode_bs", + type=int, + default=0, + help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead." + ) + parser.add_argument( + "--export_casted_ckpt", + type=str, + default=None, + help= + "Save a checkpoint after applying precision settings (mixed dtypes preserved)." + ) + parser.add_argument( + "--export_only", + action='store_true', + default=False, + help="Exit after exporting the casted checkpoint." + ) + parser.add_argument( + "--step_log_every", + type=int, + default=1, + help="Print per-iteration step logs every N iterations." + ) parser.add_argument( "--n_action_steps", type=int, @@ -1569,20 +1614,20 @@ def get_parser(): default=None, help="Directory to save profiling results. Defaults to {savedir}/profile_output." ) - parser.add_argument( - "--profile_iterations", - type=int, - default=3, - help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis." - ) - parser.add_argument( - "--profile_detail", - type=str, - choices=["light", "full"], - default="light", - help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops." - ) - return parser + parser.add_argument( + "--profile_iterations", + type=int, + default=3, + help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis." + ) + parser.add_argument( + "--profile_detail", + type=str, + choices=["light", "full"], + default="light", + help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops." + ) + return parser if __name__ == '__main__': @@ -1597,11 +1642,11 @@ if __name__ == '__main__': profile_output_dir = args.profile_output_dir if profile_output_dir is None: profile_output_dir = os.path.join(args.savedir, "profile_output") - init_profiler( - enabled=args.profile, - output_dir=profile_output_dir, - profile_detail=args.profile_detail, - ) + init_profiler( + enabled=args.profile, + output_dir=profile_output_dir, + profile_detail=args.profile_detail, + ) rank, gpu_num = 0, 1 run_inference(args, gpu_num, rank) diff --git a/src/unifolm_wma/models/autoencoder.py b/src/unifolm_wma/models/autoencoder.py index 2a3b521..1a79699 100644 --- a/src/unifolm_wma/models/autoencoder.py +++ b/src/unifolm_wma/models/autoencoder.py @@ -99,13 +99,16 @@ class AutoencoderKL(pl.LightningModule): print(f"Restored from {path}") def encode(self, x, **kwargs): - + if getattr(self, '_channels_last', False): + x = x.to(memory_format=torch.channels_last) h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z, **kwargs): + if getattr(self, '_channels_last', False): + z = z.to(memory_format=torch.channels_last) z = self.post_quant_conv(z) dec = self.decoder(z) return dec diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py index 2f6c3ca..a558d62 100644 --- a/src/unifolm_wma/models/ddpms.py +++ b/src/unifolm_wma/models/ddpms.py @@ -1073,15 +1073,19 @@ class LatentDiffusion(DDPM): if not self.perframe_ae: encoder_posterior = self.first_stage_model.encode(x) results = self.get_first_stage_encoding(encoder_posterior).detach() - else: ## Consume less GPU memory but slower - results = [] - for index in range(x.shape[0]): - frame_batch = self.first_stage_model.encode(x[index:index + - 1, :, :, :]) - frame_result = self.get_first_stage_encoding( - frame_batch).detach() - results.append(frame_result) - results = torch.cat(results, dim=0) + else: ## Batch encode with configurable batch size + bs = getattr(self, 'vae_encode_bs', 1) + if bs >= x.shape[0]: + encoder_posterior = self.first_stage_model.encode(x) + results = self.get_first_stage_encoding(encoder_posterior).detach() + else: + results = [] + for i in range(0, x.shape[0], bs): + frame_batch = self.first_stage_model.encode(x[i:i + bs]) + frame_result = self.get_first_stage_encoding( + frame_batch).detach() + results.append(frame_result) + results = torch.cat(results, dim=0) if reshape_back: results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t) @@ -1105,16 +1109,21 @@ class LatentDiffusion(DDPM): else: reshape_back = False + z = 1. / self.scale_factor * z + if not self.perframe_ae: - z = 1. / self.scale_factor * z results = self.first_stage_model.decode(z, **kwargs) else: - results = [] - for index in range(z.shape[0]): - frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] - frame_result = self.first_stage_model.decode(frame_z, **kwargs) - results.append(frame_result) - results = torch.cat(results, dim=0) + bs = getattr(self, 'vae_decode_bs', 1) + if bs >= z.shape[0]: + # all frames in one batch + results = self.first_stage_model.decode(z, **kwargs) + else: + results = [] + for i in range(0, z.shape[0], bs): + results.append( + self.first_stage_model.decode(z[i:i + bs], **kwargs)) + results = torch.cat(results, dim=0) if reshape_back: results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t) diff --git a/src/unifolm_wma/modules/networks/ae_modules.py b/src/unifolm_wma/modules/networks/ae_modules.py index 2ec124d..026acdf 100644 --- a/src/unifolm_wma/modules/networks/ae_modules.py +++ b/src/unifolm_wma/modules/networks/ae_modules.py @@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config def nonlinearity(x): - # swish - return x * torch.sigmoid(x) + # swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32):