#!/usr/bin/env python """Probe MPI rank placement and whether torch CPU ops use multiple threads. Run this under mpirun/mpiexec to check: * which CPUs each rank is allowed to run on, * whether torch sees the requested intra-op thread count, and * whether a large CPU tensor op actually consumes more CPU time than wall time. The script is intentionally small and self-contained so it can be used to debug MPI launcher affinity and torch OpenMP behavior independently from the TN code path. """ from __future__ import annotations import argparse import os import socket import time from pathlib import Path from mpi4py import MPI def _dtype_from_name(name): import torch mapping = { "float32": torch.float32, "float64": torch.float64, "complex64": torch.complex64, "complex128": torch.complex128, } return mapping[name] def _make_tensor(shape, dtype): import torch if dtype in (torch.complex64, torch.complex128): base = torch.float32 if dtype == torch.complex64 else torch.float64 return torch.complex( torch.randn(shape, dtype=base), torch.randn(shape, dtype=base), ) return torch.randn(shape, dtype=dtype) def _bench(label, fn, iters, warmup=2): for _ in range(warmup): fn() start_wall = time.perf_counter() start_cpu = time.process_time() checksum = 0.0 for _ in range(iters): value = fn() checksum += float(value) wall = time.perf_counter() - start_wall cpu = time.process_time() - start_cpu ratio = cpu / wall if wall > 0 else float("inf") print( f"{label} wall={wall:.3f}s cpu={cpu:.3f}s cpu_over_wall={ratio:.2f} " f"checksum={checksum:.6e}", flush=True, ) def _visible_numa_nodes(): nodes = [] for path in sorted(Path("/sys/devices/system/node").glob("node[0-9]*")): cpulist = path / "cpulist" if cpulist.exists(): nodes.append(f"{path.name}:{cpulist.read_text(encoding='utf-8').strip()}") return ",".join(nodes) if nodes else "unknown" def _dtype_nbytes(name): return { "float32": 4, "float64": 8, "complex64": 8, "complex128": 16, }[name] def _format_gib(nbytes): return f"{nbytes / (1024 ** 3):.2f}GiB" def main(): parser = argparse.ArgumentParser() parser.add_argument("--threads", type=int, default=48) parser.add_argument("--n", type=int, default=4096) parser.add_argument("--iters", type=int, default=4) parser.add_argument("--dtype", choices=("float32", "float64", "complex64", "complex128"), default="float32") parser.add_argument("--op", choices=("matmul", "tensordot", "both"), default="both") parser.add_argument( "--affinity-only", action="store_true", help="Print MPI/torch placement diagnostics without allocating tensors.", ) args = parser.parse_args() os.environ.setdefault("OMP_NUM_THREADS", str(args.threads)) os.environ.setdefault("MKL_NUM_THREADS", str(args.threads)) os.environ.setdefault("OMP_PROC_BIND", "close") os.environ.setdefault("OMP_PLACES", "cores") import torch comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() torch.set_num_threads(args.threads) try: torch.set_num_interop_threads(1) except Exception: pass dtype = _dtype_from_name(args.dtype) affinity = sorted(os.sched_getaffinity(0)) allowed_list = "" try: with open("/proc/self/status", encoding="utf-8") as f: for line in f: if line.startswith("Cpus_allowed_list:"): allowed_list = line.split(":", 1)[1].strip() break except OSError: pass print( f"rank={rank}/{size} host={socket.gethostname()} pid={os.getpid()} " f"affinity_len={len(affinity)} allowed={allowed_list} " f"torch_threads={torch.get_num_threads()} " f"torch_interop={torch.get_num_interop_threads()} " f"OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')} " f"MKL_NUM_THREADS={os.environ.get('MKL_NUM_THREADS')} " f"OMP_PROC_BIND={os.environ.get('OMP_PROC_BIND')} " f"OMP_PLACES={os.environ.get('OMP_PLACES')} " f"visible_numa={_visible_numa_nodes()}", flush=True, ) if rank == 0: print(torch.__config__.parallel_info(), flush=True) input_bytes = args.n * args.n * _dtype_nbytes(args.dtype) min_live_bytes = 3 * input_bytes print( f"matrix_n={args.n} dtype={args.dtype} " f"one_matrix={_format_gib(input_bytes)} " f"approx_min_live_per_rank={_format_gib(min_live_bytes)} " f"approx_min_live_all_ranks={_format_gib(min_live_bytes * size)}", flush=True, ) comm.Barrier() if args.affinity_only: return a = _make_tensor((args.n, args.n), dtype) b = _make_tensor((args.n, args.n), dtype) def run_matmul(): value = (a @ b).sum() return value.real.item() if value.is_complex() else value.item() def run_tensordot(): value = torch.tensordot(a, b, dims=1) value = value.sum() return value.real.item() if value.is_complex() else value.item() if args.op in ("matmul", "both"): _bench("matmul", run_matmul, args.iters) if args.op in ("tensordot", "both"): _bench("tensordot", run_tensordot, args.iters) if __name__ == "__main__": main()