diff --git a/video_encoder/tav_avx512.h b/video_encoder/tav_avx512.h index 2d6736c..7694e59 100644 --- a/video_encoder/tav_avx512.h +++ b/video_encoder/tav_avx512.h @@ -188,185 +188,305 @@ static inline void dwt_53_forward_1d_avx512(float *data, int length) { static inline void dwt_97_forward_1d_avx512(float *data, int length) { if (length < 2) return; - float *temp = (float*)malloc(length * sizeof(float)); int half = (length + 1) / 2; - // Split into even/odd - vectorized gather - int i; - for (i = 0; i + 16 <= half; i += 16) { - float even_vals[16], odd_vals[16]; - for (int j = 0; j < 16; j++) { - even_vals[j] = data[2 * (i + j)]; - if (2 * (i + j) + 1 < length) { - odd_vals[j] = data[2 * (i + j) + 1]; - } else { - odd_vals[j] = 0.0f; - } + // Allocate aligned temp buffer once (64-byte align for cache lines) + float *temp = NULL; +#if defined(_POSIX_C_SOURCE) || defined(_XOPEN_SOURCE) + if (posix_memalign((void**)&temp, 64, (size_t)length * sizeof(float)) != 0) { + temp = (float*)malloc((size_t)length * sizeof(float)); + } +#else + temp = (float*)aligned_alloc(64, ((size_t)length * sizeof(float) + 63) & ~63); + if (!temp) temp = (float*)malloc((size_t)length * sizeof(float)); +#endif + if (!temp) return; // allocation failure: bail out (preserve original behavior could be different) + + // FAST SPLIT: interleave into temp: first half = evens, second half = odds + // This is simple, streaming-friendly, and much faster than per-iteration small-array gathers. + { + float *even = temp; + float *odd = temp + half; + int i = 0; + // process pairs to minimize branches and memory ops + for (; i + 1 < length; i += 2) { + even[0] = data[i]; + odd[0] = data[i + 1]; + ++even; ++odd; } - _mm512_storeu_ps(&temp[i], _mm512_loadu_ps(even_vals)); - if (i < length / 2) { - __mmask16 mask = ((i + 16) <= length / 2) ? 0xFFFF : (1 << (length / 2 - i)) - 1; - _mm512_mask_storeu_ps(&temp[half + i], mask, _mm512_loadu_ps(odd_vals)); + if (i < length) { // odd leftover + even[0] = data[i]; } } - // Remaining scalar - for (; i < half; i++) { - temp[i] = data[2 * i]; - } - for (i = 0; i < length / 2; i++) { - temp[half + i] = data[2 * i + 1]; - } - - // Lifting coefficients + // Lifting coefficients as vectors const __m512 alpha_vec = _mm512_set1_ps(-1.586134342f); - const __m512 beta_vec = _mm512_set1_ps(-0.052980118f); + const __m512 beta_vec = _mm512_set1_ps(-0.052980118f); const __m512 gamma_vec = _mm512_set1_ps(0.882911076f); const __m512 delta_vec = _mm512_set1_ps(0.443506852f); - const __m512 K_vec = _mm512_set1_ps(1.230174105f); - const __m512 invK_vec = _mm512_set1_ps(1.0f / 1.230174105f); + const __m512 K_vec = _mm512_set1_ps(1.230174105f); + const __m512 invK_vec = _mm512_set1_ps(1.0f / 1.230174105f); - // Step 1: Predict α - d[i] += α * (s[i] + s[i+1]) - for (i = 0; i + 16 <= length / 2; i += 16) { - __mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1; + // Helper variables + int i; - float s_curr_vals[16], s_next_vals[16], d_vals[16]; - for (int j = 0; j < 16; j++) { - s_curr_vals[j] = temp[i + j]; - s_next_vals[j] = ((i + j + 1) < half) ? temp[i + j + 1] : temp[i + j]; - if ((half + i + j) < length) { - d_vals[j] = temp[half + i + j]; + // ----------------------- + // Step 1: Predict α + // d[i] += alpha * (s[i] + s[i+1]) + // ----------------------- + if (half > 0) { + // handle small or trivial cases + if (half == 1) { + if (half < length) { + temp[half + 0] += -1.586134342f * (temp[0] + temp[0]); + } + } else { + // main vectorized body: ensure s_next loads (i+1) valid -> i <= half-2 + int limit = (half - 1); + int n_full = (limit / 16) * 16; // process up to n_full (multiple of 16) + i = 0; + for (; i + 32 <= n_full; i += 32) { + // unroll 2x (i and i+16) + __m512 s0 = _mm512_loadu_ps(&temp[i]); + __m512 s0n = _mm512_loadu_ps(&temp[i + 1]); + __m512 d0 = _mm512_loadu_ps(&temp[half + i]); + __m512 sum0 = _mm512_add_ps(s0, s0n); + d0 = _mm512_fmadd_ps(alpha_vec, sum0, d0); + _mm512_storeu_ps(&temp[half + i], d0); + + __m512 s1 = _mm512_loadu_ps(&temp[i + 16]); + __m512 s1n = _mm512_loadu_ps(&temp[i + 17]); + __m512 d1 = _mm512_loadu_ps(&temp[half + i + 16]); + __m512 sum1 = _mm512_add_ps(s1, s1n); + d1 = _mm512_fmadd_ps(alpha_vec, sum1, d1); + _mm512_storeu_ps(&temp[half + i + 16], d1); + } + for (; i + 16 <= n_full; i += 16) { + __m512 s = _mm512_loadu_ps(&temp[i]); + __m512 sn = _mm512_loadu_ps(&temp[i + 1]); + __m512 d = _mm512_loadu_ps(&temp[half + i]); + __m512 sum = _mm512_add_ps(s, sn); + d = _mm512_fmadd_ps(alpha_vec, sum, d); + _mm512_storeu_ps(&temp[half + i], d); + } + // scalar remainder up to limit (half-2 -> last vector handled below) + for (; i < limit; ++i) { + temp[half + i] += -1.586134342f * (temp[i] + temp[i + 1]); + } + // handle last index i = half-1 (mirror) + int last = half - 1; + if (half + last < length) { + float s_curr = temp[last]; + float s_next = s_curr; + temp[half + last] += -1.586134342f * (s_curr + s_next); + } + } + } + + // ----------------------- + // Step 2: Update β + // s[i] += beta * (d[i-1] + d[i]) + // ----------------------- + if (half > 0) { + // handle i == 0 separately (d_prev = d_curr for boundary semantics) + if (half >= 1) { + // i == 0 + if (half + 0 < length) { + float d_curr0 = temp[half + 0]; + temp[0] += -0.052980118f * (d_curr0 + d_curr0); } } - __m512 s_curr = _mm512_loadu_ps(s_curr_vals); - __m512 s_next = _mm512_loadu_ps(s_next_vals); - __m512 d = _mm512_maskz_loadu_ps(mask, d_vals); + if (half > 1) { + // main vector loop starting from i = 1 to half-1 (we will write s[i] for i>=1) + int start = 1; + int limit = half; // exclusive + int n_elems = limit - start; + int n_full = (n_elems / 16) * 16; + i = start; + for (; i + 32 <= start + n_full; i += 32) { + // unroll 2x + __m512 s0 = _mm512_loadu_ps(&temp[i]); + __m512 dcurr0 = _mm512_loadu_ps(&temp[half + i]); + __m512 dprev0 = _mm512_loadu_ps(&temp[half + i - 1]); + __m512 sum0 = _mm512_add_ps(dprev0, dcurr0); + s0 = _mm512_fmadd_ps(beta_vec, sum0, s0); + _mm512_storeu_ps(&temp[i], s0); - __m512 sum = _mm512_add_ps(s_curr, s_next); - d = _mm512_fmadd_ps(alpha_vec, sum, d); // d += alpha * sum - - _mm512_mask_storeu_ps(&temp[half + i], mask, d); - } - - // Remaining scalar for step 1 - for (; i < length / 2; i++) { - if (half + i < length) { - float s_curr = temp[i]; - float s_next = (i + 1 < half) ? temp[i + 1] : s_curr; - temp[half + i] += -1.586134342f * (s_curr + s_next); + __m512 s1 = _mm512_loadu_ps(&temp[i + 16]); + __m512 dcurr1 = _mm512_loadu_ps(&temp[half + i + 16]); + __m512 dprev1 = _mm512_loadu_ps(&temp[half + i + 15]); + __m512 sum1 = _mm512_add_ps(dprev1, dcurr1); + s1 = _mm512_fmadd_ps(beta_vec, sum1, s1); + _mm512_storeu_ps(&temp[i + 16], s1); + } + for (; i + 16 <= start + n_full; i += 16) { + __m512 s = _mm512_loadu_ps(&temp[i]); + __m512 dcurr = _mm512_loadu_ps(&temp[half + i]); + __m512 dprev = _mm512_loadu_ps(&temp[half + i - 1]); + __m512 sum = _mm512_add_ps(dprev, dcurr); + s = _mm512_fmadd_ps(beta_vec, sum, s); + _mm512_storeu_ps(&temp[i], s); + } + // scalar remainder + for (; i < limit; ++i) { + float d_curr = (half + i < length) ? temp[half + i] : 0.0f; + float d_prev = (half + i - 1 < length && i > 0) ? temp[half + i - 1] : d_curr; + temp[i] += -0.052980118f * (d_prev + d_curr); + } } } - // Step 2: Update β - s[i] += β * (d[i-1] + d[i]) - for (i = 0; i + 16 <= half; i += 16) { - __mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; - - float s_vals[16], d_curr_vals[16], d_prev_vals[16]; - for (int j = 0; j < 16; j++) { - s_vals[j] = temp[i + j]; - d_curr_vals[j] = ((half + i + j) < length) ? temp[half + i + j] : 0.0f; - d_prev_vals[j] = ((i + j) > 0 && (half + i + j - 1) < length) ? temp[half + i + j - 1] : d_curr_vals[j]; - } - - __m512 s = _mm512_loadu_ps(s_vals); - __m512 d_curr = _mm512_loadu_ps(d_curr_vals); - __m512 d_prev = _mm512_loadu_ps(d_prev_vals); - - s = _mm512_fmadd_ps(beta_vec, _mm512_add_ps(d_prev, d_curr), s); - - _mm512_mask_storeu_ps(&temp[i], mask, s); - } - - // Scalar remainder for step 2 - for (; i < half; i++) { - float d_curr = (half + i < length) ? temp[half + i] : 0.0f; - float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr; - temp[i] += -0.052980118f * (d_prev + d_curr); - } - + // ----------------------- // Step 3: Predict γ - for (i = 0; i + 16 <= length / 2; i += 16) { - __mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1; + // d[i] += gamma * (s[i] + s[i+1]) + // ----------------------- + if (half > 0) { + if (half == 1) { + if (half < length) { + temp[half + 0] += 0.882911076f * (temp[0] + temp[0]); + } + } else { + int limit = (half - 1); + int n_full = (limit / 16) * 16; + i = 0; + for (; i + 32 <= n_full; i += 32) { + __m512 s0 = _mm512_loadu_ps(&temp[i]); + __m512 s0n = _mm512_loadu_ps(&temp[i + 1]); + __m512 d0 = _mm512_loadu_ps(&temp[half + i]); + __m512 sum0 = _mm512_add_ps(s0, s0n); + d0 = _mm512_fmadd_ps(gamma_vec, sum0, d0); + _mm512_storeu_ps(&temp[half + i], d0); - float s_curr_vals[16], s_next_vals[16], d_vals[16]; - for (int j = 0; j < 16; j++) { - s_curr_vals[j] = temp[i + j]; - s_next_vals[j] = ((i + j + 1) < half) ? temp[i + j + 1] : temp[i + j]; - if ((half + i + j) < length) { - d_vals[j] = temp[half + i + j]; + __m512 s1 = _mm512_loadu_ps(&temp[i + 16]); + __m512 s1n = _mm512_loadu_ps(&temp[i + 17]); + __m512 d1 = _mm512_loadu_ps(&temp[half + i + 16]); + __m512 sum1 = _mm512_add_ps(s1, s1n); + d1 = _mm512_fmadd_ps(gamma_vec, sum1, d1); + _mm512_storeu_ps(&temp[half + i + 16], d1); + } + for (; i + 16 <= n_full; i += 16) { + __m512 s = _mm512_loadu_ps(&temp[i]); + __m512 sn = _mm512_loadu_ps(&temp[i + 1]); + __m512 d = _mm512_loadu_ps(&temp[half + i]); + __m512 sum = _mm512_add_ps(s, sn); + d = _mm512_fmadd_ps(gamma_vec, sum, d); + _mm512_storeu_ps(&temp[half + i], d); + } + for (; i < limit; ++i) { + temp[half + i] += 0.882911076f * (temp[i] + temp[i + 1]); + } + // last index mirror + int last = half - 1; + if (half + last < length) { + float s_curr = temp[last]; + float s_next = s_curr; + temp[half + last] += 0.882911076f * (s_curr + s_next); + } + } + } + + // ----------------------- + // Step 4: Update δ + // s[i] += delta * (d[i-1] + d[i]) + // ----------------------- + if (half > 0) { + // i == 0 + if (half >= 1) { + if (half + 0 < length) { + float d_curr0 = temp[half + 0]; + temp[0] += 0.443506852f * (d_curr0 + d_curr0); } } - __m512 s_curr = _mm512_loadu_ps(s_curr_vals); - __m512 s_next = _mm512_loadu_ps(s_next_vals); - __m512 d = _mm512_maskz_loadu_ps(mask, d_vals); + if (half > 1) { + int start = 1; + int limit = half; // exclusive + int n_elems = limit - start; + int n_full = (n_elems / 16) * 16; + i = start; + for (; i + 32 <= start + n_full; i += 32) { + __m512 s0 = _mm512_loadu_ps(&temp[i]); + __m512 dcurr0 = _mm512_loadu_ps(&temp[half + i]); + __m512 dprev0 = _mm512_loadu_ps(&temp[half + i - 1]); + __m512 sum0 = _mm512_add_ps(dprev0, dcurr0); + s0 = _mm512_fmadd_ps(delta_vec, sum0, s0); + _mm512_storeu_ps(&temp[i], s0); - d = _mm512_fmadd_ps(gamma_vec, _mm512_add_ps(s_curr, s_next), d); - - _mm512_mask_storeu_ps(&temp[half + i], mask, d); - } - - // Scalar remainder for step 3 - for (; i < length / 2; i++) { - if (half + i < length) { - float s_curr = temp[i]; - float s_next = (i + 1 < half) ? temp[i + 1] : s_curr; - temp[half + i] += 0.882911076f * (s_curr + s_next); + __m512 s1 = _mm512_loadu_ps(&temp[i + 16]); + __m512 dcurr1 = _mm512_loadu_ps(&temp[half + i + 16]); + __m512 dprev1 = _mm512_loadu_ps(&temp[half + i + 15]); + __m512 sum1 = _mm512_add_ps(dprev1, dcurr1); + s1 = _mm512_fmadd_ps(delta_vec, sum1, s1); + _mm512_storeu_ps(&temp[i + 16], s1); + } + for (; i + 16 <= start + n_full; i += 16) { + __m512 s = _mm512_loadu_ps(&temp[i]); + __m512 dcurr = _mm512_loadu_ps(&temp[half + i]); + __m512 dprev = _mm512_loadu_ps(&temp[half + i - 1]); + __m512 sum = _mm512_add_ps(dprev, dcurr); + s = _mm512_fmadd_ps(delta_vec, sum, s); + _mm512_storeu_ps(&temp[i], s); + } + for (; i < limit; ++i) { + float d_curr = (half + i < length) ? temp[half + i] : 0.0f; + float d_prev = (half + i - 1 < length && i > 0) ? temp[half + i - 1] : d_curr; + temp[i] += 0.443506852f * (d_prev + d_curr); + } } } - // Step 4: Update δ - for (i = 0; i + 16 <= half; i += 16) { - __mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; + // ----------------------- + // Step 5: Scaling + // s *= K, d *= invK + // ----------------------- + // s (first half) + { + int n_full = (half / 16) * 16; + i = 0; + for (; i + 32 <= n_full; i += 32) { + __m512 s0 = _mm512_loadu_ps(&temp[i]); + s0 = _mm512_mul_ps(s0, K_vec); + _mm512_storeu_ps(&temp[i], s0); - float s_vals[16], d_curr_vals[16], d_prev_vals[16]; - for (int j = 0; j < 16; j++) { - s_vals[j] = temp[i + j]; - d_curr_vals[j] = ((half + i + j) < length) ? temp[half + i + j] : 0.0f; - d_prev_vals[j] = ((i + j) > 0 && (half + i + j - 1) < length) ? temp[half + i + j - 1] : d_curr_vals[j]; + __m512 s1 = _mm512_loadu_ps(&temp[i + 16]); + s1 = _mm512_mul_ps(s1, K_vec); + _mm512_storeu_ps(&temp[i + 16], s1); } - - __m512 s = _mm512_loadu_ps(s_vals); - __m512 d_curr = _mm512_loadu_ps(d_curr_vals); - __m512 d_prev = _mm512_loadu_ps(d_prev_vals); - - s = _mm512_fmadd_ps(delta_vec, _mm512_add_ps(d_prev, d_curr), s); - - _mm512_mask_storeu_ps(&temp[i], mask, s); + for (; i + 16 <= n_full; i += 16) { + __m512 s = _mm512_loadu_ps(&temp[i]); + s = _mm512_mul_ps(s, K_vec); + _mm512_storeu_ps(&temp[i], s); + } + for (; i < half; ++i) temp[i] *= 1.230174105f; } - // Scalar remainder for step 4 - for (; i < half; i++) { - float d_curr = (half + i < length) ? temp[half + i] : 0.0f; - float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr; - temp[i] += 0.443506852f * (d_prev + d_curr); - } + // d (second half) + { + int dlen = length - half; + int n_full = (dlen / 16) * 16; + i = 0; + for (; i + 32 <= n_full; i += 32) { + __m512 d0 = _mm512_loadu_ps(&temp[half + i]); + d0 = _mm512_mul_ps(d0, invK_vec); + _mm512_storeu_ps(&temp[half + i], d0); - // Step 5: Scaling - vectorized - for (i = 0; i + 16 <= half; i += 16) { - __mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; - __m512 s = _mm512_maskz_loadu_ps(mask, &temp[i]); - s = _mm512_mul_ps(s, K_vec); - _mm512_mask_storeu_ps(&temp[i], mask, s); - } - for (; i < half; i++) { - temp[i] *= 1.230174105f; - } - - for (i = 0; i + 16 <= length / 2; i += 16) { - __mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1; - __m512 d = _mm512_maskz_loadu_ps(mask, &temp[half + i]); - d = _mm512_mul_ps(d, invK_vec); - _mm512_mask_storeu_ps(&temp[half + i], mask, d); - } - for (; i < length / 2; i++) { - if (half + i < length) { - temp[half + i] /= 1.230174105f; + __m512 d1 = _mm512_loadu_ps(&temp[half + i + 16]); + d1 = _mm512_mul_ps(d1, invK_vec); + _mm512_storeu_ps(&temp[half + i + 16], d1); + } + for (; i + 16 <= n_full; i += 16) { + __m512 d = _mm512_loadu_ps(&temp[half + i]); + d = _mm512_mul_ps(d, invK_vec); + _mm512_storeu_ps(&temp[half + i], d); + } + for (; i < dlen; ++i) { + if (half + i < length) temp[half + i] /= 1.230174105f; } } - memcpy(data, temp, length * sizeof(float)); + // Copy back and free + memcpy(data, temp, (size_t)length * sizeof(float)); free(temp); }