Checkpoint Z4C CUDA optimization progress

This commit is contained in:
2026-05-02 08:55:25 +08:00
parent a5c8188305
commit fcd98649f6
4 changed files with 180 additions and 24 deletions

View File

@@ -462,25 +462,28 @@ struct StepContext {
std::array<double *, BSSN_STATE_COUNT> d_state_curr;
std::array<double *, BSSN_STATE_COUNT> d_state_next;
std::array<double *, BSSN_MATTER_COUNT> d_matter;
std::array<double *, BSSN_STATE_COUNT> resident_host;
size_t cap_all;
size_t cap_comm;
bool h_comm_pinned;
size_t cap_h_comm;
bool matter_ready;
bool state_ready;
bool resident_host_valid;
StepContext()
: d_state0_mem(nullptr), d_accum_mem(nullptr),
d_state_curr_mem(nullptr), d_state_next_mem(nullptr),
d_matter_mem(nullptr), d_comm_mem(nullptr), h_comm_mem(nullptr),
cap_all(0), cap_comm(0), h_comm_pinned(false), cap_h_comm(0),
matter_ready(false), state_ready(false)
matter_ready(false), state_ready(false), resident_host_valid(false)
{
d_state0.fill(nullptr);
d_accum.fill(nullptr);
d_state_curr.fill(nullptr);
d_state_next.fill(nullptr);
d_matter.fill(nullptr);
resident_host.fill(nullptr);
}
};
@@ -544,6 +547,8 @@ static StepAllocation detach_step_allocation(StepContext &ctx)
ctx.d_state_curr.fill(nullptr);
ctx.d_state_next.fill(nullptr);
ctx.d_matter.fill(nullptr);
ctx.resident_host.fill(nullptr);
ctx.resident_host_valid = false;
return alloc;
}
@@ -562,6 +567,8 @@ static void attach_step_allocation(StepContext &ctx, const StepAllocation &alloc
ctx.cap_h_comm = alloc.cap_h_comm;
ctx.matter_ready = false;
ctx.state_ready = false;
ctx.resident_host.fill(nullptr);
ctx.resident_host_valid = false;
}
static void recycle_step_allocation(StepAllocation &alloc)
@@ -5794,6 +5801,37 @@ static bool has_resident_state(void *block_tag)
return it != g_step_ctx.end() && it->second.state_ready;
}
static bool resident_key_usable(double **host_key)
{
if (!host_key) return false;
for (int i = 0; i < BSSN_STATE_COUNT; ++i) {
if (!host_key[i]) return false;
}
return true;
}
static bool resident_key_matches(const StepContext &ctx, double **host_key)
{
if (!ctx.state_ready || !ctx.resident_host_valid || !resident_key_usable(host_key))
return false;
for (int i = 0; i < BSSN_STATE_COUNT; ++i) {
if (ctx.resident_host[i] != host_key[i]) return false;
}
return true;
}
static void set_resident_key(StepContext &ctx, double **host_key)
{
if (!resident_key_usable(host_key)) {
ctx.resident_host.fill(nullptr);
ctx.resident_host_valid = false;
return;
}
for (int i = 0; i < BSSN_STATE_COUNT; ++i)
ctx.resident_host[i] = host_key[i];
ctx.resident_host_valid = true;
}
#define pow2(x) ((x) * (x))
@@ -7786,10 +7824,13 @@ extern "C" int z4c_cuda_rk4_substep(void *block_tag,
bind_state_input_slots(ctx.d_state_curr);
bind_state_output_slots(ctx.d_state_next);
}
double t0 = profile ? cuda_profile_now_ms() : 0.0;
if (!use_resident_state || !ctx.state_ready) {
upload_state_inputs(state_host_in, all);
if (use_resident_state) {
ctx.state_ready = true;
set_resident_key(ctx, state_host_in);
}
}
if (apply_enforce_ga) {
kern_enforce_ga_cuda<<<grid(all), BLK>>>(g_buf.slot[S_dxx], g_buf.slot[S_gxy], g_buf.slot[S_gxz],
@@ -7849,6 +7890,7 @@ extern "C" int z4c_cuda_rk4_substep(void *block_tag,
std::swap(ctx.d_state_curr_mem, ctx.d_state_next_mem);
ctx.d_state_curr.swap(ctx.d_state_next);
ctx.state_ready = true;
set_resident_key(ctx, state_host_out);
} else {
download_state_outputs(state_host_out, all);
}
@@ -8154,6 +8196,17 @@ extern "C" int z4c_cuda_has_resident_state(void *block_tag)
return has_resident_state(block_tag) ? 1 : 0;
}
extern "C" int z4c_cuda_resident_state_matches(void *block_tag,
double **state_host_key)
{
using namespace z4c_cuda;
init_gpu_dispatch();
CUDA_CHECK(cudaSetDevice(g_dispatch.my_device));
auto it = g_step_ctx.find(block_tag);
if (it == g_step_ctx.end()) return 0;
return resident_key_matches(it->second, state_host_key) ? 1 : 0;
}
extern "C" void z4c_cuda_release_step_ctx(void *block_tag)
{
using namespace z4c_cuda;