"""Export video UNet backbone to ONNX, then convert to TensorRT engine. Usage: python scripts/export_trt.py \ --ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \ --config configs/inference/world_model_interaction.yaml \ --out_dir trt_engines python scripts/export_trt.py \ --ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \ --config configs/inference/world_model_interaction.yaml \ --engine_path trt_engines/video_backbone_multigpu.engine \ --onnx_path trt_engines/video_backbone_multigpu.onnx """ import os import sys import argparse import json import torch import tensorrt as trt from omegaconf import OmegaConf sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.trt_utils import export_backbone_onnx class TacticRecorder(trt.IAlgorithmSelector): """Pass 1: record all candidate tactics and the auto-selected winner.""" def __init__(self): super().__init__() self.records = {} # layer_name -> {candidates: [...], selected: ...} def select_algorithms(self, ctx, choices): name = ctx.name # Collect input/output shapes inputs = [] for j in range(ctx.num_inputs): try: inputs.append([int(d) for d in ctx.get_shape(j)]) except Exception: inputs.append(None) outputs = [] for j in range(ctx.num_outputs): try: outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)]) except Exception: outputs.append(None) self.records[name] = { "input_shapes": inputs, "output_shapes": outputs, "candidates": [], "selected": None, } for i, c in enumerate(choices): v = c.algorithm_variant self.records[name]["candidates"].append({ "index": i, "implementation": v.implementation, "tactic": v.tactic, "timing_msec": c.timing_msec, "workspace_size": c.workspace_size, }) # return all indices -> let TRT auto-pick the fastest return list(range(len(choices))) def report_algorithms(self, ctx, choices): # Both ctx and choices are lists in report_algorithms for c, alg in zip(ctx, choices): name = c.name if name in self.records: v = alg.algorithm_variant self.records[name]["selected"] = { "implementation": v.implementation, "tactic": v.tactic, "timing_msec": alg.timing_msec, "workspace_size": alg.workspace_size, } def save(self, path): with open(path, "w") as f: json.dump(self.records, f, indent=2) print(f">>> Tactic info saved to {path} ({len(self.records)} layers)") class TacticForcer(trt.IAlgorithmSelector): """Pass 2: force user-specified tactics from a JSON file.""" def __init__(self, path): super().__init__() with open(path) as f: self.overrides = json.load(f) n = sum(1 for v in self.overrides.values() if v.get("force")) print(f">>> Loaded tactic overrides: {n} layers with 'force' set") def select_algorithms(self, ctx, choices): name = ctx.name override = self.overrides.get(name) if override and override.get("force"): target_impl = override["force"]["implementation"] target_tactic = override["force"]["tactic"] for i, c in enumerate(choices): v = c.algorithm_variant if v.implementation == target_impl and v.tactic == target_tactic: return [i] print(f" WARN: forced tactic not found for {name}, using auto") return list(range(len(choices))) def report_algorithms(self, ctx, choices): pass def load_model(config_path, ckpt_path, device): if ckpt_path.endswith('.prepared.pt'): model = torch.load(ckpt_path, map_location='cpu') else: config = OmegaConf.load(config_path) model = instantiate_from_config(config.model) state_dict = torch.load(ckpt_path, map_location='cpu') if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] model.load_state_dict(state_dict, strict=False) model.eval().to(device) return model def main(): parser = argparse.ArgumentParser() parser.add_argument('--ckpt', required=True) parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml') parser.add_argument('--out_dir', default='trt_engines') parser.add_argument('--gpu_id', type=int, default=0, help='CUDA device id used for ONNX export and TRT build.') parser.add_argument('--onnx_path', default=None, help='Optional explicit ONNX output path. Overrides --out_dir default name.') parser.add_argument('--engine_path', default=None, help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.') parser.add_argument('--context_len', type=int, default=95) parser.add_argument('--fp16', action='store_true', default=True) parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON') parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON') args = parser.parse_args() device = torch.device('cuda', args.gpu_id) torch.cuda.set_device(device) onnx_path = args.onnx_path or os.path.join(args.out_dir, 'video_backbone.onnx') engine_path = args.engine_path or os.path.join(args.out_dir, 'video_backbone.engine') os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True) os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True) if os.path.exists(onnx_path): print(f">>> ONNX already exists at {onnx_path}, skipping export.") n_outputs = 10 else: print(">>> Loading model ...") model = load_model(args.config, args.ckpt, device) print(">>> Exporting ONNX ...") with torch.no_grad(): n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len) del model torch.cuda.empty_cache() print(">>> Converting ONNX -> TensorRT engine ...") logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(os.path.abspath(onnx_path)): for i in range(parser.num_errors): print(f" ONNX parse error: {parser.get_error(i)}") raise RuntimeError("ONNX parsing failed") config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30) if args.fp16: config.set_flag(trt.BuilderFlag.FP16) # Tactic selection recorder = None if args.dump_tactics: recorder = TacticRecorder() config.algorithm_selector = recorder elif args.load_tactics: config.algorithm_selector = TacticForcer(args.load_tactics) engine_bytes = builder.build_serialized_network(network, config) if recorder and args.dump_tactics: recorder.save(args.dump_tactics) with open(engine_path, 'wb') as f: f.write(engine_bytes) print(f"\n>>> Done! Engine saved to {engine_path}") print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors") if __name__ == '__main__': main()