TAD: now processing entirely in float

This commit is contained in:
minjaesong
2025-10-24 05:31:38 +09:00
parent a9319fd812
commit 9dc71095a0
5 changed files with 139 additions and 319 deletions

View File

@@ -12,6 +12,7 @@
#define DECODER_VENDOR_STRING "Decoder-TAD 20251023"
// TAD format constants (must match encoder)
#define TAD_COEFF_SCALAR 1024.0f
#define TAD_DEFAULT_CHUNK_SIZE 32768
#define TAD_MIN_CHUNK_SIZE 1024
#define TAD_SAMPLE_RATE 32000
@@ -148,22 +149,58 @@ static void dwt_haar_inverse_multilevel(float *data, int length, int levels) {
// M/S Stereo Correlation (inverse of decorrelation)
//=============================================================================
static void ms_correlate(const int8_t *mid, const int8_t *side, uint8_t *left, uint8_t *right, size_t count) {
// Uniform random in [0, 1)
static inline float frand01(void) {
return (float)rand() / ((float)RAND_MAX + 1.0f);
}
// TPDF noise in [-1, +1)
static inline float tpdf1(void) {
return (frand01() - frand01());
}
static void ms_correlate(const float *mid, const float *side, uint8_t *left, uint8_t *right, size_t count, float dither_error[2][2]) {
const float b1 = 1.5f; // 1st feedback coefficient
const float b2 = -0.75f; // 2nd feedback coefficient
const float scale = 127.5f;
const float bias = 128.0f;
for (size_t i = 0; i < count; i++) {
// L = M + S, R = M - S
int32_t m = mid[i];
int32_t s = side[i];
int32_t l = m + s;
int32_t r = m - s;
// Decode M/S → L/R
float m = mid[i];
float s = side[i];
float l = FCLAMP(m + s, -1.0f, 1.0f);
float r = FCLAMP(m - s, -1.0f, 1.0f);
// Clamp to [-128, 127] then convert to unsigned [0, 255]
if (l < -128) l = -128;
if (l > 127) l = 127;
if (r < -128) r = -128;
if (r > 127) r = 127;
// --- LEFT channel ---
float feedbackL = b1 * dither_error[0][0] + b2 * dither_error[0][1];
float ditherL = 0.5f * tpdf1(); // ±0.5 LSB TPDF
float shapedL = l + feedbackL + ditherL / scale;
shapedL = FCLAMP(shapedL, -1.0f, 1.0f);
left[i] = (uint8_t)(l + 128);
right[i] = (uint8_t)(r + 128);
int qL = (int)lrintf(shapedL * scale);
if (qL < -128) qL = -128;
else if (qL > 127) qL = 127;
left[i] = (uint8_t)(qL + bias);
float qerrL = shapedL - (float)qL / scale;
dither_error[0][1] = dither_error[0][0]; // shift history
dither_error[0][0] = qerrL;
// --- RIGHT channel ---
float feedbackR = b1 * dither_error[1][0] + b2 * dither_error[1][1];
float ditherR = 0.5f * tpdf1();
float shapedR = r + feedbackR + ditherR / scale;
shapedR = FCLAMP(shapedR, -1.0f, 1.0f);
int qR = (int)lrintf(shapedR * scale);
if (qR < -128) qR = -128;
else if (qR > 127) qR = 127;
right[i] = (uint8_t)(qR + bias);
float qerrR = shapedR - (float)qR / scale;
dither_error[1][1] = dither_error[1][0];
dither_error[1][0] = qerrR;
}
}
@@ -188,11 +225,10 @@ static void get_quantization_weights(int quality, int dwt_levels, float *weights
/*12*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f},
/*13*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f},
/*14*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f},
/*15*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f},
/*16*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f}
/*15*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f}
};
float quality_scale = 1.0f + FCLAMP((3 - quality) * 0.5f, 0.0f, 1000.0f);
float quality_scale = 4.0f + FCLAMP((3 - quality) * 0.5f, 0.0f, 1000.0f);
for (int i = 0; i < dwt_levels; i++) {
weights[i] = FCLAMP(base_weights[dwt_levels][i] * quality_scale, 1.0f, 1000.0f);
@@ -227,7 +263,7 @@ static void dequantize_dwt_coefficients(const int16_t *quantized, float *coeffs,
if (weight_idx >= dwt_levels) weight_idx = dwt_levels - 1;
float weight = weights[weight_idx];
coeffs[i] = (float)quantized[i] * weight;
coeffs[i] = (float)quantized[i] * weight / TAD_COEFF_SCALAR;
}
free(sideband_starts);
@@ -237,29 +273,6 @@ static void dequantize_dwt_coefficients(const int16_t *quantized, float *coeffs,
// Significance Map Decoding
//=============================================================================
static size_t decode_sigmap_1bit(const uint8_t *input, int16_t *values, size_t count) {
size_t map_bytes = (count + 7) / 8;
const uint8_t *map = input;
const uint8_t *read_ptr = input + map_bytes;
uint32_t nonzero_count = *((const uint32_t*)read_ptr);
read_ptr += sizeof(uint32_t);
const int16_t *value_ptr = (const int16_t*)read_ptr;
uint32_t value_idx = 0;
// Reconstruct values
for (size_t i = 0; i < count; i++) {
if (map[i / 8] & (1 << (i % 8))) {
values[i] = value_ptr[value_idx++];
} else {
values[i] = 0;
}
}
return map_bytes + sizeof(uint32_t) + nonzero_count * sizeof(int16_t);
}
static size_t decode_sigmap_2bit(const uint8_t *input, int16_t *values, size_t count) {
size_t map_bytes = (count * 2 + 7) / 8;
const uint8_t *map = input;
@@ -291,48 +304,6 @@ static size_t decode_sigmap_2bit(const uint8_t *input, int16_t *values, size_t c
return map_bytes + other_idx * sizeof(int16_t);
}
static size_t decode_sigmap_rle(const uint8_t *input, int16_t *values, size_t count) {
const uint8_t *read_ptr = input;
uint32_t run_count = *((const uint32_t*)read_ptr);
read_ptr += sizeof(uint32_t);
size_t value_idx = 0;
for (uint32_t run = 0; run < run_count; run++) {
// Decode zero run length (varint)
uint32_t zero_run = 0;
int shift = 0;
uint8_t byte;
do {
byte = *read_ptr++;
zero_run |= ((uint32_t)(byte & 0x7F) << shift);
shift += 7;
} while (byte & 0x80);
// Fill zeros
for (uint32_t i = 0; i < zero_run && value_idx < count; i++) {
values[value_idx++] = 0;
}
// Read non-zero value
int16_t val = *((const int16_t*)read_ptr);
read_ptr += sizeof(int16_t);
if (value_idx < count && val != 0) {
values[value_idx++] = val;
}
}
// Fill remaining with zeros
while (value_idx < count) {
values[value_idx++] = 0;
}
return read_ptr - input;
}
//=============================================================================
// Chunk Decoding
//=============================================================================
@@ -381,8 +352,6 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
int16_t *quant_side = malloc(sample_count * sizeof(int16_t));
float *dwt_mid = malloc(sample_count * sizeof(float));
float *dwt_side = malloc(sample_count * sizeof(float));
int8_t *pcm8_mid = malloc(sample_count * sizeof(int8_t));
int8_t *pcm8_side = malloc(sample_count * sizeof(int8_t));
uint8_t *pcm8_left = malloc(sample_count * sizeof(uint8_t));
uint8_t *pcm8_right = malloc(sample_count * sizeof(uint8_t));
@@ -401,23 +370,10 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
dwt_haar_inverse_multilevel(dwt_mid, sample_count, dwt_levels);
dwt_haar_inverse_multilevel(dwt_side, sample_count, dwt_levels);
// Convert to signed PCM8
for (size_t i = 0; i < sample_count; i++) {
float m = dwt_mid[i];
float s = dwt_side[i];
// Clamp and round
if (m < -128.0f) m = -128.0f;
if (m > 127.0f) m = 127.0f;
if (s < -128.0f) s = -128.0f;
if (s > 127.0f) s = 127.0f;
pcm8_mid[i] = (int8_t)roundf(m);
pcm8_side[i] = (int8_t)roundf(s);
}
float err[2][2] = {{0,0},{0,0}};
// M/S to L/R correlation
ms_correlate(pcm8_mid, pcm8_side, pcm8_left, pcm8_right, sample_count);
ms_correlate(dwt_mid, dwt_side, pcm8_left, pcm8_right, sample_count, err);
// Interleave stereo output (PCMu8)
for (size_t i = 0; i < sample_count; i++) {
@@ -427,7 +383,7 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
// Cleanup
free(quant_mid); free(quant_side); free(dwt_mid); free(dwt_side);
free(pcm8_mid); free(pcm8_side); free(pcm8_left); free(pcm8_right);
free(pcm8_left); free(pcm8_right);
if (decompressed) free(decompressed);
return 0;
@@ -442,7 +398,7 @@ static void print_usage(const char *prog_name) {
printf("Options:\n");
printf(" -i <file> Input TAD file\n");
printf(" -o <file> Output PCMu8 file (raw 8-bit unsigned stereo @ 32kHz)\n");
printf(" -q <0-5> Quality level used during encoding (default: 2)\n");
printf(" -q <0-5> Quality level used during encoding (default: 3)\n");
printf(" -v Verbose output\n");
printf(" -h, --help Show this help\n");
printf("\nVersion: %s\n", DECODER_VENDOR_STRING);
@@ -453,7 +409,7 @@ static void print_usage(const char *prog_name) {
int main(int argc, char *argv[]) {
char *input_file = NULL;
char *output_file = NULL;
int quality = 2; // Must match encoder quality
int quality = 3; // Must match encoder quality
int verbose = 0;
int opt;