Batch GPU stage downloads
This commit is contained in:
@@ -1103,6 +1103,77 @@ int bssn_gpu_stage_zero_buffer(const double *host_ptr, int count)
|
||||
return prepare_owned_buffer(host_ptr, static_cast<size_t>(count), true) ? 0 : 1;
|
||||
}
|
||||
|
||||
int bssn_gpu_download_buffer_batch(const int *ex, double **host_ptrs, int num_buffers)
|
||||
{
|
||||
if (!ex || !host_ptrs || num_buffers <= 0)
|
||||
return 1;
|
||||
|
||||
static thread_local cudaStream_t stream = nullptr;
|
||||
static thread_local cudaEvent_t ready = nullptr;
|
||||
if (!stream)
|
||||
{
|
||||
cudaError_t err = cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
cerr << "cudaStreamCreateWithFlags failed: " << cudaGetErrorString(err) << endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (!ready)
|
||||
{
|
||||
cudaError_t err = cudaEventCreateWithFlags(&ready, cudaEventDisableTiming);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
cerr << "cudaEventCreateWithFlags failed: " << cudaGetErrorString(err) << endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
cudaError_t err = cudaEventRecord(ready, 0);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
cerr << "cudaEventRecord download readiness failed: " << cudaGetErrorString(err) << endl;
|
||||
return 1;
|
||||
}
|
||||
err = cudaStreamWaitEvent(stream, ready, 0);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
cerr << "cudaStreamWaitEvent download readiness failed: " << cudaGetErrorString(err) << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int n = 1;
|
||||
for (int i = 0; i < 3; ++i)
|
||||
n *= ex[i];
|
||||
const size_t bytes = static_cast<size_t>(n) * sizeof(double);
|
||||
|
||||
for (int i = 0; i < num_buffers; ++i)
|
||||
{
|
||||
double *host_ptr = host_ptrs[i];
|
||||
if (!host_ptr)
|
||||
return 1;
|
||||
const double *device_ptr = bssn_gpu_find_device_buffer(host_ptr);
|
||||
if (!device_ptr)
|
||||
return 1;
|
||||
bssn_gpu_prepare_host_buffer(host_ptr, n);
|
||||
err = cudaMemcpyAsync(host_ptr, device_ptr, bytes, cudaMemcpyDeviceToHost, stream);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
cerr << "cudaMemcpyAsync(D2H) buffered batch download failed: "
|
||||
<< cudaGetErrorString(err) << endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
err = cudaStreamSynchronize(stream);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
cerr << "cudaStreamSynchronize buffered batch download failed: "
|
||||
<< cudaGetErrorString(err) << endl;
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int bssn_gpu_stage_upload_region(const double *host_ptr,
|
||||
const int *full_shape,
|
||||
const double *full_llb,
|
||||
|
||||
Reference in New Issue
Block a user