#include "rungekutta4_rout.h" #include #include #include #include namespace { inline void rk4_stage0(std::size_t n, const double *__restrict f0, const double *__restrict frhs, double *__restrict f1, double c) { std::size_t i = 0; #if defined(__AVX512F__) const __m512d vc = _mm512_set1_pd(c); for (; i + 7 < n; i += 8) { const __m512d v0 = _mm512_loadu_pd(f0 + i); const __m512d vr = _mm512_loadu_pd(frhs + i); _mm512_storeu_pd(f1 + i, _mm512_fmadd_pd(vc, vr, v0)); } #elif defined(__AVX2__) const __m256d vc = _mm256_set1_pd(c); for (; i + 3 < n; i += 4) { const __m256d v0 = _mm256_loadu_pd(f0 + i); const __m256d vr = _mm256_loadu_pd(frhs + i); _mm256_storeu_pd(f1 + i, _mm256_fmadd_pd(vc, vr, v0)); } #endif #pragma ivdep for (; i < n; ++i) { f1[i] = f0[i] + c * frhs[i]; } } inline void rk4_rhs_accum(std::size_t n, const double *__restrict f1, double *__restrict frhs) { std::size_t i = 0; #if defined(__AVX512F__) const __m512d v2 = _mm512_set1_pd(2.0); for (; i + 7 < n; i += 8) { const __m512d v1 = _mm512_loadu_pd(f1 + i); const __m512d vrhs = _mm512_loadu_pd(frhs + i); _mm512_storeu_pd(frhs + i, _mm512_fmadd_pd(v2, v1, vrhs)); } #elif defined(__AVX2__) const __m256d v2 = _mm256_set1_pd(2.0); for (; i + 3 < n; i += 4) { const __m256d v1 = _mm256_loadu_pd(f1 + i); const __m256d vrhs = _mm256_loadu_pd(frhs + i); _mm256_storeu_pd(frhs + i, _mm256_fmadd_pd(v2, v1, vrhs)); } #endif #pragma ivdep for (; i < n; ++i) { frhs[i] = frhs[i] + 2.0 * f1[i]; } } inline void rk4_f1_from_f0_f1(std::size_t n, const double *__restrict f0, double *__restrict f1, double c) { std::size_t i = 0; #if defined(__AVX512F__) const __m512d vc = _mm512_set1_pd(c); for (; i + 7 < n; i += 8) { const __m512d v0 = _mm512_loadu_pd(f0 + i); const __m512d v1 = _mm512_loadu_pd(f1 + i); _mm512_storeu_pd(f1 + i, _mm512_fmadd_pd(vc, v1, v0)); } #elif defined(__AVX2__) const __m256d vc = _mm256_set1_pd(c); for (; i + 3 < n; i += 4) { const __m256d v0 = _mm256_loadu_pd(f0 + i); const __m256d v1 = _mm256_loadu_pd(f1 + i); _mm256_storeu_pd(f1 + i, _mm256_fmadd_pd(vc, v1, v0)); } #endif #pragma ivdep for (; i < n; ++i) { f1[i] = f0[i] + c * f1[i]; } } inline void rk4_stage3(std::size_t n, const double *__restrict f0, double *__restrict f1, const double *__restrict frhs, double c) { std::size_t i = 0; #if defined(__AVX512F__) const __m512d vc = _mm512_set1_pd(c); for (; i + 7 < n; i += 8) { const __m512d v0 = _mm512_loadu_pd(f0 + i); const __m512d v1 = _mm512_loadu_pd(f1 + i); const __m512d vr = _mm512_loadu_pd(frhs + i); _mm512_storeu_pd(f1 + i, _mm512_fmadd_pd(vc, _mm512_add_pd(v1, vr), v0)); } #elif defined(__AVX2__) const __m256d vc = _mm256_set1_pd(c); for (; i + 3 < n; i += 4) { const __m256d v0 = _mm256_loadu_pd(f0 + i); const __m256d v1 = _mm256_loadu_pd(f1 + i); const __m256d vr = _mm256_loadu_pd(frhs + i); _mm256_storeu_pd(f1 + i, _mm256_fmadd_pd(vc, _mm256_add_pd(v1, vr), v0)); } #endif #pragma ivdep for (; i < n; ++i) { f1[i] = f0[i] + c * (f1[i] + frhs[i]); } } } // namespace extern "C" { int f_rungekutta4_rout(int *ex, double &dT, double *f0, double *f1, double *f_rhs, int &RK4) { const std::size_t n = static_cast(ex[0]) * static_cast(ex[1]) * static_cast(ex[2]); const double *const __restrict f0r = f0; double *const __restrict f1r = f1; double *const __restrict frhs = f_rhs; if (__builtin_expect(static_cast(RK4) > 3u, 0)) { std::fprintf(stderr, "rungekutta4_rout_c: invalid RK4 stage %d\n", RK4); std::abort(); } switch (RK4) { case 0: rk4_stage0(n, f0r, frhs, f1r, 0.5 * dT); break; case 1: rk4_rhs_accum(n, f1r, frhs); rk4_f1_from_f0_f1(n, f0r, f1r, 0.5 * dT); break; case 2: rk4_rhs_accum(n, f1r, frhs); rk4_f1_from_f0_f1(n, f0r, f1r, dT); break; default: rk4_stage3(n, f0r, f1r, frhs, (1.0 / 6.0) * dT); break; } return 0; } } // extern "C"