赛前稳定版
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
208
tools/inspect_contraction_tree.py
Normal file
208
tools/inspect_contraction_tree.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Inspect cotengra contraction trees for dominant torch matmul shapes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import math
|
||||
import pickle
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _prod(values):
|
||||
out = 1
|
||||
for value in values:
|
||||
out *= int(value)
|
||||
return out
|
||||
|
||||
|
||||
def _broadcast_batch(a_batch, b_batch):
|
||||
if a_batch == b_batch:
|
||||
return _prod(a_batch)
|
||||
if not a_batch:
|
||||
return _prod(b_batch)
|
||||
if not b_batch:
|
||||
return _prod(a_batch)
|
||||
|
||||
ndim = max(len(a_batch), len(b_batch))
|
||||
a_batch = (1,) * (ndim - len(a_batch)) + tuple(a_batch)
|
||||
b_batch = (1,) * (ndim - len(b_batch)) + tuple(b_batch)
|
||||
return _prod(max(a, b) for a, b in zip(a_batch, b_batch))
|
||||
|
||||
|
||||
def _load_tree(path, index):
|
||||
with Path(path).open("rb") as f:
|
||||
payload = pickle.load(f)
|
||||
trees = payload["trees"] if isinstance(payload, dict) else payload
|
||||
if not isinstance(trees, (list, tuple)):
|
||||
trees = [trees]
|
||||
return trees[index]
|
||||
|
||||
|
||||
def _analyze_tree(tree):
|
||||
contract_mod = importlib.import_module("cotengra.contract")
|
||||
contractions = contract_mod.extract_contractions(tree)
|
||||
size_dict = tree.size_dict
|
||||
ops = []
|
||||
counts = Counter()
|
||||
|
||||
for op_index, (parent, left, right, tdot, arg, perm) in enumerate(contractions):
|
||||
if left is None and right is None:
|
||||
counts["preprocess"] += 1
|
||||
continue
|
||||
|
||||
left_inds = tree.get_inds(left)
|
||||
right_inds = tree.get_inds(right)
|
||||
parent_inds = tree.get_inds(parent)
|
||||
left_shape = tuple(size_dict[ix] for ix in left_inds)
|
||||
right_shape = tuple(size_dict[ix] for ix in right_inds)
|
||||
|
||||
if tdot:
|
||||
parsed = contract_mod._parse_tensordot_axes_to_matmul(
|
||||
arg,
|
||||
left_shape,
|
||||
right_shape,
|
||||
)
|
||||
else:
|
||||
parsed = contract_mod._parse_eq_to_batch_matmul(
|
||||
arg,
|
||||
left_shape,
|
||||
right_shape,
|
||||
)
|
||||
|
||||
(
|
||||
_eq_a,
|
||||
_eq_b,
|
||||
new_shape_a,
|
||||
new_shape_b,
|
||||
_new_shape_ab,
|
||||
_perm_ab,
|
||||
pure_multiplication,
|
||||
) = parsed
|
||||
|
||||
matmul_shape = None
|
||||
matmul_flops = 0
|
||||
if pure_multiplication:
|
||||
kind = "mul"
|
||||
else:
|
||||
a_shape = tuple(new_shape_a or left_shape)
|
||||
b_shape = tuple(new_shape_b or right_shape)
|
||||
batch = _broadcast_batch(a_shape[:-2], b_shape[:-2])
|
||||
m, k, n = int(a_shape[-2]), int(a_shape[-1]), int(b_shape[-1])
|
||||
kind = "mm" if batch == 1 else "bmm"
|
||||
matmul_shape = (batch, m, k, n)
|
||||
matmul_flops = batch * m * k * n
|
||||
|
||||
tree_flops = int(tree.get_flops(parent))
|
||||
out_size = int(tree.get_size(parent))
|
||||
ops.append(
|
||||
{
|
||||
"index": op_index,
|
||||
"kind": kind,
|
||||
"matmul_shape": matmul_shape,
|
||||
"matmul_flops": matmul_flops,
|
||||
"tree_flops": tree_flops,
|
||||
"out_size": out_size,
|
||||
"left_shape": left_shape,
|
||||
"right_shape": right_shape,
|
||||
"left_rank": len(left_inds),
|
||||
"right_rank": len(right_inds),
|
||||
"out_rank": len(parent_inds),
|
||||
"perm": perm,
|
||||
}
|
||||
)
|
||||
counts[kind] += 1
|
||||
|
||||
return contractions, ops, counts
|
||||
|
||||
|
||||
def _format_log(value, base):
|
||||
return "-inf" if value <= 0 else f"{math.log(value, base):.3f}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("tree", help="Pickle file containing one tree or {'trees': [...]}.")
|
||||
parser.add_argument("--index", type=int, default=0, help="Tree index in the file.")
|
||||
parser.add_argument("--top", type=int, default=20, help="Number of top ops to print.")
|
||||
parser.add_argument(
|
||||
"--dtype-bytes",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Bytes per element for memory estimates, for example 8 for complex64.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tree = _load_tree(args.tree, args.index)
|
||||
contractions, ops, counts = _analyze_tree(tree)
|
||||
nslices = int(getattr(tree, "multiplicity", 1))
|
||||
per_slice_flops = sum(op["tree_flops"] for op in ops)
|
||||
per_slice_write = sum(op["out_size"] for op in ops)
|
||||
max_out = max((op["out_size"] for op in ops), default=0)
|
||||
all_flops = per_slice_flops * nslices
|
||||
all_write = per_slice_write * nslices
|
||||
|
||||
print(f"tree={args.tree} index={args.index}")
|
||||
print(
|
||||
"summary "
|
||||
f"slices={nslices} contractions={len(contractions)} "
|
||||
f"counts={dict(counts)}"
|
||||
)
|
||||
print(
|
||||
"per_slice "
|
||||
f"log10_flops={_format_log(per_slice_flops, 10)} "
|
||||
f"log10_write={_format_log(per_slice_write, 10)} "
|
||||
f"log2_max_output={_format_log(max_out, 2)} "
|
||||
f"max_output_gib={max_out * args.dtype_bytes / 1024**3:.6g}"
|
||||
)
|
||||
print(
|
||||
"all_slices "
|
||||
f"log10_flops={_format_log(all_flops, 10)} "
|
||||
f"log10_write={_format_log(all_write, 10)}"
|
||||
)
|
||||
|
||||
print(f"\ntop_{args.top}_ops_by_flops")
|
||||
for op in sorted(ops, key=lambda item: item["tree_flops"], reverse=True)[: args.top]:
|
||||
print(
|
||||
f"op={op['index']} kind={op['kind']} "
|
||||
f"flops={op['tree_flops']:.6e} out={op['out_size']:.6e} "
|
||||
f"matmul={op['matmul_shape']} "
|
||||
f"ranks=({op['left_rank']},{op['right_rank']}->{op['out_rank']}) "
|
||||
f"lhs={op['left_shape']} rhs={op['right_shape']}"
|
||||
)
|
||||
|
||||
by_shape = defaultdict(lambda: [0, 0, 0])
|
||||
for op in ops:
|
||||
shape = op["matmul_shape"]
|
||||
if shape is None:
|
||||
continue
|
||||
by_shape[shape][0] += 1
|
||||
by_shape[shape][1] += op["tree_flops"]
|
||||
by_shape[shape][2] += op["out_size"]
|
||||
|
||||
print(f"\ntop_{args.top}_matmul_shapes_by_flops")
|
||||
for shape, (count, flops, out_size) in sorted(
|
||||
by_shape.items(),
|
||||
key=lambda item: item[1][1],
|
||||
reverse=True,
|
||||
)[: args.top]:
|
||||
print(
|
||||
f"shape={shape} count={count} "
|
||||
f"flops={flops:.6e} output={out_size:.6e}"
|
||||
)
|
||||
|
||||
print(f"\ntop_{args.top}_matmul_shapes_by_count")
|
||||
for shape, (count, flops, out_size) in sorted(
|
||||
by_shape.items(),
|
||||
key=lambda item: item[1][0],
|
||||
reverse=True,
|
||||
)[: args.top]:
|
||||
print(
|
||||
f"shape={shape} count={count} "
|
||||
f"flops={flops:.6e} output={out_size:.6e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user