Accelerate Shell-Patch interpolation fast paths

This commit is contained in:
2026-05-08 13:26:16 +08:00
parent 063f28b3b4
commit 39450228f5
3 changed files with 906 additions and 150 deletions

View File

@@ -9463,6 +9463,197 @@ int bssn_cuda_interp_host_two_fields(void *block_tag,
return 0;
}
__global__ void kern_shell_pack_host_fields(double **fields,
const int *block_shapes,
const int *point_block,
const int *point_dimh,
const int *point_dumyd,
const int *point_sind,
const double *point_coef,
double *out,
int npoints,
int nvars,
int ordn)
{
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int total = npoints * nvars;
if (tid >= total) return;
const int p = tid / nvars;
const int v = tid - p * nvars;
const int b = point_block[p];
const int *shape = block_shapes + 3 * b;
const int *s = point_sind + 3 * p;
const double *coef = point_coef + 3 * ordn * p;
const double *f = fields[b * nvars + v];
const int nx = shape[0];
const int ny = shape[1];
const int nz = shape[2];
const int dimh = point_dimh[p];
const int dumyd = point_dumyd[p];
double sum = 0.0;
if (dimh == 3) {
const double *cx = coef;
const double *cy = coef + ordn;
const double *cz = coef + 2 * ordn;
for (int kk = 0; kk < ordn; ++kk)
for (int jj = 0; jj < ordn; ++jj)
for (int ii = 0; ii < ordn; ++ii) {
const int idx = (s[0] + ii) + nx * ((s[1] + jj) + ny * (s[2] + kk));
sum += cx[ii] * cy[jj] * cz[kk] * f[idx];
}
} else if (dimh == 1 && dumyd == 1) {
for (int ii = 0; ii < ordn; ++ii) {
const int idx = (s[0] + ii) + nx * (s[1] + ny * s[2]);
sum += coef[ii] * f[idx];
}
} else if (dimh == 1 && dumyd == 0) {
for (int jj = 0; jj < ordn; ++jj) {
const int idx = s[1] + nx * ((s[0] + jj) + ny * s[2]);
sum += coef[jj] * f[idx];
}
}
out[tid] = sum;
}
struct ShellPackCachedField {
double *device;
size_t bytes;
int generation;
};
static std::unordered_map<const double *, ShellPackCachedField> g_shell_pack_cache;
static int g_shell_pack_generation = 0;
extern "C"
void bssn_cuda_shell_pack_cache_begin()
{
init_gpu_dispatch();
CUDA_CHECK(cudaSetDevice(g_dispatch.my_device));
for (auto &kv : g_shell_pack_cache)
cudaFree(kv.second.device);
g_shell_pack_cache.clear();
++g_shell_pack_generation;
}
extern "C"
void bssn_cuda_shell_pack_cache_end()
{
init_gpu_dispatch();
CUDA_CHECK(cudaSetDevice(g_dispatch.my_device));
for (auto &kv : g_shell_pack_cache)
cudaFree(kv.second.device);
g_shell_pack_cache.clear();
}
extern "C"
int bssn_cuda_shell_pack_host_fields(int npoints,
int nvars,
int nblocks,
int ordn,
double **block_var_fields,
int *block_shapes,
int *point_block,
int *point_dimh,
int *point_dumyd,
int *point_sind,
double *point_coef,
double *out)
{
init_gpu_dispatch();
CUDA_CHECK(cudaSetDevice(g_dispatch.my_device));
if (npoints <= 0 || nvars <= 0 || nblocks <= 0 || ordn <= 0 || ordn > 8 ||
!block_var_fields || !block_shapes || !point_block || !point_dimh ||
!point_dumyd || !point_sind || !point_coef || !out)
return 1;
const int field_count = nblocks * nvars;
std::vector<double *> h_device_fields((size_t)field_count, nullptr);
double **d_fields = nullptr;
int *d_block_shapes = nullptr;
int *d_point_block = nullptr;
int *d_point_dimh = nullptr;
int *d_point_dumyd = nullptr;
int *d_point_sind = nullptr;
double *d_point_coef = nullptr;
double *d_out = nullptr;
for (int b = 0; b < nblocks; ++b) {
const size_t all = (size_t)block_shapes[3 * b] *
(size_t)block_shapes[3 * b + 1] *
(size_t)block_shapes[3 * b + 2];
const size_t bytes = all * sizeof(double);
for (int v = 0; v < nvars; ++v) {
const int idx = b * nvars + v;
double *host_ptr = block_var_fields[idx];
if (!host_ptr) return 1;
auto it = g_shell_pack_cache.find(host_ptr);
if (it != g_shell_pack_cache.end() &&
it->second.bytes == bytes &&
it->second.generation == g_shell_pack_generation) {
h_device_fields[idx] = it->second.device;
} else {
double *device_ptr = nullptr;
CUDA_CHECK(cudaMalloc(&device_ptr, bytes));
CUDA_CHECK(cudaMemcpy(device_ptr, host_ptr, bytes, cudaMemcpyHostToDevice));
g_shell_pack_cache[host_ptr] = {device_ptr, bytes, g_shell_pack_generation};
h_device_fields[idx] = device_ptr;
}
}
}
CUDA_CHECK(cudaMalloc(&d_fields, (size_t)field_count * sizeof(double *)));
CUDA_CHECK(cudaMemcpy(d_fields, h_device_fields.data(),
(size_t)field_count * sizeof(double *),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMalloc(&d_block_shapes, (size_t)nblocks * 3 * sizeof(int)));
CUDA_CHECK(cudaMemcpy(d_block_shapes, block_shapes,
(size_t)nblocks * 3 * sizeof(int),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMalloc(&d_point_block, (size_t)npoints * sizeof(int)));
CUDA_CHECK(cudaMalloc(&d_point_dimh, (size_t)npoints * sizeof(int)));
CUDA_CHECK(cudaMalloc(&d_point_dumyd, (size_t)npoints * sizeof(int)));
CUDA_CHECK(cudaMalloc(&d_point_sind, (size_t)npoints * 3 * sizeof(int)));
CUDA_CHECK(cudaMalloc(&d_point_coef, (size_t)npoints * 3 * ordn * sizeof(double)));
CUDA_CHECK(cudaMalloc(&d_out, (size_t)npoints * nvars * sizeof(double)));
CUDA_CHECK(cudaMemcpy(d_point_block, point_block,
(size_t)npoints * sizeof(int), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_point_dimh, point_dimh,
(size_t)npoints * sizeof(int), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_point_dumyd, point_dumyd,
(size_t)npoints * sizeof(int), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_point_sind, point_sind,
(size_t)npoints * 3 * sizeof(int), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_point_coef, point_coef,
(size_t)npoints * 3 * ordn * sizeof(double),
cudaMemcpyHostToDevice));
const int total = npoints * nvars;
const int threads = 256;
const int blocks = (total + threads - 1) / threads;
kern_shell_pack_host_fields<<<blocks, threads>>>(
d_fields, d_block_shapes, d_point_block, d_point_dimh,
d_point_dumyd, d_point_sind, d_point_coef, d_out,
npoints, nvars, ordn);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaMemcpy(out, d_out, (size_t)total * sizeof(double),
cudaMemcpyDeviceToHost));
cudaFree(d_out);
cudaFree(d_point_coef);
cudaFree(d_point_sind);
cudaFree(d_point_dumyd);
cudaFree(d_point_dimh);
cudaFree(d_point_block);
cudaFree(d_block_shapes);
cudaFree(d_fields);
return 0;
}
extern "C"
int bssn_cuda_unpack_state_region_from_host_buffer(void *block_tag,
int state_index,