Files
unifolm-world-model-action/scripts/export_trt.py
2026-05-17 15:07:06 +08:00

210 lines
7.8 KiB
Python

"""Export video UNet backbone to ONNX, then convert to TensorRT engine.
Usage:
python scripts/export_trt.py \
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
--config configs/inference/world_model_interaction.yaml \
--out_dir trt_engines
python scripts/export_trt.py \
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
--config configs/inference/world_model_interaction.yaml \
--engine_path trt_engines/video_backbone_multigpu.engine \
--onnx_path trt_engines/video_backbone_multigpu.onnx
"""
import os
import sys
import argparse
import json
import torch
import tensorrt as trt
from omegaconf import OmegaConf
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.trt_utils import export_backbone_onnx
class TacticRecorder(trt.IAlgorithmSelector):
"""Pass 1: record all candidate tactics and the auto-selected winner."""
def __init__(self):
super().__init__()
self.records = {} # layer_name -> {candidates: [...], selected: ...}
def select_algorithms(self, ctx, choices):
name = ctx.name
# Collect input/output shapes
inputs = []
for j in range(ctx.num_inputs):
try:
inputs.append([int(d) for d in ctx.get_shape(j)])
except Exception:
inputs.append(None)
outputs = []
for j in range(ctx.num_outputs):
try:
outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)])
except Exception:
outputs.append(None)
self.records[name] = {
"input_shapes": inputs,
"output_shapes": outputs,
"candidates": [],
"selected": None,
}
for i, c in enumerate(choices):
v = c.algorithm_variant
self.records[name]["candidates"].append({
"index": i,
"implementation": v.implementation,
"tactic": v.tactic,
"timing_msec": c.timing_msec,
"workspace_size": c.workspace_size,
})
# return all indices -> let TRT auto-pick the fastest
return list(range(len(choices)))
def report_algorithms(self, ctx, choices):
# Both ctx and choices are lists in report_algorithms
for c, alg in zip(ctx, choices):
name = c.name
if name in self.records:
v = alg.algorithm_variant
self.records[name]["selected"] = {
"implementation": v.implementation,
"tactic": v.tactic,
"timing_msec": alg.timing_msec,
"workspace_size": alg.workspace_size,
}
def save(self, path):
with open(path, "w") as f:
json.dump(self.records, f, indent=2)
print(f">>> Tactic info saved to {path} ({len(self.records)} layers)")
class TacticForcer(trt.IAlgorithmSelector):
"""Pass 2: force user-specified tactics from a JSON file."""
def __init__(self, path):
super().__init__()
with open(path) as f:
self.overrides = json.load(f)
n = sum(1 for v in self.overrides.values() if v.get("force"))
print(f">>> Loaded tactic overrides: {n} layers with 'force' set")
def select_algorithms(self, ctx, choices):
name = ctx.name
override = self.overrides.get(name)
if override and override.get("force"):
target_impl = override["force"]["implementation"]
target_tactic = override["force"]["tactic"]
for i, c in enumerate(choices):
v = c.algorithm_variant
if v.implementation == target_impl and v.tactic == target_tactic:
return [i]
print(f" WARN: forced tactic not found for {name}, using auto")
return list(range(len(choices)))
def report_algorithms(self, ctx, choices):
pass
def load_model(config_path, ckpt_path, device):
if ckpt_path.endswith('.prepared.pt'):
model = torch.load(ckpt_path, map_location='cpu')
else:
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict, strict=False)
model.eval().to(device)
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', required=True)
parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml')
parser.add_argument('--out_dir', default='trt_engines')
parser.add_argument('--gpu_id',
type=int,
default=0,
help='CUDA device id used for ONNX export and TRT build.')
parser.add_argument('--onnx_path',
default=None,
help='Optional explicit ONNX output path. Overrides --out_dir default name.')
parser.add_argument('--engine_path',
default=None,
help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.')
parser.add_argument('--context_len', type=int, default=95)
parser.add_argument('--fp16', action='store_true', default=True)
parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON')
parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON')
args = parser.parse_args()
device = torch.device('cuda', args.gpu_id)
torch.cuda.set_device(device)
onnx_path = args.onnx_path or os.path.join(args.out_dir,
'video_backbone.onnx')
engine_path = args.engine_path or os.path.join(args.out_dir,
'video_backbone.engine')
os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True)
os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True)
if os.path.exists(onnx_path):
print(f">>> ONNX already exists at {onnx_path}, skipping export.")
n_outputs = 10
else:
print(">>> Loading model ...")
model = load_model(args.config, args.ckpt, device)
print(">>> Exporting ONNX ...")
with torch.no_grad():
n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len)
del model
torch.cuda.empty_cache()
print(">>> Converting ONNX -> TensorRT engine ...")
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(os.path.abspath(onnx_path)):
for i in range(parser.num_errors):
print(f" ONNX parse error: {parser.get_error(i)}")
raise RuntimeError("ONNX parsing failed")
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
if args.fp16:
config.set_flag(trt.BuilderFlag.FP16)
# Tactic selection
recorder = None
if args.dump_tactics:
recorder = TacticRecorder()
config.algorithm_selector = recorder
elif args.load_tactics:
config.algorithm_selector = TacticForcer(args.load_tactics)
engine_bytes = builder.build_serialized_network(network, config)
if recorder and args.dump_tactics:
recorder.save(args.dump_tactics)
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
print(f"\n>>> Done! Engine saved to {engine_path}")
print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors")
if __name__ == '__main__':
main()