210 lines
7.8 KiB
Python
210 lines
7.8 KiB
Python
"""Export video UNet backbone to ONNX, then convert to TensorRT engine.
|
|
|
|
Usage:
|
|
python scripts/export_trt.py \
|
|
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
|
--config configs/inference/world_model_interaction.yaml \
|
|
--out_dir trt_engines
|
|
|
|
python scripts/export_trt.py \
|
|
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
|
--config configs/inference/world_model_interaction.yaml \
|
|
--engine_path trt_engines/video_backbone_multigpu.engine \
|
|
--onnx_path trt_engines/video_backbone_multigpu.onnx
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import json
|
|
|
|
import torch
|
|
import tensorrt as trt
|
|
from omegaconf import OmegaConf
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
from unifolm_wma.trt_utils import export_backbone_onnx
|
|
|
|
|
|
class TacticRecorder(trt.IAlgorithmSelector):
|
|
"""Pass 1: record all candidate tactics and the auto-selected winner."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.records = {} # layer_name -> {candidates: [...], selected: ...}
|
|
|
|
def select_algorithms(self, ctx, choices):
|
|
name = ctx.name
|
|
# Collect input/output shapes
|
|
inputs = []
|
|
for j in range(ctx.num_inputs):
|
|
try:
|
|
inputs.append([int(d) for d in ctx.get_shape(j)])
|
|
except Exception:
|
|
inputs.append(None)
|
|
outputs = []
|
|
for j in range(ctx.num_outputs):
|
|
try:
|
|
outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)])
|
|
except Exception:
|
|
outputs.append(None)
|
|
self.records[name] = {
|
|
"input_shapes": inputs,
|
|
"output_shapes": outputs,
|
|
"candidates": [],
|
|
"selected": None,
|
|
}
|
|
for i, c in enumerate(choices):
|
|
v = c.algorithm_variant
|
|
self.records[name]["candidates"].append({
|
|
"index": i,
|
|
"implementation": v.implementation,
|
|
"tactic": v.tactic,
|
|
"timing_msec": c.timing_msec,
|
|
"workspace_size": c.workspace_size,
|
|
})
|
|
# return all indices -> let TRT auto-pick the fastest
|
|
return list(range(len(choices)))
|
|
|
|
def report_algorithms(self, ctx, choices):
|
|
# Both ctx and choices are lists in report_algorithms
|
|
for c, alg in zip(ctx, choices):
|
|
name = c.name
|
|
if name in self.records:
|
|
v = alg.algorithm_variant
|
|
self.records[name]["selected"] = {
|
|
"implementation": v.implementation,
|
|
"tactic": v.tactic,
|
|
"timing_msec": alg.timing_msec,
|
|
"workspace_size": alg.workspace_size,
|
|
}
|
|
|
|
def save(self, path):
|
|
with open(path, "w") as f:
|
|
json.dump(self.records, f, indent=2)
|
|
print(f">>> Tactic info saved to {path} ({len(self.records)} layers)")
|
|
|
|
|
|
class TacticForcer(trt.IAlgorithmSelector):
|
|
"""Pass 2: force user-specified tactics from a JSON file."""
|
|
|
|
def __init__(self, path):
|
|
super().__init__()
|
|
with open(path) as f:
|
|
self.overrides = json.load(f)
|
|
n = sum(1 for v in self.overrides.values() if v.get("force"))
|
|
print(f">>> Loaded tactic overrides: {n} layers with 'force' set")
|
|
|
|
def select_algorithms(self, ctx, choices):
|
|
name = ctx.name
|
|
override = self.overrides.get(name)
|
|
if override and override.get("force"):
|
|
target_impl = override["force"]["implementation"]
|
|
target_tactic = override["force"]["tactic"]
|
|
for i, c in enumerate(choices):
|
|
v = c.algorithm_variant
|
|
if v.implementation == target_impl and v.tactic == target_tactic:
|
|
return [i]
|
|
print(f" WARN: forced tactic not found for {name}, using auto")
|
|
return list(range(len(choices)))
|
|
|
|
def report_algorithms(self, ctx, choices):
|
|
pass
|
|
|
|
|
|
def load_model(config_path, ckpt_path, device):
|
|
if ckpt_path.endswith('.prepared.pt'):
|
|
model = torch.load(ckpt_path, map_location='cpu')
|
|
else:
|
|
config = OmegaConf.load(config_path)
|
|
model = instantiate_from_config(config.model)
|
|
state_dict = torch.load(ckpt_path, map_location='cpu')
|
|
if 'state_dict' in state_dict:
|
|
state_dict = state_dict['state_dict']
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model.eval().to(device)
|
|
return model
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--ckpt', required=True)
|
|
parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml')
|
|
parser.add_argument('--out_dir', default='trt_engines')
|
|
parser.add_argument('--gpu_id',
|
|
type=int,
|
|
default=0,
|
|
help='CUDA device id used for ONNX export and TRT build.')
|
|
parser.add_argument('--onnx_path',
|
|
default=None,
|
|
help='Optional explicit ONNX output path. Overrides --out_dir default name.')
|
|
parser.add_argument('--engine_path',
|
|
default=None,
|
|
help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.')
|
|
parser.add_argument('--context_len', type=int, default=95)
|
|
parser.add_argument('--fp16', action='store_true', default=True)
|
|
parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON')
|
|
parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON')
|
|
args = parser.parse_args()
|
|
device = torch.device('cuda', args.gpu_id)
|
|
torch.cuda.set_device(device)
|
|
|
|
onnx_path = args.onnx_path or os.path.join(args.out_dir,
|
|
'video_backbone.onnx')
|
|
engine_path = args.engine_path or os.path.join(args.out_dir,
|
|
'video_backbone.engine')
|
|
os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True)
|
|
os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True)
|
|
|
|
if os.path.exists(onnx_path):
|
|
print(f">>> ONNX already exists at {onnx_path}, skipping export.")
|
|
n_outputs = 10
|
|
else:
|
|
print(">>> Loading model ...")
|
|
model = load_model(args.config, args.ckpt, device)
|
|
print(">>> Exporting ONNX ...")
|
|
with torch.no_grad():
|
|
n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len)
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
|
|
print(">>> Converting ONNX -> TensorRT engine ...")
|
|
logger = trt.Logger(trt.Logger.WARNING)
|
|
builder = trt.Builder(logger)
|
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
|
parser = trt.OnnxParser(network, logger)
|
|
|
|
if not parser.parse_from_file(os.path.abspath(onnx_path)):
|
|
for i in range(parser.num_errors):
|
|
print(f" ONNX parse error: {parser.get_error(i)}")
|
|
raise RuntimeError("ONNX parsing failed")
|
|
|
|
config = builder.create_builder_config()
|
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
|
|
if args.fp16:
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
|
|
# Tactic selection
|
|
recorder = None
|
|
if args.dump_tactics:
|
|
recorder = TacticRecorder()
|
|
config.algorithm_selector = recorder
|
|
elif args.load_tactics:
|
|
config.algorithm_selector = TacticForcer(args.load_tactics)
|
|
|
|
engine_bytes = builder.build_serialized_network(network, config)
|
|
|
|
if recorder and args.dump_tactics:
|
|
recorder.save(args.dump_tactics)
|
|
|
|
with open(engine_path, 'wb') as f:
|
|
f.write(engine_bytes)
|
|
|
|
print(f"\n>>> Done! Engine saved to {engine_path}")
|
|
print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|