TAD: now processing entirely in float

This commit is contained in:
minjaesong
2025-10-24 05:31:38 +09:00
parent a9319fd812
commit 9dc71095a0
5 changed files with 139 additions and 319 deletions

View File

@@ -1,6 +1,7 @@
// Created by CuriousTorvald and Claude on 2025-10-23.
// TAD (Terrarum Advanced Audio) Encoder Library - DWT-based audio compression
// This file contains only the encoding functions for use by encoder_tad.c and encoder_tav.c
// Created by CuriousTorvald and Claude on 2025-10-24.
// TAD32 (Terrarum Advanced Audio - PCM32f version) Encoder Library
// Alternative version: PCM32f throughout encoding, PCM8 conversion only at decoder
// This file contains only the encoding functions for comparison testing
#include <stdio.h>
#include <stdlib.h>
@@ -11,12 +12,9 @@
#include "encoder_tad.h"
// Forward declarations for internal functions
static void dwt_haar_forward_1d(float *data, int length);
static void dwt_dd4_forward_1d(float *data, int length);
static void dwt_97_forward_1d(float *data, int length);
static void dwt_haar_forward_multilevel(float *data, int length, int levels);
static void ms_decorrelate(const int8_t *left, const int8_t *right, int8_t *mid, int8_t *side, size_t count);
static void convert_pcm16_to_pcm8_dithered(const int16_t *pcm16, int8_t *pcm8, int num_samples, int16_t *dither_error);
static void dwt_dd4_forward_multilevel(float *data, int length, int levels);
static void ms_decorrelate_16(const float *left, const float *right, float *mid, float *side, size_t count);
static void get_quantization_weights(int quality, int dwt_levels, float *weights);
static int get_deadzone_threshold(int quality);
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int quality, int apply_deadzone, int chunk_size, int dwt_levels);
@@ -26,15 +24,13 @@ static inline float FCLAMP(float x, float min, float max) {
return x < min ? min : (x > max ? max : x);
}
// Calculate DWT levels from chunk size (non-power-of-2 supported, >= 1024)
// Calculate DWT levels from chunk size
static int calculate_dwt_levels(int chunk_size) {
if (chunk_size < TAD_MIN_CHUNK_SIZE) {
fprintf(stderr, "Error: Chunk size %d is below minimum %d\n", chunk_size, TAD_MIN_CHUNK_SIZE);
if (chunk_size < TAD32_MIN_CHUNK_SIZE) {
fprintf(stderr, "Error: Chunk size %d is below minimum %d\n", chunk_size, TAD32_MIN_CHUNK_SIZE);
return -1;
}
// For non-power-of-2, find next power of 2 and calculate levels
// Then subtract 2 for maximum decomposition
int levels = 0;
int size = chunk_size;
while (size > 1) {
@@ -48,39 +44,13 @@ static int calculate_dwt_levels(int chunk_size) {
levels++;
}
return levels - 2; // Maximum decomposition leaves 2-sample approximation
return levels - 2; // Maximum decomposition
}
//=============================================================================
// Haar DWT Implementation
// DD-4 DWT Implementation
//=============================================================================
static void dwt_haar_forward_1d(float *data, int length) {
if (length < 2) return;
float *temp = malloc(length * sizeof(float));
int half = (length + 1) / 2;
// Haar transform: compute averages (low-pass) and differences (high-pass)
for (int i = 0; i < half; i++) {
if (2 * i + 1 < length) {
// Average of adjacent pairs (low-pass)
temp[i] = (data[2 * i] + data[2 * i + 1]) / 2.0f;
// Difference of adjacent pairs (high-pass)
temp[half + i] = (data[2 * i] - data[2 * i + 1]) / 2.0f;
} else {
// Handle odd length: last sample goes to low-pass
temp[i] = data[2 * i];
if (half + i < length) {
temp[half + i] = 0.0f;
}
}
}
memcpy(data, temp, length * sizeof(float));
free(temp);
}
// Four-point interpolating Deslauriers-Dubuc (DD-4) wavelet forward 1D transform
static void dwt_dd4_forward_1d(float *data, int length) {
if (length < 2) return;
@@ -129,76 +99,8 @@ static void dwt_dd4_forward_1d(float *data, int length) {
free(temp);
}
// 1D DWT using lifting scheme for 9/7 irreversible filter
static void dwt_97_forward_1d(float *data, int length) {
if (length < 2) return;
float *temp = malloc(length * sizeof(float));
int half = (length + 1) / 2;
// Split into even/odd samples
for (int i = 0; i < half; i++) {
temp[i] = data[2 * i]; // Even (low)
}
for (int i = 0; i < length / 2; i++) {
temp[half + i] = data[2 * i + 1]; // Odd (high)
}
// JPEG2000 9/7 forward lifting steps
const float alpha = -1.586134342f;
const float beta = -0.052980118f;
const float gamma = 0.882911076f;
const float delta = 0.443506852f;
const float K = 1.230174105f;
// Step 1: Predict α
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
float s_curr = temp[i];
float s_next = (i + 1 < half) ? temp[i + 1] : s_curr;
temp[half + i] += alpha * (s_curr + s_next);
}
}
// Step 2: Update β
for (int i = 0; i < half; i++) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr;
temp[i] += beta * (d_prev + d_curr);
}
// Step 3: Predict γ
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
float s_curr = temp[i];
float s_next = (i + 1 < half) ? temp[i + 1] : s_curr;
temp[half + i] += gamma * (s_curr + s_next);
}
}
// Step 4: Update δ
for (int i = 0; i < half; i++) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr;
temp[i] += delta * (d_prev + d_curr);
}
// Step 5: Scaling
for (int i = 0; i < half; i++) {
temp[i] *= K;
}
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
temp[half + i] /= K;
}
}
memcpy(data, temp, length * sizeof(float));
free(temp);
}
// Apply multi-level DWT (using DD-4 wavelet)
static void dwt_haar_forward_multilevel(float *data, int length, int levels) {
static void dwt_dd4_forward_multilevel(float *data, int length, int levels) {
int current_length = length;
for (int level = 0; level < levels; level++) {
dwt_dd4_forward_1d(data, current_length);
@@ -207,35 +109,16 @@ static void dwt_haar_forward_multilevel(float *data, int length, int levels) {
}
//=============================================================================
// M/S Stereo Decorrelation
// M/S Stereo Decorrelation (PCM32f version)
//=============================================================================
static void ms_decorrelate(const int8_t *left, const int8_t *right, int8_t *mid, int8_t *side, size_t count) {
static void ms_decorrelate_16(const float *left, const float *right, float *mid, float *side, size_t count) {
for (size_t i = 0; i < count; i++) {
// Mid = (L + R) / 2, Side = (L - R) / 2
int32_t l = left[i];
int32_t r = right[i];
mid[i] = (int8_t)((l + r) / 2);
side[i] = (int8_t)((l - r) / 2);
}
}
//=============================================================================
// PCM16 to Signed PCM8 Conversion with Dithering
//=============================================================================
static void convert_pcm16_to_pcm8_dithered(const int16_t *pcm16, int8_t *pcm8, int num_samples, int16_t *dither_error) {
for (int i = 0; i < num_samples; i++) {
for (int ch = 0; ch < 2; ch++) { // Stereo: L and R
int idx = i * 2 + ch;
int32_t sample = (int32_t)pcm16[idx];
sample += dither_error[ch];
int32_t quantized = sample >> 8;
if (quantized < -128) quantized = -128;
if (quantized > 127) quantized = 127;
pcm8[idx] = (int8_t)quantized;
dither_error[ch] = sample - (quantized << 8);
}
float l = left[i];
float r = right[i];
mid[i] = (l + r) / 2.0f;
side[i] = (l - r) / 2.0f;
}
}
@@ -263,15 +146,15 @@ static void get_quantization_weights(int quality, int dwt_levels, float *weights
/*15*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f}
};
float quality_scale = 1.0f + FCLAMP((3 - quality) * 0.5f, 0.0f, 1000.0f);
float quality_scale = 4.0f * (1.0f + FCLAMP((3 - quality) * 0.5f, 0.0f, 1000.0f));
for (int i = 0; i < dwt_levels; i++) {
weights[i] = FCLAMP(base_weights[dwt_levels][i] * quality_scale, 1.0f, 1000.0f);
weights[i] = base_weights[dwt_levels][i] * quality_scale;
}
}
static int get_deadzone_threshold(int quality) {
const int thresholds[] = {1,1,0,0,0,0}; // Q0 to Q5
const int thresholds[] = {1,1,1,1,1,1}; // Q0 to Q5
return thresholds[quality];
}
@@ -302,7 +185,7 @@ static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, s
if (weight_idx >= dwt_levels) weight_idx = dwt_levels - 1;
float weight = weights[weight_idx];
float val = coeffs[i] / weight;
float val = coeffs[i] / weight * TAD32_COEFF_SCALAR;
int16_t quant_val = (int16_t)roundf(val);
if (apply_deadzone && sideband >= dwt_levels - 1) {
@@ -359,8 +242,8 @@ static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *o
// Public API: Chunk Encoding
//=============================================================================
size_t tad_encode_chunk(const int16_t *pcm16_stereo, size_t num_samples, int quality,
int use_zstd, uint8_t *output) {
size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int quality,
int use_zstd, uint8_t *output) {
// Calculate DWT levels from chunk size
int dwt_levels = calculate_dwt_levels(num_samples);
if (dwt_levels < 0) {
@@ -368,12 +251,11 @@ size_t tad_encode_chunk(const int16_t *pcm16_stereo, size_t num_samples, int qua
return 0;
}
// Allocate working buffers
int8_t *pcm8_stereo = malloc(num_samples * 2 * sizeof(int8_t));
int8_t *pcm8_left = malloc(num_samples * sizeof(int8_t));
int8_t *pcm8_right = malloc(num_samples * sizeof(int8_t));
int8_t *pcm8_mid = malloc(num_samples * sizeof(int8_t));
int8_t *pcm8_side = malloc(num_samples * sizeof(int8_t));
// Allocate working buffers (PCM32f throughout, int32 coefficients)
float *pcm32_left = malloc(num_samples * sizeof(float));
float *pcm32_right = malloc(num_samples * sizeof(float));
float *pcm32_mid = malloc(num_samples * sizeof(float));
float *pcm32_side = malloc(num_samples * sizeof(float));
float *dwt_mid = malloc(num_samples * sizeof(float));
float *dwt_side = malloc(num_samples * sizeof(float));
@@ -381,34 +263,30 @@ size_t tad_encode_chunk(const int16_t *pcm16_stereo, size_t num_samples, int qua
int16_t *quant_mid = malloc(num_samples * sizeof(int16_t));
int16_t *quant_side = malloc(num_samples * sizeof(int16_t));
// Step 1: Convert PCM16 to signed PCM8 with dithering
int16_t dither_error[2] = {0, 0};
convert_pcm16_to_pcm8_dithered(pcm16_stereo, pcm8_stereo, num_samples, dither_error);
// Deinterleave stereo
// Step 1: Deinterleave stereo
for (size_t i = 0; i < num_samples; i++) {
pcm8_left[i] = pcm8_stereo[i * 2];
pcm8_right[i] = pcm8_stereo[i * 2 + 1];
pcm32_left[i] = pcm32_stereo[i * 2];
pcm32_right[i] = pcm32_stereo[i * 2 + 1];
}
// Step 2: M/S decorrelation
ms_decorrelate(pcm8_left, pcm8_right, pcm8_mid, pcm8_side, num_samples);
ms_decorrelate_16(pcm32_left, pcm32_right, pcm32_mid, pcm32_side, num_samples);
// Step 3: Convert to float and apply DWT
for (size_t i = 0; i < num_samples; i++) {
dwt_mid[i] = (float)pcm8_mid[i];
dwt_side[i] = (float)pcm8_side[i];
dwt_mid[i] = pcm32_mid[i];
dwt_side[i] = pcm32_side[i];
}
dwt_haar_forward_multilevel(dwt_mid, num_samples, dwt_levels);
dwt_haar_forward_multilevel(dwt_side, num_samples, dwt_levels);
dwt_dd4_forward_multilevel(dwt_mid, num_samples, dwt_levels);
dwt_dd4_forward_multilevel(dwt_side, num_samples, dwt_levels);
// Step 4: Quantize with frequency-dependent weights and dead zone
quantize_dwt_coefficients(dwt_mid, quant_mid, num_samples, quality, 1, num_samples, dwt_levels);
quantize_dwt_coefficients(dwt_side, quant_side, num_samples, quality, 1, num_samples, dwt_levels);
// Step 5: Encode with 2-bit significance map
uint8_t *temp_buffer = malloc(num_samples * 4 * sizeof(int16_t));
// Step 5: Encode with 2-bit significance map (32-bit version)
uint8_t *temp_buffer = malloc(num_samples * 4 * sizeof(int32_t));
size_t mid_size = encode_sigmap_2bit(quant_mid, num_samples, temp_buffer);
size_t side_size = encode_sigmap_2bit(quant_side, num_samples, temp_buffer + mid_size);
@@ -429,13 +307,13 @@ size_t tad_encode_chunk(const int16_t *pcm16_stereo, size_t num_samples, int qua
size_t zstd_bound = ZSTD_compressBound(uncompressed_size);
uint8_t *zstd_buffer = malloc(zstd_bound);
payload_size = ZSTD_compress(zstd_buffer, zstd_bound, temp_buffer, uncompressed_size, TAD_ZSTD_LEVEL);
payload_size = ZSTD_compress(zstd_buffer, zstd_bound, temp_buffer, uncompressed_size, TAD32_ZSTD_LEVEL);
if (ZSTD_isError(payload_size)) {
fprintf(stderr, "Error: Zstd compression failed: %s\n", ZSTD_getErrorName(payload_size));
free(zstd_buffer);
free(pcm8_stereo); free(pcm8_left); free(pcm8_right);
free(pcm8_mid); free(pcm8_side); free(dwt_mid); free(dwt_side);
free(pcm32_left); free(pcm32_right);
free(pcm32_mid); free(pcm32_side); free(dwt_mid); free(dwt_side);
free(quant_mid); free(quant_side); free(temp_buffer);
return 0;
}
@@ -451,8 +329,8 @@ size_t tad_encode_chunk(const int16_t *pcm16_stereo, size_t num_samples, int qua
write_ptr += payload_size;
// Cleanup
free(pcm8_stereo); free(pcm8_left); free(pcm8_right);
free(pcm8_mid); free(pcm8_side); free(dwt_mid); free(dwt_side);
free(pcm32_left); free(pcm32_right);
free(pcm32_mid); free(pcm32_side); free(dwt_mid); free(dwt_side);
free(quant_mid); free(quant_side); free(temp_buffer);
return write_ptr - output;