From 85afe00fc5a92eb7ce53a14bb5a732a95d6aecd5 Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Wed, 11 Feb 2026 16:19:17 +0800 Subject: [PATCH] Merge plotting optimizations from chb-copilot-test - Implement multiprocessing-based parallel plotting - Add parallel_plot_helper.py for concurrent plot task execution - Use matplotlib 'Agg' backend for multiprocessing safety - Set OMP_NUM_THREADS=1 to prevent BLAS thread explosion - Use subprocess for binary data plots to avoid thread conflicts - Add fork bomb protection in main program This merge only includes plotting improvements and excludes MPI communication changes to preserve existing optimizations. Co-Authored-By: Claude Sonnet 4.5 --- AMSS_NCKU_Program.py | 27 +++++++++++++++------ parallel_plot_helper.py | 29 ++++++++++++++++++++++ plot_GW_strain_amplitude_xiaoqu.py | 2 ++ plot_binary_data.py | 27 +++++++++++++++++++-- plot_xiaoqu.py | 39 ++++++++++++++++++++++++++++-- 5 files changed, 113 insertions(+), 11 deletions(-) create mode 100644 parallel_plot_helper.py diff --git a/AMSS_NCKU_Program.py b/AMSS_NCKU_Program.py index 46d15f1..6a7952a 100755 --- a/AMSS_NCKU_Program.py +++ b/AMSS_NCKU_Program.py @@ -8,6 +8,14 @@ ## ################################################################## +## Guard against re-execution by multiprocessing child processes. +## Without this, using 'spawn' or 'forkserver' context would cause every +## worker to re-run the entire script, spawning exponentially more +## workers (fork bomb). +if __name__ != '__main__': + import sys as _sys + _sys.exit(0) + ################################################################## @@ -424,26 +432,31 @@ print( import plot_xiaoqu import plot_GW_strain_amplitude_xiaoqu +from parallel_plot_helper import run_plot_tasks_parallel + +plot_tasks = [] ## Plot black hole trajectory -plot_xiaoqu.generate_puncture_orbit_plot( binary_results_directory, figure_directory ) -plot_xiaoqu.generate_puncture_orbit_plot3D( binary_results_directory, figure_directory ) +plot_tasks.append( ( plot_xiaoqu.generate_puncture_orbit_plot, (binary_results_directory, figure_directory) ) ) +plot_tasks.append( ( plot_xiaoqu.generate_puncture_orbit_plot3D, (binary_results_directory, figure_directory) ) ) ## Plot black hole separation vs. time -plot_xiaoqu.generate_puncture_distence_plot( binary_results_directory, figure_directory ) +plot_tasks.append( ( plot_xiaoqu.generate_puncture_distence_plot, (binary_results_directory, figure_directory) ) ) ## Plot gravitational waveforms (psi4 and strain amplitude) for i in range(input_data.Detector_Number): - plot_xiaoqu.generate_gravitational_wave_psi4_plot( binary_results_directory, figure_directory, i ) - plot_GW_strain_amplitude_xiaoqu.generate_gravitational_wave_amplitude_plot( binary_results_directory, figure_directory, i ) + plot_tasks.append( ( plot_xiaoqu.generate_gravitational_wave_psi4_plot, (binary_results_directory, figure_directory, i) ) ) + plot_tasks.append( ( plot_GW_strain_amplitude_xiaoqu.generate_gravitational_wave_amplitude_plot, (binary_results_directory, figure_directory, i) ) ) ## Plot ADM mass evolution for i in range(input_data.Detector_Number): - plot_xiaoqu.generate_ADMmass_plot( binary_results_directory, figure_directory, i ) + plot_tasks.append( ( plot_xiaoqu.generate_ADMmass_plot, (binary_results_directory, figure_directory, i) ) ) ## Plot Hamiltonian constraint violation over time for i in range(input_data.grid_level): - plot_xiaoqu.generate_constraint_check_plot( binary_results_directory, figure_directory, i ) + plot_tasks.append( ( plot_xiaoqu.generate_constraint_check_plot, (binary_results_directory, figure_directory, i) ) ) + +run_plot_tasks_parallel(plot_tasks) ## Plot stored binary data plot_xiaoqu.generate_binary_data_plot( binary_results_directory, figure_directory ) diff --git a/parallel_plot_helper.py b/parallel_plot_helper.py new file mode 100644 index 0000000..c1168fa --- /dev/null +++ b/parallel_plot_helper.py @@ -0,0 +1,29 @@ +import multiprocessing + +def run_plot_task(task): + """Execute a single plotting task. + + Parameters + ---------- + task : tuple + A tuple of (function, args_tuple) where function is a callable + plotting function and args_tuple contains its arguments. + """ + func, args = task + return func(*args) + + +def run_plot_tasks_parallel(plot_tasks): + """Execute a list of independent plotting tasks in parallel. + + Uses the 'fork' context to create worker processes so that the main + script is NOT re-imported/re-executed in child processes. + + Parameters + ---------- + plot_tasks : list of tuples + Each element is (function, args_tuple). + """ + ctx = multiprocessing.get_context('fork') + with ctx.Pool() as pool: + pool.map(run_plot_task, plot_tasks) diff --git a/plot_GW_strain_amplitude_xiaoqu.py b/plot_GW_strain_amplitude_xiaoqu.py index 739f3d4..cf7b098 100755 --- a/plot_GW_strain_amplitude_xiaoqu.py +++ b/plot_GW_strain_amplitude_xiaoqu.py @@ -11,6 +11,8 @@ import numpy ## numpy for array operations import scipy ## scipy for interpolation and signal processing import math +import matplotlib +matplotlib.use('Agg') ## use non-interactive backend for multiprocessing safety import matplotlib.pyplot as plt ## matplotlib for plotting import os ## os for system/file operations diff --git a/plot_binary_data.py b/plot_binary_data.py index 0694f4f..2aca1c7 100755 --- a/plot_binary_data.py +++ b/plot_binary_data.py @@ -8,16 +8,23 @@ ## ################################################# +## Restrict OpenMP to one thread per process so that running +## many workers in parallel does not create an O(workers * BLAS_threads) +## thread explosion. The variable MUST be set before numpy/scipy +## are imported, because the BLAS library reads them only at load time. +import os +os.environ.setdefault("OMP_NUM_THREADS", "1") + import numpy import scipy +import matplotlib +matplotlib.use('Agg') ## use non-interactive backend for multiprocessing safety import matplotlib.pyplot as plt from matplotlib.colors import LogNorm from mpl_toolkits.mplot3d import Axes3D ## import torch import AMSS_NCKU_Input as input_data -import os - ######################################################################################### @@ -192,3 +199,19 @@ def get_data_xy( Rmin, Rmax, n, data0, time, figure_title, figure_outdir ): #################################################################################### + +#################################################################################### +## Allow this module to be run as a standalone script so that each +## binary-data plot can be executed in a fresh subprocess whose BLAS +## environment variables (set above) take effect before numpy loads. +## +## Usage: python3 plot_binary_data.py +#################################################################################### + +if __name__ == '__main__': + import sys + if len(sys.argv) != 4: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + plot_binary_data(sys.argv[1], sys.argv[2], sys.argv[3]) + diff --git a/plot_xiaoqu.py b/plot_xiaoqu.py index 7711d5a..47970cf 100755 --- a/plot_xiaoqu.py +++ b/plot_xiaoqu.py @@ -8,6 +8,8 @@ ################################################# import numpy ## numpy for array operations +import matplotlib +matplotlib.use('Agg') ## use non-interactive backend for multiprocessing safety import matplotlib.pyplot as plt ## matplotlib for plotting from mpl_toolkits.mplot3d import Axes3D ## needed for 3D plots import glob @@ -15,6 +17,9 @@ import os ## operating system utilities import plot_binary_data import AMSS_NCKU_Input as input_data +import subprocess +import sys +import multiprocessing # plt.rcParams['text.usetex'] = True ## enable LaTeX fonts in plots @@ -50,10 +55,40 @@ def generate_binary_data_plot( binary_outdir, figure_outdir ): file_list.append(x) print(x) - ## Plot each file in the list + ## Plot each file in parallel using subprocesses. + ## Each subprocess is a fresh Python process where the BLAS thread-count + ## environment variables (set at the top of plot_binary_data.py) take + ## effect before numpy is imported. This avoids the thread explosion + ## that occurs when multiprocessing.Pool with 'fork' context inherits + ## already-initialized multi-threaded BLAS from the parent. + script = os.path.join( os.path.dirname(__file__), "plot_binary_data.py" ) + max_workers = min( multiprocessing.cpu_count(), len(file_list) ) if file_list else 0 + + running = [] + failed = [] for filename in file_list: print(filename) - plot_binary_data.plot_binary_data(filename, binary_outdir, figure_outdir) + proc = subprocess.Popen( + [sys.executable, script, filename, binary_outdir, figure_outdir], + ) + running.append( (proc, filename) ) + ## Keep at most max_workers subprocesses active at a time + if len(running) >= max_workers: + p, fn = running.pop(0) + p.wait() + if p.returncode != 0: + failed.append(fn) + + ## Wait for all remaining subprocesses to finish + for p, fn in running: + p.wait() + if p.returncode != 0: + failed.append(fn) + + if failed: + print( " WARNING: the following binary data plots failed:" ) + for fn in failed: + print( " ", fn ) print( ) print( " Binary Data Plot Has been Finished " )