TAD: bringing coeff weight back

This commit is contained in:
minjaesong
2025-10-29 01:47:14 +09:00
parent 86864c4b7a
commit f06f339d99
4 changed files with 93 additions and 99 deletions

View File

@@ -18,6 +18,22 @@
// Index 0 = LL band, Index 1-9 = H bands (L9 to L1)
static const float TAD32_COEFF_SCALARS[] = {64.0f, 45.255f, 32.0f, 22.627f, 16.0f, 11.314f, 8.0f, 5.657f, 4.0f, 2.828f};
// Base quantiser weight table (10 subbands: LL + 9 H bands)
// Linearly spaced from 1.0 (LL) to 2.0 (H9)
// These weights are multiplied by quantiser_scale during dequantization
static const float BASE_QUANTISER_WEIGHTS[] = {
1.0f, // LL (L9) - finest preservation
1.111f, // H (L9)
1.222f, // H (L8)
1.333f, // H (L7)
1.444f, // H (L6)
1.556f, // H (L5)
1.667f, // H (L4)
1.778f, // H (L3)
1.889f, // H (L2)
2.0f // H (L1) - coarsest quantization
};
#define TAD_DEFAULT_CHUNK_SIZE 32768
#define TAD_MIN_CHUNK_SIZE 1024
#define TAD_SAMPLE_RATE 32000
@@ -333,11 +349,11 @@ static void pcm32f_to_pcm8(const float *fleft, const float *fright, uint8_t *lef
//=============================================================================
#define LAMBDA_FIXED 5.0f
#define LAMBDA_FIXED 6.0f
// Lambda-based decompanding decoder (inverse of Laplacian CDF-based encoder)
// Converts quantized index back to normalized float in [-1, 1]
static float lambda_decompanding(int16_t quant_val, int max_index) {
static float lambda_decompanding(int8_t quant_val, int max_index) {
// Handle zero
if (quant_val == 0) {
return 0.0f;
@@ -366,7 +382,7 @@ static float lambda_decompanding(int16_t quant_val, int max_index) {
return sign * abs_val;
}
static void dequantize_dwt_coefficients(const int16_t *quantized, float *coeffs, size_t count, int chunk_size, int dwt_levels, int max_index) {
static void dequantize_dwt_coefficients(const int8_t *quantized, float *coeffs, size_t count, int chunk_size, int dwt_levels, int max_index, float quantiser_scale) {
// Calculate sideband boundaries dynamically
int first_band_size = chunk_size >> dwt_levels;
@@ -390,63 +406,14 @@ static void dequantize_dwt_coefficients(const int16_t *quantized, float *coeffs,
// Decode using lambda companding
float normalized_val = lambda_decompanding(quantized[i], max_index);
// Denormalize using the subband scalar
coeffs[i] = normalized_val * TAD32_COEFF_SCALARS[sideband];
// Denormalize using the subband scalar and apply base weight + quantiser scaling
float weight = BASE_QUANTISER_WEIGHTS[sideband] * quantiser_scale;
coeffs[i] = normalized_val * TAD32_COEFF_SCALARS[sideband] * weight;
}
free(sideband_starts);
}
//=============================================================================
// Bitplane Decoding with Delta Prediction
//=============================================================================
// Pure bitplane decoding with delta prediction: each coefficient uses exactly (quant_bits + 1) bits
// Bit layout: 1 sign bit + quant_bits magnitude bits
// Sign bit: 0 = positive/zero, 1 = negative
// Magnitude: unsigned value [0, 2^quant_bits - 1]
// Delta prediction: plane[i] ^= plane[i-1] (reversed by same operation)
static size_t decode_bitplanes(const uint8_t *input, int16_t *values, size_t count, int max_index) {
int bits_per_coeff = ((int)ceilf(log2f(max_index))) + 1; // 1 sign bit + quant_bits magnitude bits
size_t plane_bytes = (count + 7) / 8; // Bytes needed for one bitplane
size_t input_bytes = plane_bytes * bits_per_coeff;
// Allocate temporary bitplanes
uint8_t **bitplanes = malloc(bits_per_coeff * sizeof(uint8_t*));
for (int plane = 0; plane < bits_per_coeff; plane++) {
bitplanes[plane] = malloc(plane_bytes);
memcpy(bitplanes[plane], input + (plane * plane_bytes), plane_bytes);
}
// Reconstruct coefficients from bitplanes
for (size_t i = 0; i < count; i++) {
size_t byte_idx = i / 8;
size_t bit_offset = i % 8;
// Read sign bit (plane 0)
uint8_t sign_bit = (bitplanes[0][byte_idx] >> bit_offset) & 0x01;
// Read magnitude bits (planes 1 to quant_bits)
uint16_t magnitude = 0;
for (int b = 0; b < bits_per_coeff - 1; b++) {
if (bitplanes[b + 1][byte_idx] & (1 << bit_offset)) {
magnitude |= (1 << b);
}
}
// Reconstruct signed value
values[i] = sign_bit ? -(int16_t)magnitude : (int16_t)magnitude;
}
// Free temporary bitplanes
for (int plane = 0; plane < bits_per_coeff; plane++) {
free(bitplanes[plane]);
}
free(bitplanes);
return input_bytes;
}
//=============================================================================
// Chunk Decoding
//=============================================================================
@@ -477,7 +444,7 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
uint8_t *decompressed = NULL;
// Estimate decompressed size (generous upper bound)
size_t decompressed_size = sample_count * 4 * sizeof(int16_t);
size_t decompressed_size = sample_count * 4 * sizeof(int8_t);
decompressed = malloc(decompressed_size);
size_t actual_size = ZSTD_decompress(decompressed, decompressed_size, read_ptr, payload_size);
@@ -488,15 +455,13 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
return -1;
}
payload = decompressed;
read_ptr += payload_size;
*bytes_consumed = read_ptr - input;
*samples_decoded = sample_count;
// Allocate working buffers
int16_t *quant_mid = malloc(sample_count * sizeof(int16_t));
int16_t *quant_side = malloc(sample_count * sizeof(int16_t));
int8_t *quant_mid = malloc(sample_count * sizeof(int8_t));
int8_t *quant_side = malloc(sample_count * sizeof(int8_t));
float *dwt_mid = malloc(sample_count * sizeof(float));
float *dwt_side = malloc(sample_count * sizeof(float));
float *pcm32_left = malloc(sample_count * sizeof(float));
@@ -504,16 +469,16 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
uint8_t *pcm8_left = malloc(sample_count * sizeof(uint8_t));
uint8_t *pcm8_right = malloc(sample_count * sizeof(uint8_t));
// Decode bitplanes
const uint8_t *payload_ptr = payload;
size_t mid_bytes, side_bytes;
// Separate Mid/Side
memcpy(quant_mid, decompressed, sample_count);
memcpy(quant_side, decompressed + sample_count, sample_count);
mid_bytes = decode_bitplanes(payload_ptr, quant_mid, sample_count, max_index);
side_bytes = decode_bitplanes(payload_ptr + mid_bytes, quant_side, sample_count, max_index);
// Dequantize
dequantize_dwt_coefficients(quant_mid, dwt_mid, sample_count, sample_count, dwt_levels, max_index);
dequantize_dwt_coefficients(quant_side, dwt_side, sample_count, sample_count, dwt_levels, max_index);
// Dequantize with quantiser scaling
// Use quantiser_scale = 1.0f for baseline (must match encoder)
float quantiser_scale = 1.0f;
dequantize_dwt_coefficients(quant_mid, dwt_mid, sample_count, sample_count, dwt_levels, max_index, quantiser_scale);
dequantize_dwt_coefficients(quant_side, dwt_side, sample_count, sample_count, dwt_levels, max_index, quantiser_scale);
// Inverse DWT
dwt_haar_inverse_multilevel(dwt_mid, sample_count, dwt_levels);