Stabilize device AMR restrict across regrid

This commit is contained in:
2026-04-30 20:01:18 +08:00
parent be9033f449
commit b1974ef146
3 changed files with 247 additions and 0 deletions

View File

@@ -6404,6 +6404,45 @@ static void download_resident_state(void *block_tag, int *ex, double **state_hos
}
}
static bool download_resident_state_if_present(void *block_tag, int *ex, double **state_host_out)
{
auto it = g_step_ctx.find(block_tag);
if (it == g_step_ctx.end()) return false;
StepContext &ctx = it->second;
const int bank = find_resident_bank(ctx, state_host_out);
if (bank < 0 || !ctx.resident_valid[bank])
return false;
const size_t all = (size_t)ex[0] * ex[1] * ex[2];
const size_t bytes = all * sizeof(double);
mark_resident_current_bank(ctx, bank);
if (resident_host_subset_clean(ctx, bank, BSSN_STATE_COUNT, nullptr))
return true;
static int direct_download = -1;
if (direct_download < 0) {
const char *env = getenv("AMSS_CUDA_DIRECT_STATE_DOWNLOAD");
direct_download = env ? ((atoi(env) != 0) ? 1 : 0) : 1;
}
if (direct_download) {
for (int i = 0; i < BSSN_STATE_COUNT; ++i) {
CUDA_CHECK(cudaMemcpyAsync(state_host_out[i], ctx.d_resident[bank][i],
bytes, cudaMemcpyDeviceToHost));
}
CUDA_CHECK(cudaDeviceSynchronize());
} else {
CUDA_CHECK(cudaMemcpy(g_buf.h_stage, ctx.d_resident_mem[bank],
(size_t)BSSN_STATE_COUNT * bytes,
cudaMemcpyDeviceToHost));
for (int i = 0; i < BSSN_STATE_COUNT; ++i) {
std::memcpy(state_host_out[i], g_buf.h_stage + (size_t)i * all, bytes);
}
}
set_resident_host_clean(ctx, bank, true);
return true;
}
static void copy_state_subset(void *block_tag,
int *ex,
int subset_count,
@@ -7056,6 +7095,18 @@ int bssn_cuda_download_resident_state(void *block_tag,
return 0;
}
extern "C"
int bssn_cuda_download_resident_state_if_present(void *block_tag,
int *ex,
double **state_host_out)
{
init_gpu_dispatch();
CUDA_CHECK(cudaSetDevice(g_dispatch.my_device));
if (!block_tag || !ex || !state_host_out) return 1;
download_resident_state_if_present(block_tag, ex, state_host_out);
return 0;
}
extern "C"
int bssn_cuda_download_constraint_outputs(int *ex,
double **constraint_host_out)