diff --git a/configs/inference/world_model_interaction.yaml b/configs/inference/world_model_interaction.yaml index da709e0..a1e115a 100644 --- a/configs/inference/world_model_interaction.yaml +++ b/configs/inference/world_model_interaction.yaml @@ -222,7 +222,7 @@ data: test: target: unifolm_wma.data.wma_data.WMAData params: - data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts' + data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts' video_length: ${model.params.wma_config.params.temporal_length} frame_stride: 2 load_raw_resolution: True diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 3270dda..c8cc154 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -1,5 +1,7 @@ import argparse, os, glob from contextlib import nullcontext +import atexit +from concurrent.futures import ThreadPoolExecutor import pandas as pd import random import torch @@ -11,13 +13,15 @@ import einops import warnings import imageio +from typing import Optional, List, Any + from pytorch_lightning import seed_everything from omegaconf import OmegaConf from tqdm import tqdm from einops import rearrange, repeat from collections import OrderedDict from torch import nn -from eval_utils import populate_queues, log_to_tensorboard +from eval_utils import populate_queues from collections import deque from torch import Tensor from torch.utils.tensorboard import SummaryWriter @@ -28,6 +32,80 @@ from unifolm_wma.utils.utils import instantiate_from_config import torch.nn.functional as F +# ========== Async I/O utilities ========== +_io_executor: Optional[ThreadPoolExecutor] = None +_io_futures: List[Any] = [] + + +def _get_io_executor() -> ThreadPoolExecutor: + global _io_executor + if _io_executor is None: + _io_executor = ThreadPoolExecutor(max_workers=2) + return _io_executor + + +def _flush_io(): + """Wait for all pending async I/O to finish.""" + global _io_futures + for fut in _io_futures: + try: + fut.result() + except Exception as e: + print(f">>> [async I/O] error: {e}") + _io_futures.clear() + + +atexit.register(_flush_io) + + +def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None: + """Synchronous save on CPU tensor (runs in background thread).""" + video = torch.clamp(video_cpu.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(filename, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None: + """Submit video saving to background thread pool.""" + video_cpu = video.detach().cpu() + fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps) + _io_futures.append(fut) + + +def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None: + """Synchronous tensorboard logging on CPU tensor (runs in background thread).""" + video = video_cpu.float() + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + writer.add_video(tag, grid, fps=fps) + + +def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None: + """Submit tensorboard logging to background thread pool.""" + video_cpu = video.detach().cpu() + fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps) + _io_futures.append(fut) + + def patch_norm_bypass_autocast(): """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy. This eliminates bf16->fp32->bf16 dtype conversions during UNet forward.""" @@ -185,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: return file_list -def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: +def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module: """Load model weights from checkpoint file. Args: model (nn.Module): Model instance. ckpt (str): Path to the checkpoint file. + device (str): Target device for loaded tensors. Returns: nn.Module: Model with loaded weights. """ - state_dict = torch.load(ckpt, map_location="cpu") + state_dict = torch.load(ckpt, map_location=device) if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] try: @@ -610,36 +689,63 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Load config config = OmegaConf.load(args.config) - config['model']['params']['wma_config']['params'][ - 'use_checkpoint'] = False - model = instantiate_from_config(config.model) - model.perframe_ae = args.perframe_ae - assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" - model = load_model_checkpoint(model, args.ckpt_path) - model.eval() - print(f'>>> Load pre-trained model ...') - # Apply precision settings before moving to GPU - model = apply_precision_settings(model, args) + prepared_path = args.ckpt_path + ".prepared.pt" + if os.path.exists(prepared_path): + # ---- Fast path: load the fully-prepared model ---- + print(f">>> Loading prepared model from {prepared_path} ...") + model = torch.load(prepared_path, + map_location=f"cuda:{gpu_no}", + weights_only=False) + model.eval() - # Compile hot ResBlocks for operator fusion - apply_torch_compile(model) + # Restore autocast attributes (weights already cast, just need contexts) + model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16 + model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None + model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None - # Export precision-converted checkpoint if requested - if args.export_precision_ckpt: - export_path = args.export_precision_ckpt - os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True) - torch.save({"state_dict": model.state_dict()}, export_path) - print(f">>> Precision-converted checkpoint saved to: {export_path}") - return + # Compile hot ResBlocks for operator fusion + apply_torch_compile(model) - # Build unnomalizer + print(f">>> Prepared model loaded.") + else: + # ---- Normal path: construct + checkpoint + casting ---- + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + model = load_model_checkpoint(model, args.ckpt_path, + device=f"cuda:{gpu_no}") + model.eval() + print(f'>>> Load pre-trained model ...') + + # Apply precision settings before moving to GPU + model = apply_precision_settings(model, args) + + # Export precision-converted checkpoint if requested + if args.export_precision_ckpt: + export_path = args.export_precision_ckpt + os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True) + torch.save({"state_dict": model.state_dict()}, export_path) + print(f">>> Precision-converted checkpoint saved to: {export_path}") + return + + model = model.cuda(gpu_no) + + # Save prepared model for fast loading next time (before torch.compile) + print(f">>> Saving prepared model to {prepared_path} ...") + torch.save(model, prepared_path) + print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") + + # Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled) + apply_torch_compile(model) + + # Build normalizer (always needed, independent of model loading path) logging.info("***** Configing Data *****") data = instantiate_from_config(config.data) data.setup() print(">>> Dataset is successfully loaded ...") - - model = model.cuda(gpu_no) device = get_device_from_parameters(model) # Run over data @@ -817,28 +923,28 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Save the imagen videos for decision-making if pred_videos_0 is not None: sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_0, - sample_tag, - fps=args.save_fps) + log_to_tensorboard_async(writer, + pred_videos_0, + sample_tag, + fps=args.save_fps) # Save videos environment changes via world-model interaction sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_1, - sample_tag, - fps=args.save_fps) + log_to_tensorboard_async(writer, + pred_videos_1, + sample_tag, + fps=args.save_fps) # Save the imagen videos for decision-making if pred_videos_0 is not None: sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_0.cpu(), - sample_video_file, - fps=args.save_fps) + save_results_async(pred_videos_0, + sample_video_file, + fps=args.save_fps) # Save videos environment changes via world-model interaction sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_1.cpu(), - sample_video_file, - fps=args.save_fps) + save_results_async(pred_videos_1, + sample_video_file, + fps=args.save_fps) print('>' * 24) # Collect the result of world-model interactions @@ -846,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: full_video = torch.cat(wm_video, dim=2) sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" - log_to_tensorboard(writer, - full_video, - sample_tag, - fps=args.save_fps) + log_to_tensorboard_async(writer, + full_video, + sample_tag, + fps=args.save_fps) sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" - save_results(full_video, sample_full_video_file, fps=args.save_fps) + save_results_async(full_video, sample_full_video_file, fps=args.save_fps) + + # Wait for all async I/O to complete + _flush_io() def get_parser(): diff --git a/src/unifolm_wma/models/autoencoder.py b/src/unifolm_wma/models/autoencoder.py index 2a3b521..2b2acbf 100644 --- a/src/unifolm_wma/models/autoencoder.py +++ b/src/unifolm_wma/models/autoencoder.py @@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule): print(f"Restored from {path}") def encode(self, x, **kwargs): + if getattr(self, '_channels_last', False): + x = x.to(memory_format=torch.channels_last) h = self.encoder(x) moments = self.quant_conv(h) @@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule): return posterior def decode(self, z, **kwargs): + if getattr(self, '_channels_last', False): + z = z.to(memory_format=torch.channels_last) z = self.post_quant_conv(z) dec = self.decoder(z) return dec diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py index 5b6df2f..a6f4678 100644 --- a/src/unifolm_wma/models/ddpms.py +++ b/src/unifolm_wma/models/ddpms.py @@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM): encoder_posterior = self.first_stage_model.encode(x) results = self.get_first_stage_encoding(encoder_posterior).detach() else: ## Consume less GPU memory but slower + bs = getattr(self, 'vae_encode_bs', 1) results = [] - for index in range(x.shape[0]): - frame_batch = self.first_stage_model.encode(x[index:index + - 1, :, :, :]) + for i in range(0, x.shape[0], bs): + frame_batch = self.first_stage_model.encode(x[i:i + bs]) frame_result = self.get_first_stage_encoding( frame_batch).detach() results.append(frame_result) @@ -1109,14 +1109,14 @@ class LatentDiffusion(DDPM): vae_dtype = next(self.first_stage_model.parameters()).dtype z = z.to(dtype=vae_dtype) + z = 1. / self.scale_factor * z if not self.perframe_ae: - z = 1. / self.scale_factor * z results = self.first_stage_model.decode(z, **kwargs) else: + bs = getattr(self, 'vae_decode_bs', 1) results = [] - for index in range(z.shape[0]): - frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] - frame_result = self.first_stage_model.decode(frame_z, **kwargs) + for i in range(0, z.shape[0], bs): + frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs) results.append(frame_result) results = torch.cat(results, dim=0) diff --git a/src/unifolm_wma/models/samplers/.claude/settings.local.json b/src/unifolm_wma/models/samplers/.claude/settings.local.json new file mode 100644 index 0000000..c4f78ee --- /dev/null +++ b/src/unifolm_wma/models/samplers/.claude/settings.local.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(python3:*)" + ] + } +} diff --git a/src/unifolm_wma/modules/networks/ae_modules.py b/src/unifolm_wma/modules/networks/ae_modules.py index 2ec124d..59ecf22 100644 --- a/src/unifolm_wma/modules/networks/ae_modules.py +++ b/src/unifolm_wma/modules/networks/ae_modules.py @@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config def nonlinearity(x): # swish - return x * torch.sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32):