"""Export video UNet backbone to ONNX, then convert to TensorRT engine. Usage: python scripts/export_trt.py \ --ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \ --config configs/inference/world_model_interaction.yaml \ --out_dir trt_engines """ import os import sys import argparse import torch import tensorrt as trt from omegaconf import OmegaConf sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.trt_utils import export_backbone_onnx def load_model(config_path, ckpt_path): # If .prepared.pt, load directly; otherwise instantiate from config + state_dict if ckpt_path.endswith('.prepared.pt'): model = torch.load(ckpt_path, map_location='cpu') else: config = OmegaConf.load(config_path) model = instantiate_from_config(config.model) state_dict = torch.load(ckpt_path, map_location='cpu') if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] model.load_state_dict(state_dict, strict=False) model.eval().cuda() return model def main(): parser = argparse.ArgumentParser() parser.add_argument('--ckpt', required=True) parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml') parser.add_argument('--out_dir', default='trt_engines') parser.add_argument('--context_len', type=int, default=95, help='Preprocessed context seq length (95 for standard config)') parser.add_argument('--fp16', action='store_true', default=True) args = parser.parse_args() os.makedirs(args.out_dir, exist_ok=True) onnx_path = os.path.join(args.out_dir, 'video_backbone.onnx') engine_path = os.path.join(args.out_dir, 'video_backbone.engine') # Step 1 & 2: Export ONNX (skip if exists) if os.path.exists(onnx_path): print(f">>> ONNX already exists at {onnx_path}, skipping export.") n_outputs = 10 else: print(">>> Loading model ...") model = load_model(args.config, args.ckpt) print(">>> Exporting ONNX ...") with torch.no_grad(): n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len) del model torch.cuda.empty_cache() # Step 3: Convert to TRT via Python API print(">>> Converting ONNX -> TensorRT engine ...") logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(os.path.abspath(onnx_path)): for i in range(parser.num_errors): print(f" ONNX parse error: {parser.get_error(i)}") raise RuntimeError("ONNX parsing failed") config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30) # 16GB if args.fp16: config.set_flag(trt.BuilderFlag.FP16) engine_bytes = builder.build_serialized_network(network, config) with open(engine_path, 'wb') as f: f.write(engine_bytes) print(f"\n>>> Done! Engine saved to {engine_path}") print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors") print(f"\n To use: model.model.diffusion_model.load_trt_backbone('{engine_path}')") if __name__ == '__main__': main()