import numpy as np from qibotn.parallel import _split_repeats, contract_tree_slices, mpi_slice_plan def test_mpi_slice_plan_block_balances_contiguous_ranges(): plans = [mpi_slice_plan(10, rank, 4, assignment="block") for rank in range(4)] assert [plan.indices for plan in plans] == [ (0, 1, 2), (3, 4, 5), (6, 7), (8, 9), ] def test_mpi_slice_plan_cyclic_balances_round_robin(): plans = [mpi_slice_plan(10, rank, 4, assignment="cyclic") for rank in range(4)] assert [plan.indices for plan in plans] == [ (0, 4, 8), (1, 5, 9), (2, 6), (3, 7), ] class DummyTree: def contract_slice(self, arrays, i, backend=None): return arrays[0] * (i + 1) def test_contract_tree_slices_sums_numpy_slices(): result = contract_tree_slices( DummyTree(), [np.asarray([2.0 + 0.0j])], (0, 2, 3), backend="numpy", ) np.testing.assert_allclose(result, np.asarray([16.0 + 0.0j])) def test_split_repeats_balances_workers(): assert _split_repeats(10, 4) == [3, 3, 2, 2] assert _split_repeats(2, 4) == [1, 1]