[WIP] Add CUDA support for Z4C

Rewritten done by Codex.
This still has errors, do not pick this one now.
This commit is contained in:
2026-04-27 08:27:24 +08:00
parent 60fee8f1c1
commit c689cc8dc9
6 changed files with 8282 additions and 73 deletions

View File

@@ -6,10 +6,20 @@
#include "parameters.h"
#include <cstdlib>
#include <cstdio>
#if USE_CUDA_BSSN
#ifndef USE_CUDA_Z4C
#define USE_CUDA_Z4C 0
#endif
#if USE_CUDA_BSSN || USE_CUDA_Z4C
#include <cuda_runtime_api.h>
#endif
#if USE_CUDA_BSSN
#include "bssn_rhs_cuda.h"
#endif
#if USE_CUDA_Z4C
#include "z4c_rhs_cuda.h"
#endif
namespace {
@@ -80,7 +90,7 @@ bool cuda_sync_pinned_enabled()
if (enabled < 0)
{
const char *env = getenv("AMSS_CUDA_PINNED_SYNC");
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
enabled = (!env || atoi(env) != 0) ? 1 : 0;
#else
enabled = 0;
@@ -93,7 +103,7 @@ void free_comm_buffer(double *&ptr, unsigned char &is_pinned)
{
if (!ptr)
return;
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
if (is_pinned)
cudaFreeHost(ptr);
else
@@ -110,7 +120,7 @@ double *alloc_comm_buffer(int length, unsigned char &is_pinned)
is_pinned = 0;
if (length <= 0)
return 0;
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
if (cuda_sync_pinned_enabled())
{
double *ptr = 0;
@@ -157,19 +167,43 @@ int cuda_state_var_count(MyList<var> *src_vars, MyList<var> *dst_vars)
return (src_vars || dst_vars) ? -1 : count;
}
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
bool cuda_state_count_direct_supported(int state_count)
{
#if USE_CUDA_Z4C && (ABEtype == 2)
return state_count == Z4C_CUDA_STATE_COUNT;
#elif USE_CUDA_BSSN
return state_count > 0 && state_count <= BSSN_CUDA_STATE_COUNT;
#else
(void)state_count;
return false;
#endif
}
bool cuda_can_direct_pack(const Parallel::gridseg *src, const Parallel::gridseg *dst, int type)
{
if (type != 1 || !src || !dst || !src->Bg)
return false;
#if USE_CUDA_Z4C && (ABEtype == 2)
return z4c_cuda_has_resident_state(src->Bg) != 0;
#elif USE_CUDA_BSSN
return bssn_cuda_has_resident_state(src->Bg) != 0;
#else
return false;
#endif
}
bool cuda_can_direct_unpack(const Parallel::gridseg *dst, int type)
{
if (type != 1 || !dst || !dst->Bg)
return false;
#if USE_CUDA_Z4C && (ABEtype == 2)
return z4c_cuda_has_resident_state(dst->Bg) != 0;
#elif USE_CUDA_BSSN
return bssn_cuda_has_resident_state(dst->Bg) != 0;
#else
return false;
#endif
}
bool cuda_direct_pack_segment(double *buffer,
@@ -177,15 +211,28 @@ bool cuda_direct_pack_segment(double *buffer,
const Parallel::gridseg *dst,
int state_count)
{
#if USE_CUDA_Z4C && (ABEtype == 2)
if (state_count != Z4C_CUDA_STATE_COUNT)
return false;
#elif USE_CUDA_BSSN
if (state_count <= 0 || state_count > BSSN_CUDA_STATE_COUNT)
return false;
#else
return false;
#endif
const double t0 = sync_profile_enabled() ? MPI_Wtime() : 0.0;
const int i0 = cuda_seg_begin(dst, src->Bg, 0);
const int j0 = cuda_seg_begin(dst, src->Bg, 1);
const int k0 = cuda_seg_begin(dst, src->Bg, 2);
#if USE_CUDA_Z4C && (ABEtype == 2)
const bool ok = z4c_cuda_pack_state_batch_to_host_buffer(src->Bg, state_count, buffer, src->Bg->shape,
i0, j0, k0,
dst->shape[0], dst->shape[1], dst->shape[2]) == 0;
#else
const bool ok = bssn_cuda_pack_state_batch_to_host_buffer(src->Bg, state_count, buffer, src->Bg->shape,
i0, j0, k0,
dst->shape[0], dst->shape[1], dst->shape[2]) == 0;
#endif
if (sync_profile_enabled())
sync_profile_stats().direct_pack_sec += MPI_Wtime() - t0;
return ok;
@@ -195,15 +242,28 @@ bool cuda_direct_unpack_segment(double *buffer,
const Parallel::gridseg *dst,
int state_count)
{
#if USE_CUDA_Z4C && (ABEtype == 2)
if (state_count != Z4C_CUDA_STATE_COUNT)
return false;
#elif USE_CUDA_BSSN
if (state_count <= 0 || state_count > BSSN_CUDA_STATE_COUNT)
return false;
#else
return false;
#endif
const double t0 = sync_profile_enabled() ? MPI_Wtime() : 0.0;
const int i0 = cuda_seg_begin(dst, dst->Bg, 0);
const int j0 = cuda_seg_begin(dst, dst->Bg, 1);
const int k0 = cuda_seg_begin(dst, dst->Bg, 2);
#if USE_CUDA_Z4C && (ABEtype == 2)
const bool ok = z4c_cuda_unpack_state_batch_from_host_buffer(dst->Bg, state_count, buffer, dst->Bg->shape,
i0, j0, k0,
dst->shape[0], dst->shape[1], dst->shape[2]) == 0;
#else
const bool ok = bssn_cuda_unpack_state_batch_from_host_buffer(dst->Bg, state_count, buffer, dst->Bg->shape,
i0, j0, k0,
dst->shape[0], dst->shape[1], dst->shape[2]) == 0;
#endif
if (sync_profile_enabled())
sync_profile_stats().direct_unpack_sec += MPI_Wtime() - t0;
return ok;
@@ -3966,9 +4026,10 @@ int Parallel::data_packer(double *data, MyList<Parallel::gridseg> *src, MyList<P
{
if (data)
{
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
bool handled_by_cuda = false;
if (dir == PACK && cuda_can_direct_pack(src->data, dst->data, type))
if (dir == PACK && cuda_state_count_direct_supported(state_count) &&
cuda_can_direct_pack(src->data, dst->data, type))
{
handled_by_cuda = cuda_direct_pack_segment(data + size_out, src->data, dst->data, state_count);
if (!handled_by_cuda)
@@ -3977,7 +4038,8 @@ int Parallel::data_packer(double *data, MyList<Parallel::gridseg> *src, MyList<P
MPI_Abort(MPI_COMM_WORLD, 1);
}
}
else if (dir == UNPACK && cuda_can_direct_unpack(dst->data, type))
else if (dir == UNPACK && cuda_state_count_direct_supported(state_count) &&
cuda_can_direct_unpack(dst->data, type))
{
handled_by_cuda = cuda_direct_unpack_segment(data + size_out, dst->data, state_count);
if (!handled_by_cuda)
@@ -4012,7 +4074,7 @@ int Parallel::data_packer(double *data, MyList<Parallel::gridseg> *src, MyList<P
f_copy(DIM, dst->data->Bg->bbox, dst->data->Bg->bbox + dim, dst->data->Bg->shape, dst->data->Bg->fgfs[varld->data->sgfn],
dst->data->llb, dst->data->uub, dst->data->shape, data + size_out,
dst->data->llb, dst->data->uub);
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
}
else
{
@@ -4593,7 +4655,7 @@ void Parallel::SyncCache::destroy()
{
if (send_bufs && send_bufs[i])
{
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
free_comm_buffer(send_bufs[i], send_buf_pinned[i]);
#else
delete[] send_bufs[i];
@@ -4601,7 +4663,7 @@ void Parallel::SyncCache::destroy()
}
if (recv_bufs && recv_bufs[i])
{
#if USE_CUDA_BSSN
#if USE_CUDA_BSSN || USE_CUDA_Z4C
free_comm_buffer(recv_bufs[i], recv_buf_pinned[i]);
#else
delete[] recv_bufs[i];