Add batched CUDA patch interpolation path

This commit is contained in:
2026-04-09 14:56:01 +08:00
parent ad999e4c5a
commit c47349b7a9
3 changed files with 601 additions and 105 deletions

View File

@@ -42,6 +42,12 @@ struct CachedBuffer
size_t capacity = 0;
};
struct CachedIntBuffer
{
int *ptr = nullptr;
size_t capacity = 0;
};
inline bool ensure_capacity(CachedBuffer &buffer, size_t bytes)
{
if (bytes <= buffer.capacity && buffer.ptr)
@@ -67,6 +73,31 @@ inline bool ensure_capacity(CachedBuffer &buffer, size_t bytes)
return true;
}
inline bool ensure_capacity(CachedIntBuffer &buffer, size_t bytes)
{
if (bytes <= buffer.capacity && buffer.ptr)
return true;
if (buffer.ptr)
{
cudaError_t free_err = cudaFree(buffer.ptr);
if (free_err != cudaSuccess)
report_cuda_error("cudaFree", free_err);
buffer.ptr = nullptr;
buffer.capacity = 0;
}
cudaError_t err = cudaMalloc(&buffer.ptr, bytes);
if (err != cudaSuccess)
{
report_cuda_error("cudaMalloc", err);
return false;
}
buffer.capacity = bytes;
return true;
}
inline bool copy_to_device(CachedBuffer &dst, const double *src, size_t bytes)
{
if (!ensure_capacity(dst, bytes))
@@ -264,6 +295,148 @@ __global__ void prolong3_cell_kernel(const double *funcc, double *funf,
}
}
__device__ inline int interp_idint_like(double x)
{
return static_cast<int>(x);
}
__device__ inline void lagrange_weights_ord6(double x, double *w)
{
const double denom[6] = {-120.0, 24.0, -12.0, 12.0, -24.0, 120.0};
#pragma unroll
for (int i = 0; i < 6; ++i)
{
double num = 1.0;
#pragma unroll
for (int j = 0; j < 6; ++j)
{
if (j != i)
num *= (x - static_cast<double>(j));
}
w[i] = num / denom[i];
}
}
__device__ inline bool map_interp_index(int logical_idx, int n, double soa, int *mapped_idx, double *sign)
{
int idx = logical_idx;
if (idx < 0)
{
idx = -idx;
*sign *= soa;
}
if (idx < 0 || idx >= n)
return false;
*mapped_idx = idx;
return true;
}
__device__ inline bool compute_interp_window(double x, const double *coord, int n,
int ordn, int allow_reflect,
int *start_idx, double *cx)
{
const double dx = coord[1] - coord[0];
const int center = interp_idint_like((x - coord[0]) / dx + 0.4);
const int cmin = allow_reflect ? (-ordn / 2 + 1) : 0;
const int cmax = n - 1;
int begin = center - ordn / 2 + 1;
int end = begin + ordn - 1;
if (begin < cmin)
{
begin = cmin;
end = begin + ordn - 1;
}
if (end > cmax)
{
end = cmax;
begin = end + 1 - ordn;
}
if (begin >= 0)
*cx = (x - coord[begin]) / dx;
else
*cx = (x + coord[-begin]) / dx;
*start_idx = begin;
return (begin >= cmin && end <= cmax);
}
__global__ void interp_points_ord6_kernel(int num_points, int num_var,
int nx, int ny, int nz,
const double *X, const double *Y, const double *Z,
const double *const *fields,
const double *soa_flat,
const double *px, const double *py, const double *pz,
int symmetry,
double *out,
int *error_flag)
{
const int total = num_points * num_var;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x)
{
const int point_id = idx / num_var;
const int var_id = idx - point_id * num_var;
const double *field = fields[var_id];
const double *soa = soa_flat + 3 * var_id;
const double dx = X[1] - X[0];
const double dy = Y[1] - Y[0];
const double dz = Z[1] - Z[0];
const int allow_reflect_x = (symmetry == 2 && fabs(X[0]) < dx);
const int allow_reflect_y = (symmetry == 2 && fabs(Y[0]) < dy);
const int allow_reflect_z = (symmetry != 0 && fabs(Z[0]) < dz);
int start_x = 0, start_y = 0, start_z = 0;
double cx = 0.0, cy = 0.0, cz = 0.0;
const bool ok_x = compute_interp_window(px[point_id], X, nx, 6, allow_reflect_x, &start_x, &cx);
const bool ok_y = compute_interp_window(py[point_id], Y, ny, 6, allow_reflect_y, &start_y, &cy);
const bool ok_z = compute_interp_window(pz[point_id], Z, nz, 6, allow_reflect_z, &start_z, &cz);
if (!ok_x || !ok_y || !ok_z)
{
atomicExch(error_flag, 1);
out[idx] = 0.0;
continue;
}
double wx[6], wy[6], wz[6];
lagrange_weights_ord6(cx, wx);
lagrange_weights_ord6(cy, wy);
lagrange_weights_ord6(cz, wz);
double value = 0.0;
#pragma unroll
for (int kz = 0; kz < 6; ++kz)
{
double yz_sum = 0.0;
#pragma unroll
for (int jy = 0; jy < 6; ++jy)
{
double x_sum = 0.0;
#pragma unroll
for (int ix = 0; ix < 6; ++ix)
{
double sign = 1.0;
int sx = 0, sy = 0, sz = 0;
const bool ok_map =
map_interp_index(start_x + ix, nx, soa[0], &sx, &sign) &&
map_interp_index(start_y + jy, ny, soa[1], &sy, &sign) &&
map_interp_index(start_z + kz, nz, soa[2], &sz, &sign);
if (!ok_map)
{
atomicExch(error_flag, 1);
continue;
}
x_sum += wx[ix] * sign * field[index3(sx, sy, sz, nx, ny)];
}
yz_sum += wy[jy] * x_sum;
}
value += wz[kz] * yz_sum;
}
out[idx] = value;
}
}
__global__ void rk4_kernel(int n, double dT,
const double *f0,
double *f1,
@@ -771,6 +944,168 @@ int bssn_cuda_lowerbound(int *ex, double *chi, double tinny)
return ok ? 0 : 1;
}
int bssn_cuda_interp_points_batch(const int *ex,
const double *X, const double *Y, const double *Z,
const double *const *fields,
const double *soa_flat,
int num_var,
const double *px, const double *py, const double *pz,
int num_points,
int ordn,
int symmetry,
double *out)
{
if (!ex || !X || !Y || !Z || !fields || !soa_flat || !px || !py || !pz || !out)
return 1;
if (num_var <= 0 || num_points <= 0 || ordn != 6)
return 1;
if (ex[0] < ordn || ex[1] < ordn || ex[2] < ordn)
return 1;
struct InterpBatchCache
{
CachedBuffer X, Y, Z;
CachedBuffer px, py, pz;
CachedBuffer out;
CachedBuffer soa;
CachedBuffer field_ptrs;
CachedIntBuffer error_flag;
std::vector<CachedBuffer> host_field_copies;
const double *host_X = nullptr;
const double *host_Y = nullptr;
const double *host_Z = nullptr;
int nx = 0;
int ny = 0;
int nz = 0;
};
static thread_local InterpBatchCache cache;
const int nx = ex[0];
const int ny = ex[1];
const int nz = ex[2];
const int field_points = count_points(ex);
const size_t coord_bytes_x = static_cast<size_t>(nx) * sizeof(double);
const size_t coord_bytes_y = static_cast<size_t>(ny) * sizeof(double);
const size_t coord_bytes_z = static_cast<size_t>(nz) * sizeof(double);
const size_t field_bytes = static_cast<size_t>(field_points) * sizeof(double);
const size_t point_bytes = static_cast<size_t>(num_points) * sizeof(double);
const size_t out_bytes = static_cast<size_t>(num_points) * static_cast<size_t>(num_var) * sizeof(double);
const size_t soa_bytes = static_cast<size_t>(3 * num_var) * sizeof(double);
const size_t ptr_bytes = static_cast<size_t>(num_var) * sizeof(double *);
bool ok = true;
if (cache.host_X != X || cache.host_Y != Y || cache.host_Z != Z ||
cache.nx != nx || cache.ny != ny || cache.nz != nz)
{
ok = copy_to_device(cache.X, X, coord_bytes_x) &&
copy_to_device(cache.Y, Y, coord_bytes_y) &&
copy_to_device(cache.Z, Z, coord_bytes_z);
if (ok)
{
cache.host_X = X;
cache.host_Y = Y;
cache.host_Z = Z;
cache.nx = nx;
cache.ny = ny;
cache.nz = nz;
}
}
ok = ok &&
copy_to_device(cache.px, px, point_bytes) &&
copy_to_device(cache.py, py, point_bytes) &&
copy_to_device(cache.pz, pz, point_bytes) &&
copy_to_device(cache.soa, soa_flat, soa_bytes) &&
ensure_capacity(cache.out, out_bytes) &&
ensure_capacity(cache.field_ptrs, ptr_bytes) &&
ensure_capacity(cache.error_flag, sizeof(int));
if (!ok)
return 1;
if (static_cast<int>(cache.host_field_copies.size()) < num_var)
cache.host_field_copies.resize(num_var);
std::vector<const double *> device_fields(num_var);
for (int v = 0; v < num_var; ++v)
{
const double *device_field = bssn_gpu_find_device_buffer(fields[v]);
if (!device_field)
{
ok = copy_to_device(cache.host_field_copies[v], fields[v], field_bytes);
device_field = cache.host_field_copies[v].ptr;
}
device_fields[v] = device_field;
if (!ok || !device_fields[v])
return 1;
}
int zero = 0;
cudaError_t err = cudaMemcpy(cache.field_ptrs.ptr, device_fields.data(), ptr_bytes, cudaMemcpyHostToDevice);
if (err != cudaSuccess)
{
report_cuda_error("cudaMemcpy(H2D) field_ptrs", err);
return 1;
}
err = cudaMemcpy(cache.error_flag.ptr, &zero, sizeof(int), cudaMemcpyHostToDevice);
if (err != cudaSuccess)
{
report_cuda_error("cudaMemcpy(H2D) interp_error_flag", err);
return 1;
}
dim3 block(128);
dim3 grid(div_up(num_points * num_var, static_cast<int>(block.x)));
if (grid.x > 4096)
grid.x = 4096;
int nx_local = nx;
int ny_local = ny;
int nz_local = nz;
const double *dX = cache.X.ptr;
const double *dY = cache.Y.ptr;
const double *dZ = cache.Z.ptr;
const double *dpx = cache.px.ptr;
const double *dpy = cache.py.ptr;
const double *dpz = cache.pz.ptr;
const double *dsoa = cache.soa.ptr;
const double *const *dfields = reinterpret_cast<const double *const *>(cache.field_ptrs.ptr);
double *dout = cache.out.ptr;
int *derror = cache.error_flag.ptr;
void *args[] = {
&num_points, &num_var,
&nx_local, &ny_local, &nz_local,
&dX, &dY, &dZ,
&dfields,
&dsoa,
&dpx, &dpy, &dpz,
&symmetry,
&dout,
&derror};
ok = launch_kernel(grid, block, (const void *)interp_points_ord6_kernel, args);
if (!ok)
return 1;
int error_flag = 0;
err = cudaMemcpy(&error_flag, cache.error_flag.ptr, sizeof(int), cudaMemcpyDeviceToHost);
if (err != cudaSuccess)
{
report_cuda_error("cudaMemcpy(D2H) interp_error_flag", err);
return 1;
}
if (error_flag != 0)
return 1;
err = cudaMemcpy(out, cache.out.ptr, out_bytes, cudaMemcpyDeviceToHost);
if (err != cudaSuccess)
{
report_cuda_error("cudaMemcpy(D2H) interp_out", err);
return 1;
}
return 0;
}
int bssn_cuda_prolong3_pack(int wei,
const double *llbc, const double *uubc, const int *extc, const double *func,
const double *llbf, const double *uubf, const int *extf, double *funf,