TAD: better bit allocation using statistics

This commit is contained in:
minjaesong
2025-10-26 18:16:28 +09:00
parent 9fcb7fc95c
commit 370d511f44
5 changed files with 167 additions and 177 deletions

View File

@@ -22,9 +22,8 @@ static const float TAD32_COEFF_SCALARS[] = {64.0f, 45.255f, 32.0f, 22.627f, 16.0
static void dwt_dd4_forward_1d(float *data, int length);
static void dwt_dd4_forward_multilevel(float *data, int length, int levels);
static void ms_decorrelate_16(const float *left, const float *right, float *mid, float *side, size_t count);
static void get_quantization_weights(int quality, int dwt_levels, float *weights);
static int get_deadzone_threshold(int quality);
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int quality, int apply_deadzone, int chunk_size, int dwt_levels, int *current_subband_index);
static void get_quantization_weights(int dwt_levels, float *weights);
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int apply_deadzone, int chunk_size, int dwt_levels, int quant_bits, int *current_subband_index);
static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *output);
static inline float FCLAMP(float x, float min, float max) {
@@ -220,45 +219,44 @@ static void compress_gamma(float *left, float *right, size_t count) {
// Quantization with Frequency-Dependent Weighting
//=============================================================================
static void get_quantization_weights(int quality, int dwt_levels, float *weights) {
const float base_weights[16][16] = {
/* 0*/{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
/* 1*/{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
/* 2*/{1.0f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 3*/{0.2f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 4*/{0.2f, 0.8f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 5*/{0.2f, 0.8f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 6*/{0.2f, 0.2f, 0.8f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 7*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 8*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/* 9*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/*10*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/*11*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/*12*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/*13*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f},
/*14*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f},
/*15*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f}
};
#define LAMBDA_FIXED 5.8f
float quality_scale = 1.0f * (1.0f + FCLAMP((5 - quality) * 0.5f, 0.0f, 1000.0f));
for (int i = 0; i < dwt_levels; i++) {
weights[i] = 1.0f;//base_weights[dwt_levels][i] * quality_scale;
// Lambda-based companding encoder (based on Laplacian distribution CDF)
// val must be normalised to [-1,1]
// Returns quantized index in range [-(2^quant_bits-1), +(2^quant_bits-1)]
static int16_t lambda_companding(float val, int quant_bits) {
// Handle zero
if (fabsf(val) < 1e-9f) {
return 0;
}
int sign = (val < 0) ? -1 : 1;
float abs_val = fabsf(val);
// Clamp to [0, 1]
if (abs_val > 1.0f) abs_val = 1.0f;
// Maximum index for the given quant_bits
int max_index = (1 << (quant_bits - 1)) - 1;
// Laplacian CDF for x >= 0: F(x) = 1 - 0.5 * exp(-λ*x)
// Map to [0.5, 1.0] range (half of CDF for positive values)
float cdf = 1.0f - 0.5f * expf(-LAMBDA_FIXED * abs_val);
// Map CDF from [0.5, 1.0] to [0, 1] for positive half
float normalized_cdf = (cdf - 0.5f) * 2.0f;
// Quantize to index
int index = (int)roundf(normalized_cdf * max_index);
// Clamp index to valid range [0, max_index]
if (index < 0) index = 0;
if (index > max_index) index = max_index;
return (int16_t)(sign * index);
}
#define QUANT_STEPS 512.0f // 64 -> [-64..64] -> 7 bits for LL
static int get_deadzone_threshold(int quality) {
const int thresholds[] = {0,0,0,0,0,0}; // Q0 to Q5
return thresholds[quality];
}
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int quality, int apply_deadzone, int chunk_size, int dwt_levels, int *current_subband_index) {
float weights[16];
get_quantization_weights(quality, dwt_levels, weights);
int deadzone = apply_deadzone ? get_deadzone_threshold(quality) : 0;
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int apply_deadzone, int chunk_size, int dwt_levels, int quant_bits, int *current_subband_index) {
int first_band_size = chunk_size >> dwt_levels;
int *sideband_starts = malloc((dwt_levels + 2) * sizeof(int));
@@ -282,19 +280,8 @@ static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, s
current_subband_index[i] = sideband;
}
int weight_idx = (sideband == 0) ? 0 : sideband - 1;
if (weight_idx >= dwt_levels) weight_idx = dwt_levels - 1;
float weight = weights[weight_idx];
float val = (coeffs[i] / TAD32_COEFF_SCALARS[sideband]) * (QUANT_STEPS * weight);
// (coeffs[i] / TAD32_COEFF_SCALARS[sideband]) normalises coeffs to -1..1
int16_t quant_val = (int16_t)roundf(val);
if (apply_deadzone && sideband >= dwt_levels - 1) {
if (quant_val > -deadzone && quant_val < deadzone) {
quant_val = 0;
}
}
float val = (coeffs[i] / (TAD32_COEFF_SCALARS[sideband])); // val is normalised to [-1,1]
int16_t quant_val = lambda_companding(val, quant_bits);
quantized[i] = quant_val;
}
@@ -302,25 +289,44 @@ static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, s
free(sideband_starts);
}
// idea 1: power-of-two companding
// for quant step 8:
// Q -> Float
// 0 -> 0
// 1 -> 1/128
// 2 -> 1/64
// 3 -> 1/32
// 4 -> 1/16
// 5 -> 1/8
// 6 -> 1/4
// 7 -> 1/2
// 8 -> 1/1
// for -1 to -8, just invert the sign
//=============================================================================
// Significance Map Encoding
//=============================================================================
static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *output) {
size_t map_bytes = (count * 2 + 7) / 8;
uint8_t *map = output;
memset(map, 0, map_bytes);
uint8_t *write_ptr = output + map_bytes;
int16_t *value_ptr = (int16_t*)write_ptr;
uint32_t other_count = 0;
for (size_t i = 0; i < count; i++) {
int16_t val = values[i];
uint8_t code;
if (val == 0) code = 0; // 00
else if (val == 1) code = 1; // 01
else if (val == -1) code = 2; // 10
else {
code = 3; // 11
value_ptr[other_count++] = val;
}
size_t bit_pos = i * 2;
size_t byte_idx = bit_pos / 8;
size_t bit_offset = bit_pos % 8;
map[byte_idx] |= (code << bit_offset);
if (bit_offset == 7 && byte_idx + 1 < map_bytes) {
map[byte_idx + 1] |= (code >> 1);
}
}
return map_bytes + other_count * sizeof(int16_t);
}
//=============================================================================
// Coefficient Statistics
//=============================================================================
@@ -339,6 +345,7 @@ typedef struct {
float median;
float q3;
float max;
float lambda; // Laplacian distribution parameter (1/b, where b is scale)
} CoeffStats;
typedef struct {
@@ -410,6 +417,7 @@ static void accumulate_coefficients(const float *coeffs, int dwt_levels, int chu
static void calculate_coeff_stats(const float *coeffs, size_t count, CoeffStats *stats) {
if (count == 0) {
stats->min = stats->q1 = stats->median = stats->q3 = stats->max = 0.0f;
stats->lambda = 0.0f;
return;
}
@@ -425,6 +433,16 @@ static void calculate_coeff_stats(const float *coeffs, size_t count, CoeffStats
stats->q3 = sorted[(3 * count) / 4];
free(sorted);
// Estimate Laplacian distribution parameter λ = 1/b
// For Laplacian centered at μ=0, MLE gives: b = mean(|x|)
// Therefore: λ = 1/b = 1/mean(|x|)
double sum_abs = 0.0;
for (size_t i = 0; i < count; i++) {
sum_abs += fabs(coeffs[i]);
}
double mean_abs = sum_abs / count;
stats->lambda = (mean_abs > 1e-9) ? (1.0f / mean_abs) : 0.0f;
}
#define HISTOGRAM_BINS 40
@@ -492,9 +510,9 @@ void tad32_print_statistics(void) {
// Print Mid channel statistics
fprintf(stderr, "\nMid Channel:\n");
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s\n",
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max");
fprintf(stderr, "--------------------------------------------------------------------------------\n");
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s %10s\n",
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max", "Lambda");
fprintf(stderr, "----------------------------------------------------------------------------------------\n");
for (int s = 0; s < num_subbands; s++) {
CoeffStats stats;
@@ -507,9 +525,9 @@ void tad32_print_statistics(void) {
snprintf(band_name, sizeof(band_name), "H (L%d)", stats_dwt_levels - s + 1);
}
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f\n",
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f %10.3f\n",
band_name, mid_accumulators[s].count,
stats.min, stats.q1, stats.median, stats.q3, stats.max);
stats.min, stats.q1, stats.median, stats.q3, stats.max, stats.lambda);
}
// Print Mid channel histograms
@@ -526,9 +544,9 @@ void tad32_print_statistics(void) {
// Print Side channel statistics
fprintf(stderr, "\nSide Channel:\n");
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s\n",
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max");
fprintf(stderr, "--------------------------------------------------------------------------------\n");
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s %10s\n",
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max", "Lambda");
fprintf(stderr, "----------------------------------------------------------------------------------------\n");
for (int s = 0; s < num_subbands; s++) {
CoeffStats stats;
@@ -541,9 +559,9 @@ void tad32_print_statistics(void) {
snprintf(band_name, sizeof(band_name), "H (L%d)", stats_dwt_levels - s + 1);
}
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f\n",
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f %10.3f\n",
band_name, side_accumulators[s].count,
stats.min, stats.q1, stats.median, stats.q3, stats.max);
stats.min, stats.q1, stats.median, stats.q3, stats.max, stats.lambda);
}
// Print Side channel histograms
@@ -576,46 +594,12 @@ void tad32_free_statistics(void) {
stats_initialized = 0;
}
static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *output) {
size_t map_bytes = (count * 2 + 7) / 8;
uint8_t *map = output;
memset(map, 0, map_bytes);
uint8_t *write_ptr = output + map_bytes;
int16_t *value_ptr = (int16_t*)write_ptr;
uint32_t other_count = 0;
for (size_t i = 0; i < count; i++) {
int16_t val = values[i];
uint8_t code;
if (val == 0) code = 0; // 00
else if (val == 1) code = 1; // 01
else if (val == -1) code = 2; // 10
else {
code = 3; // 11
value_ptr[other_count++] = val;
}
size_t bit_pos = i * 2;
size_t byte_idx = bit_pos / 8;
size_t bit_offset = bit_pos % 8;
map[byte_idx] |= (code << bit_offset);
if (bit_offset == 7 && byte_idx + 1 < map_bytes) {
map[byte_idx + 1] |= (code >> 1);
}
}
return map_bytes + other_count * sizeof(int16_t);
}
//=============================================================================
// Public API: Chunk Encoding
//=============================================================================
size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int quality,
int use_zstd, uint8_t *output) {
size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples,
int quant_bits, int use_zstd, uint8_t *output) {
// Calculate DWT levels from chunk size
int dwt_levels = calculate_dwt_levels(num_samples);
if (dwt_levels < 0) {
@@ -670,8 +654,8 @@ size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int qua
}
// Step 4: Quantize with frequency-dependent weights and dead zone
quantize_dwt_coefficients(dwt_mid, quant_mid, num_samples, quality, 1, num_samples, dwt_levels, NULL);
quantize_dwt_coefficients(dwt_side, quant_side, num_samples, quality, 1, num_samples, dwt_levels, NULL);
quantize_dwt_coefficients(dwt_mid, quant_mid, num_samples, 1, num_samples, dwt_levels, quant_bits, NULL);
quantize_dwt_coefficients(dwt_side, quant_side, num_samples, 1, num_samples, dwt_levels, quant_bits, NULL);
// Step 5: Encode with 2-bit significance map (32-bit version)
uint8_t *temp_buffer = malloc(num_samples * 4 * sizeof(int32_t));
@@ -683,9 +667,13 @@ size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int qua
// Step 6: Optional Zstd compression
uint8_t *write_ptr = output;
// Write chunk header
*((uint16_t*)write_ptr) = (uint16_t)num_samples;
write_ptr += sizeof(uint16_t);
*write_ptr = (uint8_t)quant_bits;
write_ptr += sizeof(uint8_t);
uint32_t *payload_size_ptr = (uint32_t*)write_ptr;
write_ptr += sizeof(uint32_t);