From 06fa6433653924c3ae1f816255397aa324f2db58 Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Thu, 9 Apr 2026 15:06:11 +0800 Subject: [PATCH] Refine batched CUDA interpolation kernel --- AMSS_NCKU_source/bssn_cuda_ops.cu | 291 +++++++++++++++--------------- 1 file changed, 144 insertions(+), 147 deletions(-) diff --git a/AMSS_NCKU_source/bssn_cuda_ops.cu b/AMSS_NCKU_source/bssn_cuda_ops.cu index e292efe..e8f6aa4 100644 --- a/AMSS_NCKU_source/bssn_cuda_ops.cu +++ b/AMSS_NCKU_source/bssn_cuda_ops.cu @@ -98,6 +98,20 @@ inline bool ensure_capacity(CachedIntBuffer &buffer, size_t bytes) return true; } +inline bool copy_to_device(CachedIntBuffer &dst, const int *src, size_t bytes) +{ + if (!ensure_capacity(dst, bytes)) + return false; + + cudaError_t err = cudaMemcpy(dst.ptr, src, bytes, cudaMemcpyHostToDevice); + if (err != cudaSuccess) + { + report_cuda_error("cudaMemcpy(H2D int)", err); + return false; + } + return true; +} + inline bool copy_to_device(CachedBuffer &dst, const double *src, size_t bytes) { if (!ensure_capacity(dst, bytes)) @@ -135,6 +149,72 @@ inline bool copy_to_device_preferring_device(CachedBuffer &dst, const double *sr return true; } +inline int interp_idint_like_host(double x) +{ + return static_cast(x); +} + +inline void lagrange_weights_ord6_host(double x, double *w) +{ + static const double denom[6] = {-120.0, 24.0, -12.0, 12.0, -24.0, 120.0}; + for (int i = 0; i < 6; ++i) + { + double num = 1.0; + for (int j = 0; j < 6; ++j) + { + if (j != i) + num *= (x - static_cast(j)); + } + w[i] = num / denom[i]; + } +} + +inline bool map_interp_index_host(int logical_idx, int n, int *mapped_idx, int *reflected) +{ + int idx = logical_idx; + *reflected = 0; + if (idx < 0) + { + idx = -idx; + *reflected = 1; + } + if (idx < 0 || idx >= n) + return false; + *mapped_idx = idx; + return true; +} + +inline bool compute_interp_window_host(double x, const double *coord, int n, + int ordn, int allow_reflect, + int *start_idx, double *cx) +{ + const double dx = coord[1] - coord[0]; + const int center = interp_idint_like_host((x - coord[0]) / dx + 0.4); + const int cmin = allow_reflect ? (-ordn / 2 + 1) : 0; + const int cmax = n - 1; + + int begin = center - ordn / 2 + 1; + int end = begin + ordn - 1; + if (begin < cmin) + { + begin = cmin; + end = begin + ordn - 1; + } + if (end > cmax) + { + end = cmax; + begin = end + 1 - ordn; + } + + if (begin >= 0) + *cx = (x - coord[begin]) / dx; + else + *cx = (x + coord[-begin]) / dx; + + *start_idx = begin; + return (begin >= cmin && end <= cmax); +} + __global__ void enforce_ga_kernel(int n, double *dxx, double *gxy, double *gxz, double *dyy, double *gyz, double *dzz, @@ -295,80 +375,13 @@ __global__ void prolong3_cell_kernel(const double *funcc, double *funf, } } -__device__ inline int interp_idint_like(double x) -{ - return static_cast(x); -} - -__device__ inline void lagrange_weights_ord6(double x, double *w) -{ - const double denom[6] = {-120.0, 24.0, -12.0, 12.0, -24.0, 120.0}; - #pragma unroll - for (int i = 0; i < 6; ++i) - { - double num = 1.0; - #pragma unroll - for (int j = 0; j < 6; ++j) - { - if (j != i) - num *= (x - static_cast(j)); - } - w[i] = num / denom[i]; - } -} - -__device__ inline bool map_interp_index(int logical_idx, int n, double soa, int *mapped_idx, double *sign) -{ - int idx = logical_idx; - if (idx < 0) - { - idx = -idx; - *sign *= soa; - } - if (idx < 0 || idx >= n) - return false; - *mapped_idx = idx; - return true; -} - -__device__ inline bool compute_interp_window(double x, const double *coord, int n, - int ordn, int allow_reflect, - int *start_idx, double *cx) -{ - const double dx = coord[1] - coord[0]; - const int center = interp_idint_like((x - coord[0]) / dx + 0.4); - const int cmin = allow_reflect ? (-ordn / 2 + 1) : 0; - const int cmax = n - 1; - - int begin = center - ordn / 2 + 1; - int end = begin + ordn - 1; - if (begin < cmin) - { - begin = cmin; - end = begin + ordn - 1; - } - if (end > cmax) - { - end = cmax; - begin = end + 1 - ordn; - } - - if (begin >= 0) - *cx = (x - coord[begin]) / dx; - else - *cx = (x + coord[-begin]) / dx; - - *start_idx = begin; - return (begin >= cmin && end <= cmax); -} - __global__ void interp_points_ord6_kernel(int num_points, int num_var, - int nx, int ny, int nz, - const double *X, const double *Y, const double *Z, + int nx, int ny, const double *const *fields, const double *soa_flat, - const double *px, const double *py, const double *pz, - int symmetry, + const int *stencil_idx, + const int *stencil_reflect, + const double *stencil_weights, double *out, int *error_flag) { @@ -379,62 +392,38 @@ __global__ void interp_points_ord6_kernel(int num_points, int num_var, const int var_id = idx - point_id * num_var; const double *field = fields[var_id]; const double *soa = soa_flat + 3 * var_id; - - const double dx = X[1] - X[0]; - const double dy = Y[1] - Y[0]; - const double dz = Z[1] - Z[0]; - const int allow_reflect_x = (symmetry == 2 && fabs(X[0]) < dx); - const int allow_reflect_y = (symmetry == 2 && fabs(Y[0]) < dy); - const int allow_reflect_z = (symmetry != 0 && fabs(Z[0]) < dz); - - int start_x = 0, start_y = 0, start_z = 0; - double cx = 0.0, cy = 0.0, cz = 0.0; - const bool ok_x = compute_interp_window(px[point_id], X, nx, 6, allow_reflect_x, &start_x, &cx); - const bool ok_y = compute_interp_window(py[point_id], Y, ny, 6, allow_reflect_y, &start_y, &cy); - const bool ok_z = compute_interp_window(pz[point_id], Z, nz, 6, allow_reflect_z, &start_z, &cz); - if (!ok_x || !ok_y || !ok_z) - { - atomicExch(error_flag, 1); - out[idx] = 0.0; - continue; - } - - double wx[6], wy[6], wz[6]; - lagrange_weights_ord6(cx, wx); - lagrange_weights_ord6(cy, wy); - lagrange_weights_ord6(cz, wz); + const int *idxp = stencil_idx + point_id * 18; + const int *refp = stencil_reflect + point_id * 18; + const double *wp = stencil_weights + point_id * 18; double value = 0.0; #pragma unroll for (int kz = 0; kz < 6; ++kz) { + const int z_idx = idxp[12 + kz]; + const double z_sign = refp[12 + kz] ? soa[2] : 1.0; double yz_sum = 0.0; #pragma unroll for (int jy = 0; jy < 6; ++jy) { + const int y_idx = idxp[6 + jy]; + const double y_sign = refp[6 + jy] ? soa[1] : 1.0; double x_sum = 0.0; #pragma unroll for (int ix = 0; ix < 6; ++ix) { - double sign = 1.0; - int sx = 0, sy = 0, sz = 0; - const bool ok_map = - map_interp_index(start_x + ix, nx, soa[0], &sx, &sign) && - map_interp_index(start_y + jy, ny, soa[1], &sy, &sign) && - map_interp_index(start_z + kz, nz, soa[2], &sz, &sign); - if (!ok_map) - { - atomicExch(error_flag, 1); - continue; - } - x_sum += wx[ix] * sign * field[index3(sx, sy, sz, nx, ny)]; + const int x_idx = idxp[ix]; + const double x_sign = refp[ix] ? soa[0] : 1.0; + const double sign = x_sign * y_sign * z_sign; + x_sum += wp[ix] * sign * field[index3(x_idx, y_idx, z_idx, nx, ny)]; } - yz_sum += wy[jy] * x_sum; + yz_sum += wp[6 + jy] * x_sum; } - value += wz[kz] * yz_sum; + value += wp[12 + kz] * yz_sum; } out[idx] = value; } + (void)error_flag; } __global__ void rk4_kernel(int n, double dT, @@ -964,19 +953,14 @@ int bssn_cuda_interp_points_batch(const int *ex, struct InterpBatchCache { - CachedBuffer X, Y, Z; - CachedBuffer px, py, pz; CachedBuffer out; CachedBuffer soa; CachedBuffer field_ptrs; + CachedBuffer weights; + CachedIntBuffer indices; + CachedIntBuffer reflect; CachedIntBuffer error_flag; std::vector host_field_copies; - const double *host_X = nullptr; - const double *host_Y = nullptr; - const double *host_Z = nullptr; - int nx = 0; - int ny = 0; - int nz = 0; }; static thread_local InterpBatchCache cache; @@ -984,38 +968,55 @@ int bssn_cuda_interp_points_batch(const int *ex, const int ny = ex[1]; const int nz = ex[2]; const int field_points = count_points(ex); - const size_t coord_bytes_x = static_cast(nx) * sizeof(double); - const size_t coord_bytes_y = static_cast(ny) * sizeof(double); - const size_t coord_bytes_z = static_cast(nz) * sizeof(double); const size_t field_bytes = static_cast(field_points) * sizeof(double); - const size_t point_bytes = static_cast(num_points) * sizeof(double); const size_t out_bytes = static_cast(num_points) * static_cast(num_var) * sizeof(double); const size_t soa_bytes = static_cast(3 * num_var) * sizeof(double); const size_t ptr_bytes = static_cast(num_var) * sizeof(double *); + const size_t point_stencil_doubles = static_cast(num_points) * 18; + const size_t point_stencil_ints = static_cast(num_points) * 18; + const size_t weights_bytes = point_stencil_doubles * sizeof(double); + const size_t indices_bytes = point_stencil_ints * sizeof(int); bool ok = true; - if (cache.host_X != X || cache.host_Y != Y || cache.host_Z != Z || - cache.nx != nx || cache.ny != ny || cache.nz != nz) + std::vector host_weights(point_stencil_doubles); + std::vector host_indices(point_stencil_ints); + std::vector host_reflect(point_stencil_ints); + const double dx = X[1] - X[0]; + const double dy = Y[1] - Y[0]; + const double dz = Z[1] - Z[0]; + const int allow_reflect_x = (symmetry == 2 && std::fabs(X[0]) < dx); + const int allow_reflect_y = (symmetry == 2 && std::fabs(Y[0]) < dy); + const int allow_reflect_z = (symmetry != 0 && std::fabs(Z[0]) < dz); + for (int p = 0; p < num_points; ++p) { - ok = copy_to_device(cache.X, X, coord_bytes_x) && - copy_to_device(cache.Y, Y, coord_bytes_y) && - copy_to_device(cache.Z, Z, coord_bytes_z); - if (ok) + int start_x = 0, start_y = 0, start_z = 0; + double cx = 0.0, cy = 0.0, cz = 0.0; + const bool ok_x = compute_interp_window_host(px[p], X, nx, ordn, allow_reflect_x, &start_x, &cx); + const bool ok_y = compute_interp_window_host(py[p], Y, ny, ordn, allow_reflect_y, &start_y, &cy); + const bool ok_z = compute_interp_window_host(pz[p], Z, nz, ordn, allow_reflect_z, &start_z, &cz); + if (!ok_x || !ok_y || !ok_z) + return 1; + + lagrange_weights_ord6_host(cx, host_weights.data() + p * 18); + lagrange_weights_ord6_host(cy, host_weights.data() + p * 18 + 6); + lagrange_weights_ord6_host(cz, host_weights.data() + p * 18 + 12); + + for (int i = 0; i < 6; ++i) { - cache.host_X = X; - cache.host_Y = Y; - cache.host_Z = Z; - cache.nx = nx; - cache.ny = ny; - cache.nz = nz; + if (!map_interp_index_host(start_x + i, nx, &host_indices[p * 18 + i], &host_reflect[p * 18 + i])) + return 1; + if (!map_interp_index_host(start_y + i, ny, &host_indices[p * 18 + 6 + i], &host_reflect[p * 18 + 6 + i])) + return 1; + if (!map_interp_index_host(start_z + i, nz, &host_indices[p * 18 + 12 + i], &host_reflect[p * 18 + 12 + i])) + return 1; } } ok = ok && - copy_to_device(cache.px, px, point_bytes) && - copy_to_device(cache.py, py, point_bytes) && - copy_to_device(cache.pz, pz, point_bytes) && copy_to_device(cache.soa, soa_flat, soa_bytes) && + copy_to_device(cache.weights, host_weights.data(), weights_bytes) && + copy_to_device(cache.indices, host_indices.data(), indices_bytes) && + copy_to_device(cache.reflect, host_reflect.data(), indices_bytes) && ensure_capacity(cache.out, out_bytes) && ensure_capacity(cache.field_ptrs, ptr_bytes) && ensure_capacity(cache.error_flag, sizeof(int)); @@ -1060,26 +1061,22 @@ int bssn_cuda_interp_points_batch(const int *ex, int nx_local = nx; int ny_local = ny; - int nz_local = nz; - const double *dX = cache.X.ptr; - const double *dY = cache.Y.ptr; - const double *dZ = cache.Z.ptr; - const double *dpx = cache.px.ptr; - const double *dpy = cache.py.ptr; - const double *dpz = cache.pz.ptr; const double *dsoa = cache.soa.ptr; const double *const *dfields = reinterpret_cast(cache.field_ptrs.ptr); + const double *dweights = cache.weights.ptr; + const int *dindices = cache.indices.ptr; + const int *dreflect = cache.reflect.ptr; double *dout = cache.out.ptr; int *derror = cache.error_flag.ptr; void *args[] = { &num_points, &num_var, - &nx_local, &ny_local, &nz_local, - &dX, &dY, &dZ, + &nx_local, &ny_local, &dfields, &dsoa, - &dpx, &dpy, &dpz, - &symmetry, + &dindices, + &dreflect, + &dweights, &dout, &derror}; ok = launch_kernel(grid, block, (const void *)interp_points_ord6_kernel, args);