mirror of
https://github.com/curioustorvald/tsvm.git
synced 2026-03-11 05:31:51 +09:00
TAD: better bit allocation using statistics
This commit is contained in:
@@ -22,9 +22,8 @@ static const float TAD32_COEFF_SCALARS[] = {64.0f, 45.255f, 32.0f, 22.627f, 16.0
|
||||
static void dwt_dd4_forward_1d(float *data, int length);
|
||||
static void dwt_dd4_forward_multilevel(float *data, int length, int levels);
|
||||
static void ms_decorrelate_16(const float *left, const float *right, float *mid, float *side, size_t count);
|
||||
static void get_quantization_weights(int quality, int dwt_levels, float *weights);
|
||||
static int get_deadzone_threshold(int quality);
|
||||
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int quality, int apply_deadzone, int chunk_size, int dwt_levels, int *current_subband_index);
|
||||
static void get_quantization_weights(int dwt_levels, float *weights);
|
||||
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int apply_deadzone, int chunk_size, int dwt_levels, int quant_bits, int *current_subband_index);
|
||||
static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *output);
|
||||
|
||||
static inline float FCLAMP(float x, float min, float max) {
|
||||
@@ -220,45 +219,44 @@ static void compress_gamma(float *left, float *right, size_t count) {
|
||||
// Quantization with Frequency-Dependent Weighting
|
||||
//=============================================================================
|
||||
|
||||
static void get_quantization_weights(int quality, int dwt_levels, float *weights) {
|
||||
const float base_weights[16][16] = {
|
||||
/* 0*/{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
/* 1*/{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
/* 2*/{1.0f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 3*/{0.2f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 4*/{0.2f, 0.8f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 5*/{0.2f, 0.8f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 6*/{0.2f, 0.2f, 0.8f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 7*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 8*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/* 9*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/*10*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/*11*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/*12*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/*13*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f},
|
||||
/*14*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f},
|
||||
/*15*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f}
|
||||
};
|
||||
#define LAMBDA_FIXED 5.8f
|
||||
|
||||
float quality_scale = 1.0f * (1.0f + FCLAMP((5 - quality) * 0.5f, 0.0f, 1000.0f));
|
||||
|
||||
for (int i = 0; i < dwt_levels; i++) {
|
||||
weights[i] = 1.0f;//base_weights[dwt_levels][i] * quality_scale;
|
||||
// Lambda-based companding encoder (based on Laplacian distribution CDF)
|
||||
// val must be normalised to [-1,1]
|
||||
// Returns quantized index in range [-(2^quant_bits-1), +(2^quant_bits-1)]
|
||||
static int16_t lambda_companding(float val, int quant_bits) {
|
||||
// Handle zero
|
||||
if (fabsf(val) < 1e-9f) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sign = (val < 0) ? -1 : 1;
|
||||
float abs_val = fabsf(val);
|
||||
|
||||
// Clamp to [0, 1]
|
||||
if (abs_val > 1.0f) abs_val = 1.0f;
|
||||
|
||||
// Maximum index for the given quant_bits
|
||||
int max_index = (1 << (quant_bits - 1)) - 1;
|
||||
|
||||
// Laplacian CDF for x >= 0: F(x) = 1 - 0.5 * exp(-λ*x)
|
||||
// Map to [0.5, 1.0] range (half of CDF for positive values)
|
||||
float cdf = 1.0f - 0.5f * expf(-LAMBDA_FIXED * abs_val);
|
||||
|
||||
// Map CDF from [0.5, 1.0] to [0, 1] for positive half
|
||||
float normalized_cdf = (cdf - 0.5f) * 2.0f;
|
||||
|
||||
// Quantize to index
|
||||
int index = (int)roundf(normalized_cdf * max_index);
|
||||
|
||||
// Clamp index to valid range [0, max_index]
|
||||
if (index < 0) index = 0;
|
||||
if (index > max_index) index = max_index;
|
||||
|
||||
return (int16_t)(sign * index);
|
||||
}
|
||||
|
||||
#define QUANT_STEPS 512.0f // 64 -> [-64..64] -> 7 bits for LL
|
||||
|
||||
static int get_deadzone_threshold(int quality) {
|
||||
const int thresholds[] = {0,0,0,0,0,0}; // Q0 to Q5
|
||||
return thresholds[quality];
|
||||
}
|
||||
|
||||
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int quality, int apply_deadzone, int chunk_size, int dwt_levels, int *current_subband_index) {
|
||||
float weights[16];
|
||||
get_quantization_weights(quality, dwt_levels, weights);
|
||||
int deadzone = apply_deadzone ? get_deadzone_threshold(quality) : 0;
|
||||
|
||||
static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, size_t count, int apply_deadzone, int chunk_size, int dwt_levels, int quant_bits, int *current_subband_index) {
|
||||
int first_band_size = chunk_size >> dwt_levels;
|
||||
|
||||
int *sideband_starts = malloc((dwt_levels + 2) * sizeof(int));
|
||||
@@ -282,19 +280,8 @@ static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, s
|
||||
current_subband_index[i] = sideband;
|
||||
}
|
||||
|
||||
int weight_idx = (sideband == 0) ? 0 : sideband - 1;
|
||||
if (weight_idx >= dwt_levels) weight_idx = dwt_levels - 1;
|
||||
|
||||
float weight = weights[weight_idx];
|
||||
float val = (coeffs[i] / TAD32_COEFF_SCALARS[sideband]) * (QUANT_STEPS * weight);
|
||||
// (coeffs[i] / TAD32_COEFF_SCALARS[sideband]) normalises coeffs to -1..1
|
||||
int16_t quant_val = (int16_t)roundf(val);
|
||||
|
||||
if (apply_deadzone && sideband >= dwt_levels - 1) {
|
||||
if (quant_val > -deadzone && quant_val < deadzone) {
|
||||
quant_val = 0;
|
||||
}
|
||||
}
|
||||
float val = (coeffs[i] / (TAD32_COEFF_SCALARS[sideband])); // val is normalised to [-1,1]
|
||||
int16_t quant_val = lambda_companding(val, quant_bits);
|
||||
|
||||
quantized[i] = quant_val;
|
||||
}
|
||||
@@ -302,25 +289,44 @@ static void quantize_dwt_coefficients(const float *coeffs, int16_t *quantized, s
|
||||
free(sideband_starts);
|
||||
}
|
||||
|
||||
// idea 1: power-of-two companding
|
||||
// for quant step 8:
|
||||
// Q -> Float
|
||||
// 0 -> 0
|
||||
// 1 -> 1/128
|
||||
// 2 -> 1/64
|
||||
// 3 -> 1/32
|
||||
// 4 -> 1/16
|
||||
// 5 -> 1/8
|
||||
// 6 -> 1/4
|
||||
// 7 -> 1/2
|
||||
// 8 -> 1/1
|
||||
// for -1 to -8, just invert the sign
|
||||
|
||||
|
||||
//=============================================================================
|
||||
// Significance Map Encoding
|
||||
//=============================================================================
|
||||
|
||||
static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *output) {
|
||||
size_t map_bytes = (count * 2 + 7) / 8;
|
||||
uint8_t *map = output;
|
||||
memset(map, 0, map_bytes);
|
||||
|
||||
uint8_t *write_ptr = output + map_bytes;
|
||||
int16_t *value_ptr = (int16_t*)write_ptr;
|
||||
uint32_t other_count = 0;
|
||||
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
int16_t val = values[i];
|
||||
uint8_t code;
|
||||
|
||||
if (val == 0) code = 0; // 00
|
||||
else if (val == 1) code = 1; // 01
|
||||
else if (val == -1) code = 2; // 10
|
||||
else {
|
||||
code = 3; // 11
|
||||
value_ptr[other_count++] = val;
|
||||
}
|
||||
|
||||
size_t bit_pos = i * 2;
|
||||
size_t byte_idx = bit_pos / 8;
|
||||
size_t bit_offset = bit_pos % 8;
|
||||
|
||||
map[byte_idx] |= (code << bit_offset);
|
||||
if (bit_offset == 7 && byte_idx + 1 < map_bytes) {
|
||||
map[byte_idx + 1] |= (code >> 1);
|
||||
}
|
||||
}
|
||||
|
||||
return map_bytes + other_count * sizeof(int16_t);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Coefficient Statistics
|
||||
//=============================================================================
|
||||
@@ -339,6 +345,7 @@ typedef struct {
|
||||
float median;
|
||||
float q3;
|
||||
float max;
|
||||
float lambda; // Laplacian distribution parameter (1/b, where b is scale)
|
||||
} CoeffStats;
|
||||
|
||||
typedef struct {
|
||||
@@ -410,6 +417,7 @@ static void accumulate_coefficients(const float *coeffs, int dwt_levels, int chu
|
||||
static void calculate_coeff_stats(const float *coeffs, size_t count, CoeffStats *stats) {
|
||||
if (count == 0) {
|
||||
stats->min = stats->q1 = stats->median = stats->q3 = stats->max = 0.0f;
|
||||
stats->lambda = 0.0f;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -425,6 +433,16 @@ static void calculate_coeff_stats(const float *coeffs, size_t count, CoeffStats
|
||||
stats->q3 = sorted[(3 * count) / 4];
|
||||
|
||||
free(sorted);
|
||||
|
||||
// Estimate Laplacian distribution parameter λ = 1/b
|
||||
// For Laplacian centered at μ=0, MLE gives: b = mean(|x|)
|
||||
// Therefore: λ = 1/b = 1/mean(|x|)
|
||||
double sum_abs = 0.0;
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
sum_abs += fabs(coeffs[i]);
|
||||
}
|
||||
double mean_abs = sum_abs / count;
|
||||
stats->lambda = (mean_abs > 1e-9) ? (1.0f / mean_abs) : 0.0f;
|
||||
}
|
||||
|
||||
#define HISTOGRAM_BINS 40
|
||||
@@ -492,9 +510,9 @@ void tad32_print_statistics(void) {
|
||||
|
||||
// Print Mid channel statistics
|
||||
fprintf(stderr, "\nMid Channel:\n");
|
||||
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s\n",
|
||||
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max");
|
||||
fprintf(stderr, "--------------------------------------------------------------------------------\n");
|
||||
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s %10s\n",
|
||||
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max", "Lambda");
|
||||
fprintf(stderr, "----------------------------------------------------------------------------------------\n");
|
||||
|
||||
for (int s = 0; s < num_subbands; s++) {
|
||||
CoeffStats stats;
|
||||
@@ -507,9 +525,9 @@ void tad32_print_statistics(void) {
|
||||
snprintf(band_name, sizeof(band_name), "H (L%d)", stats_dwt_levels - s + 1);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f\n",
|
||||
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f %10.3f\n",
|
||||
band_name, mid_accumulators[s].count,
|
||||
stats.min, stats.q1, stats.median, stats.q3, stats.max);
|
||||
stats.min, stats.q1, stats.median, stats.q3, stats.max, stats.lambda);
|
||||
}
|
||||
|
||||
// Print Mid channel histograms
|
||||
@@ -526,9 +544,9 @@ void tad32_print_statistics(void) {
|
||||
|
||||
// Print Side channel statistics
|
||||
fprintf(stderr, "\nSide Channel:\n");
|
||||
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s\n",
|
||||
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max");
|
||||
fprintf(stderr, "--------------------------------------------------------------------------------\n");
|
||||
fprintf(stderr, "%-12s %10s %10s %10s %10s %10s %10s %10s\n",
|
||||
"Subband", "Samples", "Min", "Q1", "Median", "Q3", "Max", "Lambda");
|
||||
fprintf(stderr, "----------------------------------------------------------------------------------------\n");
|
||||
|
||||
for (int s = 0; s < num_subbands; s++) {
|
||||
CoeffStats stats;
|
||||
@@ -541,9 +559,9 @@ void tad32_print_statistics(void) {
|
||||
snprintf(band_name, sizeof(band_name), "H (L%d)", stats_dwt_levels - s + 1);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f\n",
|
||||
fprintf(stderr, "%-12s %10zu %10.3f %10.3f %10.3f %10.3f %10.3f %10.3f\n",
|
||||
band_name, side_accumulators[s].count,
|
||||
stats.min, stats.q1, stats.median, stats.q3, stats.max);
|
||||
stats.min, stats.q1, stats.median, stats.q3, stats.max, stats.lambda);
|
||||
}
|
||||
|
||||
// Print Side channel histograms
|
||||
@@ -576,46 +594,12 @@ void tad32_free_statistics(void) {
|
||||
stats_initialized = 0;
|
||||
}
|
||||
|
||||
static size_t encode_sigmap_2bit(const int16_t *values, size_t count, uint8_t *output) {
|
||||
size_t map_bytes = (count * 2 + 7) / 8;
|
||||
uint8_t *map = output;
|
||||
memset(map, 0, map_bytes);
|
||||
|
||||
uint8_t *write_ptr = output + map_bytes;
|
||||
int16_t *value_ptr = (int16_t*)write_ptr;
|
||||
uint32_t other_count = 0;
|
||||
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
int16_t val = values[i];
|
||||
uint8_t code;
|
||||
|
||||
if (val == 0) code = 0; // 00
|
||||
else if (val == 1) code = 1; // 01
|
||||
else if (val == -1) code = 2; // 10
|
||||
else {
|
||||
code = 3; // 11
|
||||
value_ptr[other_count++] = val;
|
||||
}
|
||||
|
||||
size_t bit_pos = i * 2;
|
||||
size_t byte_idx = bit_pos / 8;
|
||||
size_t bit_offset = bit_pos % 8;
|
||||
|
||||
map[byte_idx] |= (code << bit_offset);
|
||||
if (bit_offset == 7 && byte_idx + 1 < map_bytes) {
|
||||
map[byte_idx + 1] |= (code >> 1);
|
||||
}
|
||||
}
|
||||
|
||||
return map_bytes + other_count * sizeof(int16_t);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Public API: Chunk Encoding
|
||||
//=============================================================================
|
||||
|
||||
size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int quality,
|
||||
int use_zstd, uint8_t *output) {
|
||||
size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples,
|
||||
int quant_bits, int use_zstd, uint8_t *output) {
|
||||
// Calculate DWT levels from chunk size
|
||||
int dwt_levels = calculate_dwt_levels(num_samples);
|
||||
if (dwt_levels < 0) {
|
||||
@@ -670,8 +654,8 @@ size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int qua
|
||||
}
|
||||
|
||||
// Step 4: Quantize with frequency-dependent weights and dead zone
|
||||
quantize_dwt_coefficients(dwt_mid, quant_mid, num_samples, quality, 1, num_samples, dwt_levels, NULL);
|
||||
quantize_dwt_coefficients(dwt_side, quant_side, num_samples, quality, 1, num_samples, dwt_levels, NULL);
|
||||
quantize_dwt_coefficients(dwt_mid, quant_mid, num_samples, 1, num_samples, dwt_levels, quant_bits, NULL);
|
||||
quantize_dwt_coefficients(dwt_side, quant_side, num_samples, 1, num_samples, dwt_levels, quant_bits, NULL);
|
||||
|
||||
// Step 5: Encode with 2-bit significance map (32-bit version)
|
||||
uint8_t *temp_buffer = malloc(num_samples * 4 * sizeof(int32_t));
|
||||
@@ -683,9 +667,13 @@ size_t tad32_encode_chunk(const float *pcm32_stereo, size_t num_samples, int qua
|
||||
// Step 6: Optional Zstd compression
|
||||
uint8_t *write_ptr = output;
|
||||
|
||||
// Write chunk header
|
||||
*((uint16_t*)write_ptr) = (uint16_t)num_samples;
|
||||
write_ptr += sizeof(uint16_t);
|
||||
|
||||
*write_ptr = (uint8_t)quant_bits;
|
||||
write_ptr += sizeof(uint8_t);
|
||||
|
||||
uint32_t *payload_size_ptr = (uint32_t*)write_ptr;
|
||||
write_ptr += sizeof(uint32_t);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user