213 lines
6.2 KiB
C
213 lines
6.2 KiB
C
#include "rungekutta4_rout.h"
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <cstddef>
|
|
#include <complex>
|
|
#include <immintrin.h>
|
|
|
|
namespace {
|
|
|
|
inline void rk4_stage0(std::size_t n,
|
|
const double *__restrict f0,
|
|
const double *__restrict frhs,
|
|
double *__restrict f1,
|
|
double c) {
|
|
std::size_t i = 0;
|
|
#if defined(__AVX512F__)
|
|
const __m512d vc = _mm512_set1_pd(c);
|
|
for (; i + 7 < n; i += 8) {
|
|
const __m512d v0 = _mm512_loadu_pd(f0 + i);
|
|
const __m512d vr = _mm512_loadu_pd(frhs + i);
|
|
_mm512_storeu_pd(f1 + i, _mm512_fmadd_pd(vc, vr, v0));
|
|
}
|
|
#elif defined(__AVX2__)
|
|
const __m256d vc = _mm256_set1_pd(c);
|
|
for (; i + 3 < n; i += 4) {
|
|
const __m256d v0 = _mm256_loadu_pd(f0 + i);
|
|
const __m256d vr = _mm256_loadu_pd(frhs + i);
|
|
_mm256_storeu_pd(f1 + i, _mm256_fmadd_pd(vc, vr, v0));
|
|
}
|
|
#endif
|
|
#pragma ivdep
|
|
for (; i < n; ++i) {
|
|
f1[i] = f0[i] + c * frhs[i];
|
|
}
|
|
}
|
|
|
|
inline void rk4_rhs_accum(std::size_t n,
|
|
const double *__restrict f1,
|
|
double *__restrict frhs) {
|
|
std::size_t i = 0;
|
|
#if defined(__AVX512F__)
|
|
const __m512d v2 = _mm512_set1_pd(2.0);
|
|
for (; i + 7 < n; i += 8) {
|
|
const __m512d v1 = _mm512_loadu_pd(f1 + i);
|
|
const __m512d vrhs = _mm512_loadu_pd(frhs + i);
|
|
_mm512_storeu_pd(frhs + i, _mm512_fmadd_pd(v2, v1, vrhs));
|
|
}
|
|
#elif defined(__AVX2__)
|
|
const __m256d v2 = _mm256_set1_pd(2.0);
|
|
for (; i + 3 < n; i += 4) {
|
|
const __m256d v1 = _mm256_loadu_pd(f1 + i);
|
|
const __m256d vrhs = _mm256_loadu_pd(frhs + i);
|
|
_mm256_storeu_pd(frhs + i, _mm256_fmadd_pd(v2, v1, vrhs));
|
|
}
|
|
#endif
|
|
#pragma ivdep
|
|
for (; i < n; ++i) {
|
|
frhs[i] = frhs[i] + 2.0 * f1[i];
|
|
}
|
|
}
|
|
|
|
inline void rk4_f1_from_f0_f1(std::size_t n,
|
|
const double *__restrict f0,
|
|
double *__restrict f1,
|
|
double c) {
|
|
std::size_t i = 0;
|
|
#if defined(__AVX512F__)
|
|
const __m512d vc = _mm512_set1_pd(c);
|
|
for (; i + 7 < n; i += 8) {
|
|
const __m512d v0 = _mm512_loadu_pd(f0 + i);
|
|
const __m512d v1 = _mm512_loadu_pd(f1 + i);
|
|
_mm512_storeu_pd(f1 + i, _mm512_fmadd_pd(vc, v1, v0));
|
|
}
|
|
#elif defined(__AVX2__)
|
|
const __m256d vc = _mm256_set1_pd(c);
|
|
for (; i + 3 < n; i += 4) {
|
|
const __m256d v0 = _mm256_loadu_pd(f0 + i);
|
|
const __m256d v1 = _mm256_loadu_pd(f1 + i);
|
|
_mm256_storeu_pd(f1 + i, _mm256_fmadd_pd(vc, v1, v0));
|
|
}
|
|
#endif
|
|
#pragma ivdep
|
|
for (; i < n; ++i) {
|
|
f1[i] = f0[i] + c * f1[i];
|
|
}
|
|
}
|
|
|
|
inline void rk4_stage3(std::size_t n,
|
|
const double *__restrict f0,
|
|
double *__restrict f1,
|
|
const double *__restrict frhs,
|
|
double c) {
|
|
std::size_t i = 0;
|
|
#if defined(__AVX512F__)
|
|
const __m512d vc = _mm512_set1_pd(c);
|
|
for (; i + 7 < n; i += 8) {
|
|
const __m512d v0 = _mm512_loadu_pd(f0 + i);
|
|
const __m512d v1 = _mm512_loadu_pd(f1 + i);
|
|
const __m512d vr = _mm512_loadu_pd(frhs + i);
|
|
_mm512_storeu_pd(f1 + i, _mm512_fmadd_pd(vc, _mm512_add_pd(v1, vr), v0));
|
|
}
|
|
#elif defined(__AVX2__)
|
|
const __m256d vc = _mm256_set1_pd(c);
|
|
for (; i + 3 < n; i += 4) {
|
|
const __m256d v0 = _mm256_loadu_pd(f0 + i);
|
|
const __m256d v1 = _mm256_loadu_pd(f1 + i);
|
|
const __m256d vr = _mm256_loadu_pd(frhs + i);
|
|
_mm256_storeu_pd(f1 + i, _mm256_fmadd_pd(vc, _mm256_add_pd(v1, vr), v0));
|
|
}
|
|
#endif
|
|
#pragma ivdep
|
|
for (; i < n; ++i) {
|
|
f1[i] = f0[i] + c * (f1[i] + frhs[i]);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
extern "C" {
|
|
|
|
void f_rungekutta4_scalar(double &dT, double &f0, double &f1, double &f_rhs, int &RK4) {
|
|
constexpr double F1o6 = 1.0 / 6.0;
|
|
constexpr double HLF = 0.5;
|
|
constexpr double TWO = 2.0;
|
|
|
|
switch (RK4) {
|
|
case 0:
|
|
f1 = f0 + HLF * dT * f_rhs;
|
|
break;
|
|
case 1:
|
|
f_rhs = f_rhs + TWO * f1;
|
|
f1 = f0 + HLF * dT * f1;
|
|
break;
|
|
case 2:
|
|
f_rhs = f_rhs + TWO * f1;
|
|
f1 = f0 + dT * f1;
|
|
break;
|
|
case 3:
|
|
f1 = f0 + F1o6 * dT * (f1 + f_rhs);
|
|
break;
|
|
default:
|
|
std::fprintf(stderr, "rungekutta4_scalar_c: invalid RK4 stage %d\n", RK4);
|
|
std::abort();
|
|
}
|
|
}
|
|
|
|
void rungekutta4_cplxscalar_(double &dT,
|
|
std::complex<double> &f0,
|
|
std::complex<double> &f1,
|
|
std::complex<double> &f_rhs,
|
|
int &RK4) {
|
|
constexpr double F1o6 = 1.0 / 6.0;
|
|
constexpr double HLF = 0.5;
|
|
constexpr double TWO = 2.0;
|
|
|
|
switch (RK4) {
|
|
case 0:
|
|
f1 = f0 + HLF * dT * f_rhs;
|
|
break;
|
|
case 1:
|
|
f_rhs = f_rhs + TWO * f1;
|
|
f1 = f0 + HLF * dT * f1;
|
|
break;
|
|
case 2:
|
|
f_rhs = f_rhs + TWO * f1;
|
|
f1 = f0 + dT * f1;
|
|
break;
|
|
case 3:
|
|
f1 = f0 + F1o6 * dT * (f1 + f_rhs);
|
|
break;
|
|
default:
|
|
std::fprintf(stderr, "rungekutta4_cplxscalar_c: invalid RK4 stage %d\n", RK4);
|
|
std::abort();
|
|
}
|
|
}
|
|
|
|
int f_rungekutta4_rout(int *ex, double &dT,
|
|
double *f0, double *f1, double *f_rhs,
|
|
int &RK4) {
|
|
const std::size_t n = static_cast<std::size_t>(ex[0]) *
|
|
static_cast<std::size_t>(ex[1]) *
|
|
static_cast<std::size_t>(ex[2]);
|
|
const double *const __restrict f0r = f0;
|
|
double *const __restrict f1r = f1;
|
|
double *const __restrict frhs = f_rhs;
|
|
|
|
if (__builtin_expect(static_cast<unsigned>(RK4) > 3u, 0)) {
|
|
std::fprintf(stderr, "rungekutta4_rout_c: invalid RK4 stage %d\n", RK4);
|
|
std::abort();
|
|
}
|
|
|
|
switch (RK4) {
|
|
case 0:
|
|
rk4_stage0(n, f0r, frhs, f1r, 0.5 * dT);
|
|
break;
|
|
case 1:
|
|
rk4_rhs_accum(n, f1r, frhs);
|
|
rk4_f1_from_f0_f1(n, f0r, f1r, 0.5 * dT);
|
|
break;
|
|
case 2:
|
|
rk4_rhs_accum(n, f1r, frhs);
|
|
rk4_f1_from_f0_f1(n, f0r, f1r, dT);
|
|
break;
|
|
default:
|
|
rk4_stage3(n, f0r, f1r, frhs, (1.0 / 6.0) * dT);
|
|
break;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
} // extern "C"
|