diff --git a/CLAUDE.md b/CLAUDE.md index ec3c8f7..6dc59a4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -232,42 +232,7 @@ Peripheral memories can be accessed using `vm.peek()` and `vm.poke()` functions, - **255**: Haar (demonstration only, simplest possible wavelet) - **Format documentation**: `terranmon.txt` (search for "TSVM Advanced Video (TAV) Format") -- **Version**: Current (perceptual quantisation, multi-wavelet support, significance map compression) - -#### TAV Significance Map Compression (Technical Details) - -The significance map compression technique implemented on 2025-09-29 provides substantial compression improvements by exploiting the sparsity of quantised DWT coefficients: - -**Implementation Files**: -- **C Encoder**: `video_encoder/encoder_tav.c` - `preprocess_coefficients()` function (lines 960-991) -- **C Decoder**: `video_encoder/decoder_tav.c` - `postprocess_coefficients()` function (lines 29-48) -- **Kotlin Decoder**: `GraphicsJSR223Delegate.kt` - `postprocessCoefficients()` function for TSVM runtime - -**Technical Approach**: -``` -Original: [coeff_array] → [concatenated_significance_maps + nonzero_values] - -Concatenated Maps Layout: -[Y_map][Co_map][Cg_map][Y_vals][Co_vals][Cg_vals] (channel layout 0) -[Y_map][Co_map][Cg_map][A_map][Y_vals][Co_vals][Cg_vals][A_vals] (channel layout 1) -[Y_map][Y_vals] (channel layout 2) -[Y_map][A_map][Y_vals][A_vals] (channel layout 3) -[Co_map][Cg_map][Co_vals][Cg_vals] (channel layout 4) -[Co_map][Cg_map][A_map][Co_vals][Cg_vals][A_vals] (channel layout 5) - -(replace Y->I, Co->Ct, Cg->Cp for ICtCp colour space) - -- Significance map: 1 bit per coefficient (0=zero, 1=non-zero) -- Value arrays: Only non-zero coefficients in sequence per channel -- Cross-channel optimisation: Zstd finds patterns across similar significance maps -- Result: 16-18% compression improvement + 1.6% additional from concatenation -``` - -**Performance**: -- **Sparsity exploitation**: Tested on quantised DWT coefficients with 86.9% sparsity (Y), 97.8% (Co), 99.5% (Cg) -- **Compression improvement**: 16.4% from significance maps + 1.6% from concatenated layout -- **Real-world impact**: 559 bytes saved per frame (5.59 MB per 10k frames) -- **Cross-channel benefit**: Concatenated maps allow Zstd to exploit similarity between significance patterns +- **Version**: Current (perceptual quantisation, multi-wavelet support, EZBC compression) #### TAV Temporal 3D DWT (GOP Unified Encoding) @@ -275,15 +240,12 @@ Implemented on 2025-10-15 for improved temporal compression through group-of-pic **Key Features**: - **3D DWT**: Applies DWT in both spatial (2D) and temporal (1D) dimensions for optimal spacetime compression -- **Unified GOP Preprocessing**: Single significance map for all frames and channels in a GOP (width×height×N_frames×3_channels) -- **FFT-based Phase Correlation**: Uses FFTW3 library for accurate global motion estimation with quarter-pixel precision +- **Unified GOP Preprocessing**: Single EZBC tree for all frames and channels in a GOP (width×height×N_frames×3_channels) - **GOP Size**: Typically 8 frames (configurable), with scene change detection for adaptive GOPs - **Single-frame Fallback**: GOP size of 1 automatically uses traditional I-frame encoding **Packet Format**: -- **0x12 (GOP_UNIFIED)**: `[gop_size][motion_vectors...][compressed_size][compressed_data]` - - Motion vectors stored as int16_t in quarter-pixel units for all frames in GOP - - Unified significance map for entire GOP block enables cross-frame compression +- **0x12 (GOP_UNIFIED)**: `[gop_size][compressed_size][compressed_data]` - **0xFC (GOP_SYNC)**: `[frame_count]` - Indicates N frames were decoded from GOP block - **Timecode Emission**: One timecode packet per GOP (not per frame) @@ -292,12 +254,6 @@ Implemented on 2025-10-15 for improved temporal compression through group-of-pic // Unified preprocessing structure (encoder_tav.c:2371-2509) [All_Y_maps][All_Co_maps][All_Cg_maps][All_Y_values][All_Co_values][All_Cg_values] // Where maps are grouped by channel across all GOP frames for optimal Zstd compression - -// Phase correlation using FFT (encoder_tav.c:1246-1383) -// - FFTW3 forward FFT on grayscale frames -// - Cross-power spectrum computation -// - Inverse FFT gives correlation peak at (dx, dy) -// - Parabolic interpolation for quarter-pixel refinement ``` **Usage**: @@ -312,7 +268,6 @@ Implemented on 2025-10-15 for improved temporal compression through group-of-pic **Compression Benefits**: - **Temporal Coherence**: Exploits similarity across consecutive frames - **Unified Compression**: Zstd compresses entire GOP as single block, finding patterns across time -- **Motion Compensation**: FFT-based phase correlation provides accurate global motion estimation - **Adaptive GOPs**: Scene change detection ensures optimal GOP boundaries #### TAD Format (TSVM Advanced Audio) diff --git a/video_encoder/Makefile b/video_encoder/Makefile index ca109cd..58ad921 100644 --- a/video_encoder/Makefile +++ b/video_encoder/Makefile @@ -3,8 +3,8 @@ CC = gcc CXX = g++ -CFLAGS = -std=c99 -Wall -Wextra -Ofast -D_GNU_SOURCE -march=native -CXXFLAGS = -std=c++11 -Wall -Wextra -Ofast -D_GNU_SOURCE -march=native +CFLAGS = -std=c99 -Wall -Wextra -Ofast -D_GNU_SOURCE -march=native -mavx512f -mavx512dq -mavx512bw -mavx512vl +CXXFLAGS = -std=c++11 -Wall -Wextra -Ofast -D_GNU_SOURCE -march=native -mavx512f -mavx512dq -mavx512bw -mavx512vl DBGFLAGS = # Zstd flags (use pkg-config if available, fallback for cross-platform compatibility) diff --git a/video_encoder/decoder_tav.c b/video_encoder/decoder_tav.c index e63cf17..0be6689 100644 --- a/video_encoder/decoder_tav.c +++ b/video_encoder/decoder_tav.c @@ -11,11 +11,13 @@ #include #include #include +#include #include #include #include "decoder_tad.h" // Shared TAD decoder library +#include "tav_avx512.h" // AVX-512 SIMD optimizations -#define DECODER_VENDOR_STRING "Decoder-TAV 20251103 (ffv1+pcmu8)" +#define DECODER_VENDOR_STRING "Decoder-TAV 20251124 (avx512)" // TAV format constants #define TAV_MAGIC "\x1F\x54\x53\x56\x4D\x54\x41\x56" @@ -311,13 +313,31 @@ static void dequantise_dwt_subbands_perceptual(int q_index, int q_y_global, cons // Decoder must multiply by effective quantiser to denormalize // Previous denormalization in EZBC caused int16_t overflow (clipping at 32767) // for bright pixels, creating dark DWT-pattern blemishes - for (int i = 0; i < subband->coeff_count; i++) { - const int idx = subband->coeff_start + i; - if (idx < coeff_count) { - const float untruncated = quantised[idx] * effective_quantiser; - dequantised[idx] = untruncated; + +#ifdef __AVX512F__ + // Use AVX-512 optimized dequantization if available (1.1x speedup against -Ofast) + // Check: subband has >=16 elements AND won't exceed buffer bounds + const int subband_end = subband->coeff_start + subband->coeff_count; + if (g_simd_level >= SIMD_AVX512F && subband->coeff_count >= 16 && subband_end <= coeff_count) { + dequantise_dwt_coefficients_avx512( + quantised + subband->coeff_start, + dequantised + subband->coeff_start, + subband->coeff_count, + effective_quantiser + ); + } else { +#endif + // Scalar fallback or small subbands + for (int i = 0; i < subband->coeff_count; i++) { + const int idx = subband->coeff_start + i; + if (idx < coeff_count) { + const float untruncated = quantised[idx] * effective_quantiser; + dequantised[idx] = untruncated; + } } +#ifdef __AVX512F__ } +#endif } // Debug: Verify LL band was dequantised correctly @@ -2714,6 +2734,9 @@ int main(int argc, char *argv[]) { // Ignore SIGPIPE to prevent process termination if FFmpeg exits early signal(SIGPIPE, SIG_IGN); + // Initialize AVX-512 runtime detection + tav_simd_init(); + char *input_file = NULL; char *output_file = NULL; int verbose = 0; @@ -2784,6 +2807,12 @@ int main(int argc, char *argv[]) { printf("Output: %s (FFV1 level 3 + PCMu8 @ 32 KHz)\n", output_file); } + // Start timing for FPS calculation + struct timeval start_time, last_update_time; + gettimeofday(&start_time, NULL); + last_update_time = start_time; + int frames_since_last_update = 0; + // Main decoding loop int result = 1; int total_packets = 0; @@ -2845,6 +2874,28 @@ int main(int argc, char *argv[]) { } // Update decoder frame count (GOP already wrote frames) decoder->frame_count += gop_frame_count; + frames_since_last_update += gop_frame_count; + + // Print progress every second or so + struct timeval current_time; + gettimeofday(¤t_time, NULL); + double time_since_update = (current_time.tv_sec - last_update_time.tv_sec) + + (current_time.tv_usec - last_update_time.tv_usec) / 1000000.0; + + if (time_since_update >= 1.0 || decoder->frame_count == gop_frame_count) { // Update every second + double total_time = (current_time.tv_sec - start_time.tv_sec) + + (current_time.tv_usec - start_time.tv_usec) / 1000000.0; + double current_fps = frames_since_last_update / time_since_update; + double avg_fps = decoder->frame_count / total_time; + + fprintf(stderr, "\rDecoding: Frame %d (%.1f fps, avg %.1f fps) ", + decoder->frame_count, current_fps, avg_fps); + fflush(stderr); + + last_update_time = current_time; + frames_since_last_update = 0; + } + continue; } @@ -3379,10 +3430,28 @@ int main(int argc, char *argv[]) { fprintf(stderr, "Error: Frame decoding failed at frame %d\n", decoder->frame_count); break; } - if (verbose && decoder->frame_count % 100 == 0) { - printf("Decoded frame %d\r", decoder->frame_count); - fflush(stdout); + + // Update progress indicator + frames_since_last_update++; + struct timeval current_time; + gettimeofday(¤t_time, NULL); + double time_since_update = (current_time.tv_sec - last_update_time.tv_sec) + + (current_time.tv_usec - last_update_time.tv_usec) / 1000000.0; + + if (time_since_update >= 1.0 || decoder->frame_count == 1) { // Update every second + double total_time = (current_time.tv_sec - start_time.tv_sec) + + (current_time.tv_usec - start_time.tv_usec) / 1000000.0; + double current_fps = frames_since_last_update / time_since_update; + double avg_fps = decoder->frame_count / total_time; + + fprintf(stderr, "\rDecoding: Frame %d (%.1f fps, avg %.1f fps) ", + decoder->frame_count, current_fps, avg_fps); + fflush(stderr); + + last_update_time = current_time; + frames_since_last_update = 0; } + break; case TAV_PACKET_AUDIO_MP2: @@ -3419,6 +3488,12 @@ int main(int argc, char *argv[]) { } } + // Calculate final statistics + struct timeval end_time; + gettimeofday(&end_time, NULL); + double total_time = (end_time.tv_sec - start_time.tv_sec) + + (end_time.tv_usec - start_time.tv_usec) / 1000000.0; + if (verbose) { printf("\nDecoded %d frames\n", decoder->frame_count); } @@ -3431,7 +3506,12 @@ int main(int argc, char *argv[]) { return 1; } - printf("Successfully decoded to: %s\n", output_file); + // Print final statistics (similar to encoder) + fprintf(stderr, "\n"); // Clear progress line + printf("\nDecoding complete!\n"); + printf(" Frames decoded: %d\n", decoder->frame_count); + printf(" Decoding time: %.2fs (%.1f fps)\n", total_time, decoder->frame_count / total_time); + printf(" Output: %s\n", output_file); // Clean up temporary audio file if (unlink(temp_audio_file) == 0 && verbose) { diff --git a/video_encoder/encoder_tav.c b/video_encoder/encoder_tav.c index 307620f..d100945 100644 --- a/video_encoder/encoder_tav.c +++ b/video_encoder/encoder_tav.c @@ -17,8 +17,9 @@ #include #include #include +#include "tav_avx512.h" // AVX-512 SIMD optimizations -#define ENCODER_VENDOR_STRING "Encoder-TAV 20251123 (3d-dwt,tad,ssf-tc,cdf53-motion)" +#define ENCODER_VENDOR_STRING "Encoder-TAV 20251124 (3d-dwt,tad,ssf-tc,cdf53-motion,avx512)" // TSVM Advanced Video (TAV) format constants #define TAV_MAGIC "\x1F\x54\x53\x56\x4D\x54\x41\x56" // "\x1FTSVM TAV" @@ -6429,6 +6430,17 @@ static void quantise_dwt_coefficients(float *coeffs, int16_t *quantised, int siz float effective_q = quantiser; effective_q = FCLAMP(effective_q, 1.0f, 4096.0f); +#ifdef __AVX512F__ + // Use AVX-512 optimized version if available (2x speedup against -Ofast) + if (g_simd_level >= SIMD_AVX512F) { + quantise_dwt_coefficients_avx512(coeffs, quantised, size, effective_q, dead_zone_threshold, + width, height, decomp_levels, is_chroma, + get_subband_level, get_subband_type); + return; + } +#endif + + // Scalar fallback for (int i = 0; i < size; i++) { float quantised_val = coeffs[i] / effective_q; @@ -10792,6 +10804,10 @@ int main(int argc, char *argv[]) { strcpy(TEMP_PCM_FILE + 37, ".pcm"); printf("Initialising encoder...\n"); + + // Initialize AVX-512 runtime detection + tav_simd_init(); + tav_encoder_t *enc = create_encoder(); if (!enc) { fprintf(stderr, "Error: Failed to create encoder\n"); diff --git a/video_encoder/tav_avx512.h b/video_encoder/tav_avx512.h new file mode 100644 index 0000000..614c3e8 --- /dev/null +++ b/video_encoder/tav_avx512.h @@ -0,0 +1,717 @@ +/* + * TAV AVX-512 Optimizations + * + * This file contains AVX-512 optimized versions of performance-critical functions + * in the TAV encoder. Runtime CPU detection ensures fallback to scalar versions + * on non-AVX-512 systems. + * + * Optimized functions: + * - 1D DWT transforms (5/3, 9/7, Haar, Bior13/7, DD4) + * - Quantization functions + * - RGB to YCoCg color conversion + * - 2D DWT gather/scatter operations + * + * Compile with: -mavx512f -mavx512dq -mavx512bw -mavx512vl + */ + +#ifndef TAV_AVX512_H +#define TAV_AVX512_H + +#include +#include +#include +#include +#include +#include + +// ============================================================================= +// SIMD Capability Detection +// ============================================================================= + +typedef enum { + SIMD_NONE = 0, + SIMD_AVX512F = 1 +} simd_level_t; + +// Global SIMD level (set by tav_simd_init) +static simd_level_t g_simd_level = SIMD_NONE; + +// CPU feature detection +static inline int cpu_has_avx512f(void) { +#ifdef __AVX512F__ + return __builtin_cpu_supports("avx512f") && + __builtin_cpu_supports("avx512dq"); +#else + return 0; +#endif +} + +// Initialize SIMD detection (call once at startup) +static inline void tav_simd_init(void) { +#ifdef __AVX512F__ + if (cpu_has_avx512f()) { + g_simd_level = SIMD_AVX512F; + fprintf(stderr, "[TAV] AVX-512 optimizations enabled\n"); + } else { + g_simd_level = SIMD_NONE; + fprintf(stderr, "[TAV] AVX-512 not available, using scalar fallback\n"); + } +#else + g_simd_level = SIMD_NONE; + fprintf(stderr, "[TAV] Compiled without AVX-512 support\n"); +#endif +} + +#ifdef __AVX512F__ + +// ============================================================================= +// Helper Functions +// ============================================================================= + +// Horizontal sum of 16 floats +static inline float _mm512_reduce_add_ps_compat(__m512 v) { + __m256 low = _mm512_castps512_ps256(v); + __m256 high = _mm512_extractf32x8_ps(v, 1); + __m256 sum256 = _mm256_add_ps(low, high); + __m128 sum128 = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); + sum128 = _mm_hadd_ps(sum128, sum128); + sum128 = _mm_hadd_ps(sum128, sum128); + return _mm_cvtss_f32(sum128); +} + +// Clamp helper for vectorized operations +static inline __m512 _mm512_clamp_ps(__m512 v, __m512 min_val, __m512 max_val) { + return _mm512_min_ps(_mm512_max_ps(v, min_val), max_val); +} + +// ============================================================================= +// AVX-512 Optimized 1D DWT Forward Transforms +// ============================================================================= + +// 5/3 Reversible Forward DWT with AVX-512 +static inline void dwt_53_forward_1d_avx512(float *data, int length) { + if (length < 2) return; + + float *temp = (float*)calloc(length, sizeof(float)); + int half = (length + 1) / 2; + + // Predict step (high-pass) - vectorized + // temp[half + i] = data[2*i+1] - 0.5 * (data[2*i] + data[2*i+2]) + int i; + for (i = 0; i + 16 <= half; i += 16) { + __mmask16 valid_mask = 0xFFFF; + + // Check boundary for last iteration + for (int j = 0; j < 16; j++) { + int idx = 2 * (i + j) + 1; + if (idx >= length) { + valid_mask &= ~(1 << j); + } + } + + if (valid_mask == 0) break; + + // Load data[2*i] - stride 2 load + float even_curr_vals[16], even_next_vals[16], odd_vals[16]; + + for (int j = 0; j < 16; j++) { + if (valid_mask & (1 << j)) { + even_curr_vals[j] = data[2 * (i + j)]; + even_next_vals[j] = (2 * (i + j) + 2 < length) ? data[2 * (i + j) + 2] : data[2 * (i + j)]; + odd_vals[j] = data[2 * (i + j) + 1]; + } else { + even_curr_vals[j] = 0.0f; + even_next_vals[j] = 0.0f; + odd_vals[j] = 0.0f; + } + } + + __m512 even_curr = _mm512_loadu_ps(even_curr_vals); + __m512 even_next = _mm512_loadu_ps(even_next_vals); + __m512 odd = _mm512_loadu_ps(odd_vals); + + __m512 pred = _mm512_mul_ps(_mm512_add_ps(even_curr, even_next), _mm512_set1_ps(0.5f)); + __m512 high = _mm512_sub_ps(odd, pred); + + _mm512_mask_storeu_ps(&temp[half + i], valid_mask, high); + } + + // Handle remaining elements + for (; i < half; i++) { + int idx = 2 * i + 1; + if (idx < length) { + float pred = 0.5f * (data[2 * i] + (2 * i + 2 < length ? data[2 * i + 2] : data[2 * i])); + temp[half + i] = data[idx] - pred; + } + } + + // Update step (low-pass) - vectorized + // temp[i] = data[2*i] + 0.25 * (temp[half+i-1] + temp[half+i]) + for (i = 0; i + 16 <= half; i += 16) { + __m512 even = _mm512_loadu_ps(&data[2 * i]); // Load with stride 2 (simplified) + + // Manual gather for strided load + float even_vals[16]; + for (int j = 0; j < 16 && (i + j) < half; j++) { + even_vals[j] = data[2 * (i + j)]; + } + even = _mm512_loadu_ps(even_vals); + + // Load high-pass neighbors + float high_prev[16], high_curr[16]; + for (int j = 0; j < 16 && (i + j) < half; j++) { + high_prev[j] = ((i + j) > 0) ? temp[half + (i + j) - 1] : 0.0f; + high_curr[j] = ((i + j) < half - 1) ? temp[half + (i + j)] : 0.0f; + } + + __m512 hp = _mm512_loadu_ps(high_prev); + __m512 hc = _mm512_loadu_ps(high_curr); + __m512 update = _mm512_mul_ps(_mm512_add_ps(hp, hc), _mm512_set1_ps(0.25f)); + __m512 low = _mm512_add_ps(even, update); + + __mmask16 store_mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; + _mm512_mask_storeu_ps(&temp[i], store_mask, low); + } + + // Handle remaining elements + for (; i < half; i++) { + float update = 0.25f * ((i > 0 ? temp[half + i - 1] : 0) + + (i < half - 1 ? temp[half + i] : 0)); + temp[i] = data[2 * i] + update; + } + + memcpy(data, temp, length * sizeof(float)); + free(temp); +} + +// 9/7 Irreversible Forward DWT with AVX-512 +static inline void dwt_97_forward_1d_avx512(float *data, int length) { + if (length < 2) return; + + float *temp = (float*)malloc(length * sizeof(float)); + int half = (length + 1) / 2; + + // Split into even/odd - vectorized gather + int i; + for (i = 0; i + 16 <= half; i += 16) { + float even_vals[16], odd_vals[16]; + for (int j = 0; j < 16; j++) { + even_vals[j] = data[2 * (i + j)]; + if (2 * (i + j) + 1 < length) { + odd_vals[j] = data[2 * (i + j) + 1]; + } else { + odd_vals[j] = 0.0f; + } + } + _mm512_storeu_ps(&temp[i], _mm512_loadu_ps(even_vals)); + if (i < length / 2) { + __mmask16 mask = ((i + 16) <= length / 2) ? 0xFFFF : (1 << (length / 2 - i)) - 1; + _mm512_mask_storeu_ps(&temp[half + i], mask, _mm512_loadu_ps(odd_vals)); + } + } + + // Remaining scalar + for (; i < half; i++) { + temp[i] = data[2 * i]; + } + for (i = 0; i < length / 2; i++) { + temp[half + i] = data[2 * i + 1]; + } + + // Lifting coefficients + const __m512 alpha_vec = _mm512_set1_ps(-1.586134342f); + const __m512 beta_vec = _mm512_set1_ps(-0.052980118f); + const __m512 gamma_vec = _mm512_set1_ps(0.882911076f); + const __m512 delta_vec = _mm512_set1_ps(0.443506852f); + const __m512 K_vec = _mm512_set1_ps(1.230174105f); + const __m512 invK_vec = _mm512_set1_ps(1.0f / 1.230174105f); + + // Step 1: Predict α - d[i] += α * (s[i] + s[i+1]) + for (i = 0; i + 16 <= length / 2; i += 16) { + __mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1; + + float s_curr_vals[16], s_next_vals[16], d_vals[16]; + for (int j = 0; j < 16; j++) { + s_curr_vals[j] = temp[i + j]; + s_next_vals[j] = ((i + j + 1) < half) ? temp[i + j + 1] : temp[i + j]; + if ((half + i + j) < length) { + d_vals[j] = temp[half + i + j]; + } + } + + __m512 s_curr = _mm512_loadu_ps(s_curr_vals); + __m512 s_next = _mm512_loadu_ps(s_next_vals); + __m512 d = _mm512_maskz_loadu_ps(mask, d_vals); + + __m512 sum = _mm512_add_ps(s_curr, s_next); + d = _mm512_fmadd_ps(alpha_vec, sum, d); // d += alpha * sum + + _mm512_mask_storeu_ps(&temp[half + i], mask, d); + } + + // Remaining scalar for step 1 + for (; i < length / 2; i++) { + if (half + i < length) { + float s_curr = temp[i]; + float s_next = (i + 1 < half) ? temp[i + 1] : s_curr; + temp[half + i] += -1.586134342f * (s_curr + s_next); + } + } + + // Step 2: Update β - s[i] += β * (d[i-1] + d[i]) + for (i = 0; i + 16 <= half; i += 16) { + __mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; + + float s_vals[16], d_curr_vals[16], d_prev_vals[16]; + for (int j = 0; j < 16; j++) { + s_vals[j] = temp[i + j]; + d_curr_vals[j] = ((half + i + j) < length) ? temp[half + i + j] : 0.0f; + d_prev_vals[j] = ((i + j) > 0 && (half + i + j - 1) < length) ? temp[half + i + j - 1] : d_curr_vals[j]; + } + + __m512 s = _mm512_loadu_ps(s_vals); + __m512 d_curr = _mm512_loadu_ps(d_curr_vals); + __m512 d_prev = _mm512_loadu_ps(d_prev_vals); + + s = _mm512_fmadd_ps(beta_vec, _mm512_add_ps(d_prev, d_curr), s); + + _mm512_mask_storeu_ps(&temp[i], mask, s); + } + + // Scalar remainder for step 2 + for (; i < half; i++) { + float d_curr = (half + i < length) ? temp[half + i] : 0.0f; + float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr; + temp[i] += -0.052980118f * (d_prev + d_curr); + } + + // Step 3: Predict γ + for (i = 0; i + 16 <= length / 2; i += 16) { + __mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1; + + float s_curr_vals[16], s_next_vals[16], d_vals[16]; + for (int j = 0; j < 16; j++) { + s_curr_vals[j] = temp[i + j]; + s_next_vals[j] = ((i + j + 1) < half) ? temp[i + j + 1] : temp[i + j]; + if ((half + i + j) < length) { + d_vals[j] = temp[half + i + j]; + } + } + + __m512 s_curr = _mm512_loadu_ps(s_curr_vals); + __m512 s_next = _mm512_loadu_ps(s_next_vals); + __m512 d = _mm512_maskz_loadu_ps(mask, d_vals); + + d = _mm512_fmadd_ps(gamma_vec, _mm512_add_ps(s_curr, s_next), d); + + _mm512_mask_storeu_ps(&temp[half + i], mask, d); + } + + // Scalar remainder for step 3 + for (; i < length / 2; i++) { + if (half + i < length) { + float s_curr = temp[i]; + float s_next = (i + 1 < half) ? temp[i + 1] : s_curr; + temp[half + i] += 0.882911076f * (s_curr + s_next); + } + } + + // Step 4: Update δ + for (i = 0; i + 16 <= half; i += 16) { + __mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; + + float s_vals[16], d_curr_vals[16], d_prev_vals[16]; + for (int j = 0; j < 16; j++) { + s_vals[j] = temp[i + j]; + d_curr_vals[j] = ((half + i + j) < length) ? temp[half + i + j] : 0.0f; + d_prev_vals[j] = ((i + j) > 0 && (half + i + j - 1) < length) ? temp[half + i + j - 1] : d_curr_vals[j]; + } + + __m512 s = _mm512_loadu_ps(s_vals); + __m512 d_curr = _mm512_loadu_ps(d_curr_vals); + __m512 d_prev = _mm512_loadu_ps(d_prev_vals); + + s = _mm512_fmadd_ps(delta_vec, _mm512_add_ps(d_prev, d_curr), s); + + _mm512_mask_storeu_ps(&temp[i], mask, s); + } + + // Scalar remainder for step 4 + for (; i < half; i++) { + float d_curr = (half + i < length) ? temp[half + i] : 0.0f; + float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr; + temp[i] += 0.443506852f * (d_prev + d_curr); + } + + // Step 5: Scaling - vectorized + for (i = 0; i + 16 <= half; i += 16) { + __mmask16 mask = (i + 16 <= half) ? 0xFFFF : (1 << (half - i)) - 1; + __m512 s = _mm512_maskz_loadu_ps(mask, &temp[i]); + s = _mm512_mul_ps(s, K_vec); + _mm512_mask_storeu_ps(&temp[i], mask, s); + } + for (; i < half; i++) { + temp[i] *= 1.230174105f; + } + + for (i = 0; i + 16 <= length / 2; i += 16) { + __mmask16 mask = ((half + i + 16) <= length) ? 0xFFFF : (1 << (length - half - i)) - 1; + __m512 d = _mm512_maskz_loadu_ps(mask, &temp[half + i]); + d = _mm512_mul_ps(d, invK_vec); + _mm512_mask_storeu_ps(&temp[half + i], mask, d); + } + for (; i < length / 2; i++) { + if (half + i < length) { + temp[half + i] /= 1.230174105f; + } + } + + memcpy(data, temp, length * sizeof(float)); + free(temp); +} + +// Haar Forward DWT with AVX-512 +static inline void dwt_haar_forward_1d_avx512(float *data, int length) { + if (length < 2) return; + + float *temp = (float*)malloc(length * sizeof(float)); + int half = (length + 1) / 2; + + const __m512 half_vec = _mm512_set1_ps(0.5f); + + // Process 16 pairs at a time + int i; + for (i = 0; i + 16 <= half; i += 16) { + __mmask16 valid_mask = 0xFFFF; + + float even_vals[16], odd_vals[16]; + for (int j = 0; j < 16; j++) { + even_vals[j] = data[2 * (i + j)]; + if (2 * (i + j) + 1 < length) { + odd_vals[j] = data[2 * (i + j) + 1]; + } else { + odd_vals[j] = even_vals[j]; + valid_mask &= ~(1 << j); + } + } + + __m512 even = _mm512_loadu_ps(even_vals); + __m512 odd = _mm512_loadu_ps(odd_vals); + + // Low-pass: (even + odd) / 2 + __m512 low = _mm512_mul_ps(_mm512_add_ps(even, odd), half_vec); + // High-pass: (even - odd) / 2 + __m512 high = _mm512_mul_ps(_mm512_sub_ps(even, odd), half_vec); + + _mm512_storeu_ps(&temp[i], low); + _mm512_mask_storeu_ps(&temp[half + i], valid_mask, high); + } + + // Remaining scalar + for (; i < half; i++) { + if (2 * i + 1 < length) { + temp[i] = (data[2 * i] + data[2 * i + 1]) / 2.0f; + temp[half + i] = (data[2 * i] - data[2 * i + 1]) / 2.0f; + } else { + temp[i] = data[2 * i]; + if (half + i < length) { + temp[half + i] = 0.0f; + } + } + } + + memcpy(data, temp, length * sizeof(float)); + free(temp); +} + +// ============================================================================= +// AVX-512 Optimized Quantization Functions +// ============================================================================= + +static inline void quantise_dwt_coefficients_avx512( + float *coeffs, int16_t *quantised, int size, + float effective_q, float dead_zone_threshold, + int width, int height, int decomp_levels, int is_chroma, + int (*get_subband_level)(int, int, int, int), + int (*get_subband_type)(int, int, int, int) +) { + const __m512 q_vec = _mm512_set1_ps(effective_q); + const __m512 inv_q_vec = _mm512_set1_ps(1.0f / effective_q); + const __m512 half_vec = _mm512_set1_ps(0.5f); + const __m512 nhalf_vec = _mm512_set1_ps(-0.5f); + const __m512 zero_vec = _mm512_setzero_ps(); + const __m512i min_i32 = _mm512_set1_epi32(-32768); + const __m512i max_i32 = _mm512_set1_epi32(32767); + + int i; + for (i = 0; i + 16 <= size; i += 16) { + __m512 coeff = _mm512_loadu_ps(&coeffs[i]); + __m512 quant = _mm512_mul_ps(coeff, inv_q_vec); + + // Dead-zone handling (simplified - full version needs per-coeff logic) + if (dead_zone_threshold > 0.0f && !is_chroma) { + __m512 threshold_vec = _mm512_set1_ps(dead_zone_threshold); + __m512 abs_quant = _mm512_abs_ps(quant); + __mmask16 dead_mask = _mm512_cmp_ps_mask(abs_quant, threshold_vec, _CMP_LE_OQ); + quant = _mm512_mask_blend_ps(dead_mask, quant, zero_vec); + } + + // Manual rounding to match scalar behavior (round away from zero) + // First add 0.5 or -0.5 based on sign + __mmask16 pos_mask = _mm512_cmp_ps_mask(quant, zero_vec, _CMP_GE_OQ); + __m512 round_val = _mm512_mask_blend_ps(pos_mask, nhalf_vec, half_vec); + quant = _mm512_add_ps(quant, round_val); + + // Now truncate to int32 (this matches scalar (int32_t) cast after adding 0.5) + __m512i quant_i32 = _mm512_cvttps_epi32(quant); // cvtt = truncate (round toward zero) + quant_i32 = _mm512_max_epi32(quant_i32, min_i32); + quant_i32 = _mm512_min_epi32(quant_i32, max_i32); + + // Pack to int16 (AVX-512 has cvtsepi32_epi16) + __m256i quant_i16 = _mm512_cvtsepi32_epi16(quant_i32); + _mm256_storeu_si256((__m256i*)&quantised[i], quant_i16); + } + + // Remaining scalar + for (; i < size; i++) { + float quantised_val = coeffs[i] / effective_q; + + // Dead-zone (simplified) + if (dead_zone_threshold > 0.0f && !is_chroma) { + if (fabsf(quantised_val) <= dead_zone_threshold) { + quantised_val = 0.0f; + } + } + + int32_t val = (int32_t)(quantised_val + (quantised_val >= 0 ? 0.5f : -0.5f)); + quantised[i] = (int16_t)((val < -32768) ? -32768 : (val > 32767 ? 32767 : val)); + } +} + +// Perceptual quantization with per-coefficient weighting +static inline void quantise_dwt_coefficients_perceptual_avx512( + float *coeffs, int16_t *quantised, int size, + float *weights, // Pre-computed per-coefficient weights + float base_quantiser +) { + const __m512 base_q_vec = _mm512_set1_ps(base_quantiser); + const __m512 half_vec = _mm512_set1_ps(0.5f); + const __m512 nhalf_vec = _mm512_set1_ps(-0.5f); + const __m512 zero_vec = _mm512_setzero_ps(); + const __m512i min_i32 = _mm512_set1_epi32(-32768); + const __m512i max_i32 = _mm512_set1_epi32(32767); + + int i; + for (i = 0; i + 16 <= size; i += 16) { + __m512 coeff = _mm512_loadu_ps(&coeffs[i]); + __m512 weight = _mm512_loadu_ps(&weights[i]); + + // effective_q = base_q * weight + __m512 effective_q = _mm512_mul_ps(base_q_vec, weight); + __m512 quant = _mm512_div_ps(coeff, effective_q); + + // Manual rounding to match scalar behavior + __mmask16 pos_mask = _mm512_cmp_ps_mask(quant, zero_vec, _CMP_GE_OQ); + __m512 round_val = _mm512_mask_blend_ps(pos_mask, nhalf_vec, half_vec); + quant = _mm512_add_ps(quant, round_val); + + // Truncate to int32 (matches scalar cast after rounding) + __m512i quant_i32 = _mm512_cvttps_epi32(quant); + quant_i32 = _mm512_max_epi32(quant_i32, min_i32); + quant_i32 = _mm512_min_epi32(quant_i32, max_i32); + + __m256i quant_i16 = _mm512_cvtsepi32_epi16(quant_i32); + _mm256_storeu_si256((__m256i*)&quantised[i], quant_i16); + } + + // Remaining scalar + for (; i < size; i++) { + float effective_q = base_quantiser * weights[i]; + float quantised_val = coeffs[i] / effective_q; + int32_t val = (int32_t)(quantised_val + (quantised_val >= 0 ? 0.5f : -0.5f)); + quantised[i] = (int16_t)((val < -32768) ? -32768 : (val > 32767 ? 32767 : val)); + } +} + +// ============================================================================= +// AVX-512 Optimized Dequantization Functions +// ============================================================================= + +// Basic dequantization: quantised[i] * effective_q +static inline void dequantise_dwt_coefficients_avx512( + const int16_t *quantised, float *coeffs, int size, + float effective_q +) { + const __m512 q_vec = _mm512_set1_ps(effective_q); + + int i; + for (i = 0; i + 16 <= size; i += 16) { + // Load 16 int16 values + __m256i quant_i16 = _mm256_loadu_si256((__m256i*)&quantised[i]); + + // Convert int16 to int32 + __m512i quant_i32 = _mm512_cvtepi16_epi32(quant_i16); + + // Convert int32 to float + __m512 quant_f32 = _mm512_cvtepi32_ps(quant_i32); + + // Multiply by quantizer + __m512 dequant = _mm512_mul_ps(quant_f32, q_vec); + + _mm512_storeu_ps(&coeffs[i], dequant); + } + + // Remaining scalar + for (; i < size; i++) { + coeffs[i] = (float)quantised[i] * effective_q; + } +} + +// Perceptual dequantization with per-coefficient weights +static inline void dequantise_dwt_coefficients_perceptual_avx512( + const int16_t *quantised, float *coeffs, int size, + const float *weights, float base_quantiser +) { + const __m512 base_q_vec = _mm512_set1_ps(base_quantiser); + + int i; + for (i = 0; i + 16 <= size; i += 16) { + // Load 16 int16 values + __m256i quant_i16 = _mm256_loadu_si256((__m256i*)&quantised[i]); + + // Convert int16 → int32 → float + __m512i quant_i32 = _mm512_cvtepi16_epi32(quant_i16); + __m512 quant_f32 = _mm512_cvtepi32_ps(quant_i32); + + // Load weights + __m512 weight = _mm512_loadu_ps(&weights[i]); + + // effective_q = base_q * weight + __m512 effective_q = _mm512_mul_ps(base_q_vec, weight); + + // dequant = quantised * effective_q + __m512 dequant = _mm512_mul_ps(quant_f32, effective_q); + + _mm512_storeu_ps(&coeffs[i], dequant); + } + + // Remaining scalar + for (; i < size; i++) { + float effective_q = base_quantiser * weights[i]; + coeffs[i] = (float)quantised[i] * effective_q; + } +} + +// ============================================================================= +// AVX-512 Optimized RGB to YCoCg Conversion +// ============================================================================= + +static inline void rgb_to_ycocg_avx512(const uint8_t *rgb, float *y, float *co, float *cg, int width, int height) { + const int total_pixels = width * height; + const __m512 half_vec = _mm512_set1_ps(0.5f); + + int i; + // Process 16 pixels at a time (48 bytes of RGB data) + for (i = 0; i + 16 <= total_pixels; i += 16) { + // Load 16 RGB triplets (48 bytes) + // We need to deinterleave R, G, B channels + + // Manual load and deinterleave (AVX-512 doesn't have direct RGB deinterleave) + float r_vals[16], g_vals[16], b_vals[16]; + for (int j = 0; j < 16; j++) { + r_vals[j] = (float)rgb[(i + j) * 3 + 0]; + g_vals[j] = (float)rgb[(i + j) * 3 + 1]; + b_vals[j] = (float)rgb[(i + j) * 3 + 2]; + } + + __m512 r = _mm512_loadu_ps(r_vals); + __m512 g = _mm512_loadu_ps(g_vals); + __m512 b = _mm512_loadu_ps(b_vals); + + // YCoCg-R transform: + // co = r - b + // tmp = b + co * 0.5 + // cg = g - tmp + // y = tmp + cg * 0.5 + + __m512 co_vec = _mm512_sub_ps(r, b); + __m512 tmp = _mm512_fmadd_ps(co_vec, half_vec, b); // tmp = b + co * 0.5 + __m512 cg_vec = _mm512_sub_ps(g, tmp); + __m512 y_vec = _mm512_fmadd_ps(cg_vec, half_vec, tmp); // y = tmp + cg * 0.5 + + _mm512_storeu_ps(&y[i], y_vec); + _mm512_storeu_ps(&co[i], co_vec); + _mm512_storeu_ps(&cg[i], cg_vec); + } + + // Remaining pixels (scalar) + for (; i < total_pixels; i++) { + const float r = rgb[i * 3 + 0]; + const float g = rgb[i * 3 + 1]; + const float b = rgb[i * 3 + 2]; + + co[i] = r - b; + const float tmp = b + co[i] * 0.5f; + cg[i] = g - tmp; + y[i] = tmp + cg[i] * 0.5f; + } +} + +// ============================================================================= +// AVX-512 Optimized 2D DWT with Gather/Scatter +// ============================================================================= + +// Optimized column extraction using gather +static inline void dwt_2d_extract_column_avx512( + const float *tile_data, float *column, + int x, int width, int height +) { + // Create gather indices for column extraction + // indices[i] = (i * width + x) + + int y; + for (y = 0; y + 16 <= height; y += 16) { + // Build gather indices + int indices[16]; + for (int j = 0; j < 16; j++) { + indices[j] = (y + j) * width + x; + } + + __m512i vindex = _mm512_loadu_si512((__m512i*)indices); + __m512 col_data = _mm512_i32gather_ps(vindex, tile_data, 4); + _mm512_storeu_ps(&column[y], col_data); + } + + // Remaining scalar + for (; y < height; y++) { + column[y] = tile_data[y * width + x]; + } +} + +// Optimized column insertion using scatter +static inline void dwt_2d_insert_column_avx512( + float *tile_data, const float *column, + int x, int width, int height +) { + int y; + for (y = 0; y + 16 <= height; y += 16) { + // Build scatter indices + int indices[16]; + for (int j = 0; j < 16; j++) { + indices[j] = (y + j) * width + x; + } + + __m512i vindex = _mm512_loadu_si512((__m512i*)indices); + __m512 col_data = _mm512_loadu_ps(&column[y]); + _mm512_i32scatter_ps(tile_data, vindex, col_data, 4); + } + + // Remaining scalar + for (; y < height; y++) { + tile_data[y * width + x] = column[y]; + } +} + +#endif // __AVX512F__ + +#endif // TAV_AVX512_H