more experiments for avx512

This commit is contained in:
minjaesong
2025-11-26 02:36:49 +09:00
parent acaade1062
commit 3b401139e9

View File

@@ -188,185 +188,305 @@ static inline void dwt_53_forward_1d_avx512(float *data, int length) {
static inline void dwt_97_forward_1d_avx512(float *data, int length) {
if (length < 2) return;
float *temp = (float*)malloc(length * sizeof(float));
int half = (length + 1) / 2;
// Split into even/odd - vectorized gather
int i;
for (i = 0; i + 16 <= half; i += 16) {
float even_vals[16], odd_vals[16];
for (int j = 0; j < 16; j++) {
even_vals[j] = data[2 * (i + j)];
if (2 * (i + j) + 1 < length) {
odd_vals[j] = data[2 * (i + j) + 1];
} else {
odd_vals[j] = 0.0f;
}
// Allocate aligned temp buffer once (64-byte align for cache lines)
float *temp = NULL;
#if defined(_POSIX_C_SOURCE) || defined(_XOPEN_SOURCE)
if (posix_memalign((void**)&temp, 64, (size_t)length * sizeof(float)) != 0) {
temp = (float*)malloc((size_t)length * sizeof(float));
}
#else
temp = (float*)aligned_alloc(64, ((size_t)length * sizeof(float) + 63) & ~63);
if (!temp) temp = (float*)malloc((size_t)length * sizeof(float));
#endif
if (!temp) return; // allocation failure: bail out (preserve original behavior could be different)
// FAST SPLIT: interleave into temp: first half = evens, second half = odds
// This is simple, streaming-friendly, and much faster than per-iteration small-array gathers.
{
float *even = temp;
float *odd = temp + half;
int i = 0;
// process pairs to minimize branches and memory ops
for (; i + 1 < length; i += 2) {
even[0] = data[i];
odd[0] = data[i + 1];
++even; ++odd;
}
_mm512_storeu_ps(&temp[i], _mm512_loadu_ps(even_vals));
if (i < length / 2) {
__mmask16 mask = ((i + 16) <= length / 2) ? 0xFFFF : (1 << (length / 2 - i)) - 1;
_mm512_mask_storeu_ps(&temp[half + i], mask, _mm512_loadu_ps(odd_vals));
if (i < length) { // odd leftover
even[0] = data[i];
}
}
// Remaining scalar
for (; i < half; i++) {
temp[i] = data[2 * i];
}
for (i = 0; i < length / 2; i++) {
temp[half + i] = data[2 * i + 1];
}
// Lifting coefficients
// Lifting coefficients as vectors
const __m512 alpha_vec = _mm512_set1_ps(-1.586134342f);
const __m512 beta_vec = _mm512_set1_ps(-0.052980118f);
const __m512 beta_vec = _mm512_set1_ps(-0.052980118f);
const __m512 gamma_vec = _mm512_set1_ps(0.882911076f);
const __m512 delta_vec = _mm512_set1_ps(0.443506852f);
const __m512 K_vec = _mm512_set1_ps(1.230174105f);
const __m512 invK_vec = _mm512_set1_ps(1.0f / 1.230174105f);
const __m512 K_vec = _mm512_set1_ps(1.230174105f);
const __m512 invK_vec = _mm512_set1_ps(1.0f / 1.230174105f);
// Step 1: Predict α - d[i] += α * (s[i] + s[i+1])
for (i = 0; i + 16 <= length / 2; i += 16) {
__mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1;
// Helper variables
int i;
float s_curr_vals[16], s_next_vals[16], d_vals[16];
for (int j = 0; j < 16; j++) {
s_curr_vals[j] = temp[i + j];
s_next_vals[j] = ((i + j + 1) < half) ? temp[i + j + 1] : temp[i + j];
if ((half + i + j) < length) {
d_vals[j] = temp[half + i + j];
// -----------------------
// Step 1: Predict α
// d[i] += alpha * (s[i] + s[i+1])
// -----------------------
if (half > 0) {
// handle small or trivial cases
if (half == 1) {
if (half < length) {
temp[half + 0] += -1.586134342f * (temp[0] + temp[0]);
}
} else {
// main vectorized body: ensure s_next loads (i+1) valid -> i <= half-2
int limit = (half - 1);
int n_full = (limit / 16) * 16; // process up to n_full (multiple of 16)
i = 0;
for (; i + 32 <= n_full; i += 32) {
// unroll 2x (i and i+16)
__m512 s0 = _mm512_loadu_ps(&temp[i]);
__m512 s0n = _mm512_loadu_ps(&temp[i + 1]);
__m512 d0 = _mm512_loadu_ps(&temp[half + i]);
__m512 sum0 = _mm512_add_ps(s0, s0n);
d0 = _mm512_fmadd_ps(alpha_vec, sum0, d0);
_mm512_storeu_ps(&temp[half + i], d0);
__m512 s1 = _mm512_loadu_ps(&temp[i + 16]);
__m512 s1n = _mm512_loadu_ps(&temp[i + 17]);
__m512 d1 = _mm512_loadu_ps(&temp[half + i + 16]);
__m512 sum1 = _mm512_add_ps(s1, s1n);
d1 = _mm512_fmadd_ps(alpha_vec, sum1, d1);
_mm512_storeu_ps(&temp[half + i + 16], d1);
}
for (; i + 16 <= n_full; i += 16) {
__m512 s = _mm512_loadu_ps(&temp[i]);
__m512 sn = _mm512_loadu_ps(&temp[i + 1]);
__m512 d = _mm512_loadu_ps(&temp[half + i]);
__m512 sum = _mm512_add_ps(s, sn);
d = _mm512_fmadd_ps(alpha_vec, sum, d);
_mm512_storeu_ps(&temp[half + i], d);
}
// scalar remainder up to limit (half-2 -> last vector handled below)
for (; i < limit; ++i) {
temp[half + i] += -1.586134342f * (temp[i] + temp[i + 1]);
}
// handle last index i = half-1 (mirror)
int last = half - 1;
if (half + last < length) {
float s_curr = temp[last];
float s_next = s_curr;
temp[half + last] += -1.586134342f * (s_curr + s_next);
}
}
}
// -----------------------
// Step 2: Update β
// s[i] += beta * (d[i-1] + d[i])
// -----------------------
if (half > 0) {
// handle i == 0 separately (d_prev = d_curr for boundary semantics)
if (half >= 1) {
// i == 0
if (half + 0 < length) {
float d_curr0 = temp[half + 0];
temp[0] += -0.052980118f * (d_curr0 + d_curr0);
}
}
__m512 s_curr = _mm512_loadu_ps(s_curr_vals);
__m512 s_next = _mm512_loadu_ps(s_next_vals);
__m512 d = _mm512_maskz_loadu_ps(mask, d_vals);
if (half > 1) {
// main vector loop starting from i = 1 to half-1 (we will write s[i] for i>=1)
int start = 1;
int limit = half; // exclusive
int n_elems = limit - start;
int n_full = (n_elems / 16) * 16;
i = start;
for (; i + 32 <= start + n_full; i += 32) {
// unroll 2x
__m512 s0 = _mm512_loadu_ps(&temp[i]);
__m512 dcurr0 = _mm512_loadu_ps(&temp[half + i]);
__m512 dprev0 = _mm512_loadu_ps(&temp[half + i - 1]);
__m512 sum0 = _mm512_add_ps(dprev0, dcurr0);
s0 = _mm512_fmadd_ps(beta_vec, sum0, s0);
_mm512_storeu_ps(&temp[i], s0);
__m512 sum = _mm512_add_ps(s_curr, s_next);
d = _mm512_fmadd_ps(alpha_vec, sum, d); // d += alpha * sum
_mm512_mask_storeu_ps(&temp[half + i], mask, d);
}
// Remaining scalar for step 1
for (; i < length / 2; i++) {
if (half + i < length) {
float s_curr = temp[i];
float s_next = (i + 1 < half) ? temp[i + 1] : s_curr;
temp[half + i] += -1.586134342f * (s_curr + s_next);
__m512 s1 = _mm512_loadu_ps(&temp[i + 16]);
__m512 dcurr1 = _mm512_loadu_ps(&temp[half + i + 16]);
__m512 dprev1 = _mm512_loadu_ps(&temp[half + i + 15]);
__m512 sum1 = _mm512_add_ps(dprev1, dcurr1);
s1 = _mm512_fmadd_ps(beta_vec, sum1, s1);
_mm512_storeu_ps(&temp[i + 16], s1);
}
for (; i + 16 <= start + n_full; i += 16) {
__m512 s = _mm512_loadu_ps(&temp[i]);
__m512 dcurr = _mm512_loadu_ps(&temp[half + i]);
__m512 dprev = _mm512_loadu_ps(&temp[half + i - 1]);
__m512 sum = _mm512_add_ps(dprev, dcurr);
s = _mm512_fmadd_ps(beta_vec, sum, s);
_mm512_storeu_ps(&temp[i], s);
}
// scalar remainder
for (; i < limit; ++i) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (half + i - 1 < length && i > 0) ? temp[half + i - 1] : d_curr;
temp[i] += -0.052980118f * (d_prev + d_curr);
}
}
}
// Step 2: Update β - s[i] += β * (d[i-1] + d[i])
for (i = 0; i + 16 <= half; i += 16) {
__mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1;
float s_vals[16], d_curr_vals[16], d_prev_vals[16];
for (int j = 0; j < 16; j++) {
s_vals[j] = temp[i + j];
d_curr_vals[j] = ((half + i + j) < length) ? temp[half + i + j] : 0.0f;
d_prev_vals[j] = ((i + j) > 0 && (half + i + j - 1) < length) ? temp[half + i + j - 1] : d_curr_vals[j];
}
__m512 s = _mm512_loadu_ps(s_vals);
__m512 d_curr = _mm512_loadu_ps(d_curr_vals);
__m512 d_prev = _mm512_loadu_ps(d_prev_vals);
s = _mm512_fmadd_ps(beta_vec, _mm512_add_ps(d_prev, d_curr), s);
_mm512_mask_storeu_ps(&temp[i], mask, s);
}
// Scalar remainder for step 2
for (; i < half; i++) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr;
temp[i] += -0.052980118f * (d_prev + d_curr);
}
// -----------------------
// Step 3: Predict γ
for (i = 0; i + 16 <= length / 2; i += 16) {
__mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1;
// d[i] += gamma * (s[i] + s[i+1])
// -----------------------
if (half > 0) {
if (half == 1) {
if (half < length) {
temp[half + 0] += 0.882911076f * (temp[0] + temp[0]);
}
} else {
int limit = (half - 1);
int n_full = (limit / 16) * 16;
i = 0;
for (; i + 32 <= n_full; i += 32) {
__m512 s0 = _mm512_loadu_ps(&temp[i]);
__m512 s0n = _mm512_loadu_ps(&temp[i + 1]);
__m512 d0 = _mm512_loadu_ps(&temp[half + i]);
__m512 sum0 = _mm512_add_ps(s0, s0n);
d0 = _mm512_fmadd_ps(gamma_vec, sum0, d0);
_mm512_storeu_ps(&temp[half + i], d0);
float s_curr_vals[16], s_next_vals[16], d_vals[16];
for (int j = 0; j < 16; j++) {
s_curr_vals[j] = temp[i + j];
s_next_vals[j] = ((i + j + 1) < half) ? temp[i + j + 1] : temp[i + j];
if ((half + i + j) < length) {
d_vals[j] = temp[half + i + j];
__m512 s1 = _mm512_loadu_ps(&temp[i + 16]);
__m512 s1n = _mm512_loadu_ps(&temp[i + 17]);
__m512 d1 = _mm512_loadu_ps(&temp[half + i + 16]);
__m512 sum1 = _mm512_add_ps(s1, s1n);
d1 = _mm512_fmadd_ps(gamma_vec, sum1, d1);
_mm512_storeu_ps(&temp[half + i + 16], d1);
}
for (; i + 16 <= n_full; i += 16) {
__m512 s = _mm512_loadu_ps(&temp[i]);
__m512 sn = _mm512_loadu_ps(&temp[i + 1]);
__m512 d = _mm512_loadu_ps(&temp[half + i]);
__m512 sum = _mm512_add_ps(s, sn);
d = _mm512_fmadd_ps(gamma_vec, sum, d);
_mm512_storeu_ps(&temp[half + i], d);
}
for (; i < limit; ++i) {
temp[half + i] += 0.882911076f * (temp[i] + temp[i + 1]);
}
// last index mirror
int last = half - 1;
if (half + last < length) {
float s_curr = temp[last];
float s_next = s_curr;
temp[half + last] += 0.882911076f * (s_curr + s_next);
}
}
}
// -----------------------
// Step 4: Update δ
// s[i] += delta * (d[i-1] + d[i])
// -----------------------
if (half > 0) {
// i == 0
if (half >= 1) {
if (half + 0 < length) {
float d_curr0 = temp[half + 0];
temp[0] += 0.443506852f * (d_curr0 + d_curr0);
}
}
__m512 s_curr = _mm512_loadu_ps(s_curr_vals);
__m512 s_next = _mm512_loadu_ps(s_next_vals);
__m512 d = _mm512_maskz_loadu_ps(mask, d_vals);
if (half > 1) {
int start = 1;
int limit = half; // exclusive
int n_elems = limit - start;
int n_full = (n_elems / 16) * 16;
i = start;
for (; i + 32 <= start + n_full; i += 32) {
__m512 s0 = _mm512_loadu_ps(&temp[i]);
__m512 dcurr0 = _mm512_loadu_ps(&temp[half + i]);
__m512 dprev0 = _mm512_loadu_ps(&temp[half + i - 1]);
__m512 sum0 = _mm512_add_ps(dprev0, dcurr0);
s0 = _mm512_fmadd_ps(delta_vec, sum0, s0);
_mm512_storeu_ps(&temp[i], s0);
d = _mm512_fmadd_ps(gamma_vec, _mm512_add_ps(s_curr, s_next), d);
_mm512_mask_storeu_ps(&temp[half + i], mask, d);
}
// Scalar remainder for step 3
for (; i < length / 2; i++) {
if (half + i < length) {
float s_curr = temp[i];
float s_next = (i + 1 < half) ? temp[i + 1] : s_curr;
temp[half + i] += 0.882911076f * (s_curr + s_next);
__m512 s1 = _mm512_loadu_ps(&temp[i + 16]);
__m512 dcurr1 = _mm512_loadu_ps(&temp[half + i + 16]);
__m512 dprev1 = _mm512_loadu_ps(&temp[half + i + 15]);
__m512 sum1 = _mm512_add_ps(dprev1, dcurr1);
s1 = _mm512_fmadd_ps(delta_vec, sum1, s1);
_mm512_storeu_ps(&temp[i + 16], s1);
}
for (; i + 16 <= start + n_full; i += 16) {
__m512 s = _mm512_loadu_ps(&temp[i]);
__m512 dcurr = _mm512_loadu_ps(&temp[half + i]);
__m512 dprev = _mm512_loadu_ps(&temp[half + i - 1]);
__m512 sum = _mm512_add_ps(dprev, dcurr);
s = _mm512_fmadd_ps(delta_vec, sum, s);
_mm512_storeu_ps(&temp[i], s);
}
for (; i < limit; ++i) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (half + i - 1 < length && i > 0) ? temp[half + i - 1] : d_curr;
temp[i] += 0.443506852f * (d_prev + d_curr);
}
}
}
// Step 4: Update δ
for (i = 0; i + 16 <= half; i += 16) {
__mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1;
// -----------------------
// Step 5: Scaling
// s *= K, d *= invK
// -----------------------
// s (first half)
{
int n_full = (half / 16) * 16;
i = 0;
for (; i + 32 <= n_full; i += 32) {
__m512 s0 = _mm512_loadu_ps(&temp[i]);
s0 = _mm512_mul_ps(s0, K_vec);
_mm512_storeu_ps(&temp[i], s0);
float s_vals[16], d_curr_vals[16], d_prev_vals[16];
for (int j = 0; j < 16; j++) {
s_vals[j] = temp[i + j];
d_curr_vals[j] = ((half + i + j) < length) ? temp[half + i + j] : 0.0f;
d_prev_vals[j] = ((i + j) > 0 && (half + i + j - 1) < length) ? temp[half + i + j - 1] : d_curr_vals[j];
__m512 s1 = _mm512_loadu_ps(&temp[i + 16]);
s1 = _mm512_mul_ps(s1, K_vec);
_mm512_storeu_ps(&temp[i + 16], s1);
}
__m512 s = _mm512_loadu_ps(s_vals);
__m512 d_curr = _mm512_loadu_ps(d_curr_vals);
__m512 d_prev = _mm512_loadu_ps(d_prev_vals);
s = _mm512_fmadd_ps(delta_vec, _mm512_add_ps(d_prev, d_curr), s);
_mm512_mask_storeu_ps(&temp[i], mask, s);
for (; i + 16 <= n_full; i += 16) {
__m512 s = _mm512_loadu_ps(&temp[i]);
s = _mm512_mul_ps(s, K_vec);
_mm512_storeu_ps(&temp[i], s);
}
for (; i < half; ++i) temp[i] *= 1.230174105f;
}
// Scalar remainder for step 4
for (; i < half; i++) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr;
temp[i] += 0.443506852f * (d_prev + d_curr);
}
// d (second half)
{
int dlen = length - half;
int n_full = (dlen / 16) * 16;
i = 0;
for (; i + 32 <= n_full; i += 32) {
__m512 d0 = _mm512_loadu_ps(&temp[half + i]);
d0 = _mm512_mul_ps(d0, invK_vec);
_mm512_storeu_ps(&temp[half + i], d0);
// Step 5: Scaling - vectorized
for (i = 0; i + 16 <= half; i += 16) {
__mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1;
__m512 s = _mm512_maskz_loadu_ps(mask, &temp[i]);
s = _mm512_mul_ps(s, K_vec);
_mm512_mask_storeu_ps(&temp[i], mask, s);
}
for (; i < half; i++) {
temp[i] *= 1.230174105f;
}
for (i = 0; i + 16 <= length / 2; i += 16) {
__mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1;
__m512 d = _mm512_maskz_loadu_ps(mask, &temp[half + i]);
d = _mm512_mul_ps(d, invK_vec);
_mm512_mask_storeu_ps(&temp[half + i], mask, d);
}
for (; i < length / 2; i++) {
if (half + i < length) {
temp[half + i] /= 1.230174105f;
__m512 d1 = _mm512_loadu_ps(&temp[half + i + 16]);
d1 = _mm512_mul_ps(d1, invK_vec);
_mm512_storeu_ps(&temp[half + i + 16], d1);
}
for (; i + 16 <= n_full; i += 16) {
__m512 d = _mm512_loadu_ps(&temp[half + i]);
d = _mm512_mul_ps(d, invK_vec);
_mm512_storeu_ps(&temp[half + i], d);
}
for (; i < dlen; ++i) {
if (half + i < length) temp[half + i] /= 1.230174105f;
}
}
memcpy(data, temp, length * sizeof(float));
// Copy back and free
memcpy(data, temp, (size_t)length * sizeof(float));
free(temp);
}