Trigger-Discipline: parallelize result plotting

This commit is contained in:
2026-04-24 10:04:57 +08:00
parent 7f603f189b
commit 45e3c725f9
5 changed files with 137 additions and 66 deletions

View File

@@ -9,9 +9,19 @@
################################################################## ##################################################################
################################################################## ##################################################################
## Print program introduction ## Guard against re-execution by multiprocessing child processes.
## Without this, using 'spawn' or 'forkserver' context would cause every
## worker to re-run the entire script.
if __name__ != '__main__':
import sys as _sys
_sys.exit(0)
##################################################################
## Print program introduction
import print_information import print_information
@@ -422,31 +432,36 @@ print( " Plotting the txt and binary results data from the AMSS-NCKU simulation
print( ) print( )
import plot_xiaoqu import plot_xiaoqu
import plot_GW_strain_amplitude_xiaoqu import plot_GW_strain_amplitude_xiaoqu
from parallel_plot_helper import run_plot_tasks_parallel
## Plot black hole trajectory
plot_xiaoqu.generate_puncture_orbit_plot( binary_results_directory, figure_directory ) plot_tasks = []
plot_xiaoqu.generate_puncture_orbit_plot3D( binary_results_directory, figure_directory )
## Plot black hole trajectory
## Plot black hole separation vs. time plot_tasks.append( ( plot_xiaoqu.generate_puncture_orbit_plot, (binary_results_directory, figure_directory) ) )
plot_xiaoqu.generate_puncture_distence_plot( binary_results_directory, figure_directory ) plot_tasks.append( ( plot_xiaoqu.generate_puncture_orbit_plot3D, (binary_results_directory, figure_directory) ) )
## Plot gravitational waveforms (psi4 and strain amplitude) ## Plot black hole separation vs. time
for i in range(input_data.Detector_Number): plot_tasks.append( ( plot_xiaoqu.generate_puncture_distence_plot, (binary_results_directory, figure_directory) ) )
plot_xiaoqu.generate_gravitational_wave_psi4_plot( binary_results_directory, figure_directory, i )
plot_GW_strain_amplitude_xiaoqu.generate_gravitational_wave_amplitude_plot( binary_results_directory, figure_directory, i ) ## Plot gravitational waveforms (psi4 and strain amplitude)
for i in range(input_data.Detector_Number):
## Plot ADM mass evolution plot_tasks.append( ( plot_xiaoqu.generate_gravitational_wave_psi4_plot, (binary_results_directory, figure_directory, i) ) )
for i in range(input_data.Detector_Number): plot_tasks.append( ( plot_GW_strain_amplitude_xiaoqu.generate_gravitational_wave_amplitude_plot, (binary_results_directory, figure_directory, i) ) )
plot_xiaoqu.generate_ADMmass_plot( binary_results_directory, figure_directory, i )
## Plot ADM mass evolution
## Plot Hamiltonian constraint violation over time for i in range(input_data.Detector_Number):
for i in range(input_data.grid_level): plot_tasks.append( ( plot_xiaoqu.generate_ADMmass_plot, (binary_results_directory, figure_directory, i) ) )
plot_xiaoqu.generate_constraint_check_plot( binary_results_directory, figure_directory, i )
## Plot Hamiltonian constraint violation over time
## Plot stored binary data for i in range(input_data.grid_level):
plot_xiaoqu.generate_binary_data_plot( binary_results_directory, figure_directory ) plot_tasks.append( ( plot_xiaoqu.generate_constraint_check_plot, (binary_results_directory, figure_directory, i) ) )
run_plot_tasks_parallel(plot_tasks)
## Plot stored binary data
plot_xiaoqu.generate_binary_data_plot( binary_results_directory, figure_directory )
print( ) print( )
print( f" This Program Cost = {elapsed_time} Seconds " ) print( f" This Program Cost = {elapsed_time} Seconds " )

12
parallel_plot_helper.py Normal file
View File

@@ -0,0 +1,12 @@
import multiprocessing
def run_plot_task(task):
func, args = task
return func(*args)
def run_plot_tasks_parallel(plot_tasks):
ctx = multiprocessing.get_context('fork')
with ctx.Pool() as pool:
pool.map(run_plot_task, plot_tasks)

View File

@@ -8,11 +8,13 @@
## ##
################################################# #################################################
import numpy ## numpy for array operations import numpy ## numpy for array operations
import scipy ## scipy for interpolation and signal processing import scipy ## scipy for interpolation and signal processing
import math import math
import matplotlib.pyplot as plt ## matplotlib for plotting import matplotlib
import os ## os for system/file operations matplotlib.use('Agg') ## use non-interactive backend for multiprocessing safety
import matplotlib.pyplot as plt ## matplotlib for plotting
import os ## os for system/file operations
import AMSS_NCKU_Input as input_data import AMSS_NCKU_Input as input_data

View File

@@ -6,17 +6,22 @@
## Author: Xiaoqu ## Author: Xiaoqu
## Dates: 2024/10/01 --- 2025/09/14 ## Dates: 2024/10/01 --- 2025/09/14
## ##
################################################# #################################################
import numpy ## Restrict OpenMP to one thread per process so that parallel
import scipy ## subprocess plotting does not multiply BLAS thread counts.
import matplotlib.pyplot as plt import os
from matplotlib.colors import LogNorm os.environ.setdefault("OMP_NUM_THREADS", "1")
from mpl_toolkits.mplot3d import Axes3D
## import torch import numpy
import AMSS_NCKU_Input as input_data import scipy
import matplotlib
import os matplotlib.use('Agg') ## use non-interactive backend for multiprocessing safety
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.mplot3d import Axes3D
## import torch
import AMSS_NCKU_Input as input_data
######################################################################################### #########################################################################################
@@ -92,9 +97,9 @@ def plot_binary_data( filename, binary_outdir, figure_outdir ):
#################################################################################### ####################################################################################
# Plot a single binary dataset (2D slices and 3D surface) # Plot a single binary dataset (2D slices and 3D surface)
def get_data_xy( Rmin, Rmax, n, data0, time, figure_title, figure_outdir ): def get_data_xy( Rmin, Rmax, n, data0, time, figure_title, figure_outdir ):
@@ -188,7 +193,15 @@ def get_data_xy( Rmin, Rmax, n, data0, time, figure_title, figure_outdir ):
plt.savefig( os.path.join(figure_surfaceplot_outdir, figure_title + " time = " + str(time) + " surface_plot.pdf") ) # save figure plt.savefig( os.path.join(figure_surfaceplot_outdir, figure_title + " time = " + str(time) + " surface_plot.pdf") ) # save figure
plt.close() plt.close()
return return
#################################################################################### ####################################################################################
## Allow standalone subprocess execution for parallel binary-data plotting.
if __name__ == '__main__':
import sys
if len(sys.argv) != 4:
print(f"Usage: {sys.argv[0]} <filename> <binary_outdir> <figure_outdir>")
sys.exit(1)
plot_binary_data(sys.argv[1], sys.argv[2], sys.argv[3])

View File

@@ -6,15 +6,20 @@
## 2024/10/01 --- 2025/09/14 ## 2024/10/01 --- 2025/09/14
## ##
################################################# #################################################
import numpy ## numpy for array operations import numpy ## numpy for array operations
import matplotlib.pyplot as plt ## matplotlib for plotting import matplotlib
from mpl_toolkits.mplot3d import Axes3D ## needed for 3D plots matplotlib.use('Agg') ## use non-interactive backend for multiprocessing safety
import glob import matplotlib.pyplot as plt ## matplotlib for plotting
import os ## operating system utilities from mpl_toolkits.mplot3d import Axes3D ## needed for 3D plots
import glob
import plot_binary_data import os ## operating system utilities
import AMSS_NCKU_Input as input_data
import plot_binary_data
import AMSS_NCKU_Input as input_data
import subprocess
import sys
import multiprocessing
# plt.rcParams['text.usetex'] = True ## enable LaTeX fonts in plots # plt.rcParams['text.usetex'] = True ## enable LaTeX fonts in plots
@@ -50,13 +55,37 @@ def generate_binary_data_plot( binary_outdir, figure_outdir ):
file_list.append(x) file_list.append(x)
print(x) print(x)
## Plot each file in the list ## Plot each file in parallel using subprocesses.
for filename in file_list: ## Each subprocess starts with BLAS thread limits in plot_binary_data.py.
print(filename) script = os.path.join( os.path.dirname(__file__), "plot_binary_data.py" )
plot_binary_data.plot_binary_data(filename, binary_outdir, figure_outdir) max_workers = min( multiprocessing.cpu_count(), len(file_list) ) if file_list else 0
print( ) running = []
print( " Binary Data Plot Has been Finished " ) failed = []
for filename in file_list:
print(filename)
proc = subprocess.Popen(
[sys.executable, script, filename, binary_outdir, figure_outdir],
)
running.append( (proc, filename) )
if len(running) >= max_workers:
p, fn = running.pop(0)
p.wait()
if p.returncode != 0:
failed.append(fn)
for p, fn in running:
p.wait()
if p.returncode != 0:
failed.append(fn)
if failed:
print( " WARNING: the following binary data plots failed:" )
for fn in failed:
print( " ", fn )
print( )
print( " Binary Data Plot Has been Finished " )
print( ) print( )
return return