TAD: more wip

This commit is contained in:
minjaesong
2025-10-26 02:30:51 +09:00
parent 52f25f7d04
commit 9fcb7fc95c
4 changed files with 563 additions and 37 deletions

View File

@@ -9,10 +9,15 @@
#include <zstd.h>
#include <getopt.h>
#define DECODER_VENDOR_STRING "Decoder-TAD 20251023"
#define DECODER_VENDOR_STRING "Decoder-TAD 20251026"
// TAD format constants (must match encoder)
#define TAD_COEFF_SCALAR 1024.0f
#undef TAD32_COEFF_SCALARS
// Coefficient scalars for each subband (CDF 9/7 with 9 decomposition levels)
// Index 0 = LL band, Index 1-9 = H bands (L9 to L1)
static const float TAD32_COEFF_SCALARS[] = {64.0f, 45.255f, 32.0f, 22.627f, 16.0f, 11.314f, 8.0f, 5.657f, 4.0f, 2.828f};
#define TAD_DEFAULT_CHUNK_SIZE 32768
#define TAD_MIN_CHUNK_SIZE 1024
#define TAD_SAMPLE_RATE 32000
@@ -33,7 +38,7 @@ static inline float FCLAMP(float x, float min, float max) {
// Calculate DWT levels from chunk size (must be power of 2, >= 1024)
static int calculate_dwt_levels(int chunk_size) {
if (chunk_size < TAD_MIN_CHUNK_SIZE) {
/*if (chunk_size < TAD_MIN_CHUNK_SIZE) {
fprintf(stderr, "Error: Chunk size %d is below minimum %d\n", chunk_size, TAD_MIN_CHUNK_SIZE);
return -1;
}
@@ -45,7 +50,8 @@ static int calculate_dwt_levels(int chunk_size) {
size >>= 1;
levels++;
}
return levels - 2;
return levels - 2;*/
return 9;
}
//=============================================================================
@@ -71,6 +77,91 @@ static void dwt_haar_inverse_1d(float *data, int length) {
free(temp);
}
// 9/7 inverse DWT (from TSVM Kotlin code)
static void dwt_97_inverse_1d(float *data, int length) {
if (length < 2) return;
float *temp = malloc(length * sizeof(float));
int half = (length + 1) / 2;
// Split into low and high frequency components (matching TSVM layout)
for (int i = 0; i < half; i++) {
temp[i] = data[i]; // Low-pass coefficients (first half)
}
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
temp[half + i] = data[half + i]; // High-pass coefficients (second half)
}
}
// 9/7 inverse lifting coefficients from TSVM
const float alpha = -1.586134342f;
const float beta = -0.052980118f;
const float gamma = 0.882911076f;
const float delta = 0.443506852f;
const float K = 1.230174105f;
// Step 1: Undo scaling
for (int i = 0; i < half; i++) {
temp[i] /= K; // Low-pass coefficients
}
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
temp[half + i] *= K; // High-pass coefficients
}
}
// Step 2: Undo δ update
for (int i = 0; i < half; i++) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr;
temp[i] -= delta * (d_curr + d_prev);
}
// Step 3: Undo γ predict
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
float s_curr = temp[i];
float s_next = (i + 1 < half) ? temp[i + 1] : s_curr;
temp[half + i] -= gamma * (s_curr + s_next);
}
}
// Step 4: Undo β update
for (int i = 0; i < half; i++) {
float d_curr = (half + i < length) ? temp[half + i] : 0.0f;
float d_prev = (i > 0 && half + i - 1 < length) ? temp[half + i - 1] : d_curr;
temp[i] -= beta * (d_curr + d_prev);
}
// Step 5: Undo α predict
for (int i = 0; i < length / 2; i++) {
if (half + i < length) {
float s_curr = temp[i];
float s_next = (i + 1 < half) ? temp[i + 1] : s_curr;
temp[half + i] -= alpha * (s_curr + s_next);
}
}
// Reconstruction - interleave low and high pass
for (int i = 0; i < length; i++) {
if (i % 2 == 0) {
// Even positions: low-pass coefficients
data[i] = temp[i / 2];
} else {
// Odd positions: high-pass coefficients
int idx = i / 2;
if (half + idx < length) {
data[i] = temp[half + idx];
} else {
data[i] = 0.0f;
}
}
}
free(temp);
}
// Inverse 1D transform of Four-point interpolating Deslauriers-Dubuc (DD-4)
static void dwt_dd4_inverse_1d(float *data, int length) {
if (length < 2) return;
@@ -141,7 +232,8 @@ static void dwt_haar_inverse_multilevel(float *data, int length, int levels) {
current_length *= 2; // MULTIPLY FIRST: 128→256, 256→512, ..., 16384→32768
if (current_length > length) current_length = length;
// dwt_haar_inverse_1d(data, current_length); // THEN apply inverse
dwt_dd4_inverse_1d(data, current_length); // THEN apply inverse
// dwt_dd4_inverse_1d(data, current_length); // THEN apply inverse
dwt_97_inverse_1d(data, current_length); // THEN apply inverse
}
}
@@ -159,23 +251,43 @@ static inline float tpdf1(void) {
return (frand01() - frand01());
}
static void ms_correlate(const float *mid, const float *side, uint8_t *left, uint8_t *right, size_t count, float dither_error[2][2]) {
static void ms_correlate(const float *mid, const float *side, float *left, float *right, size_t count) {
for (size_t i = 0; i < count; i++) {
// Decode M/S → L/R
float m = mid[i];
float s = side[i];
left[i] = FCLAMP((m + s) * 1.7321f, -1.0f, 1.0f);
right[i] = FCLAMP((m - s) * 1.7321f, -1.0f, 1.0f);
}
}
static float signum(float x) {
if (x > 0.0f) return 1.0f;
if (x < 0.0f) return -1.0f;
return 0.0f;
}
static void expand_gamma(float *left, float *right, size_t count) {
for (size_t i = 0; i < count; i++) {
// decode(y) = sign(y) * |y|^(1/γ) where γ=0.5
float x = left[i]; float a = fabsf(x);
left[i] = signum(x) * a * a;
float y = right[i]; float b = fabsf(y);
right[i] = signum(y) * b * b;
}
}
static void pcm32f_to_pcm8(const float *fleft, const float *fright, uint8_t *left, uint8_t *right, size_t count, float dither_error[2][2]) {
const float b1 = 1.5f; // 1st feedback coefficient
const float b2 = -0.75f; // 2nd feedback coefficient
const float scale = 127.5f;
const float bias = 128.0f;
for (size_t i = 0; i < count; i++) {
// Decode M/S → L/R
float m = mid[i];
float s = side[i];
float l = FCLAMP(m + s, -1.0f, 1.0f);
float r = FCLAMP(m - s, -1.0f, 1.0f);
// --- LEFT channel ---
float feedbackL = b1 * dither_error[0][0] + b2 * dither_error[0][1];
float ditherL = 0.5f * tpdf1(); // ±0.5 LSB TPDF
float shapedL = l + feedbackL + ditherL / scale;
float shapedL = fleft[i] + feedbackL + ditherL / scale;
shapedL = FCLAMP(shapedL, -1.0f, 1.0f);
int qL = (int)lrintf(shapedL * scale);
@@ -190,7 +302,7 @@ static void ms_correlate(const float *mid, const float *side, uint8_t *left, uin
// --- RIGHT channel ---
float feedbackR = b1 * dither_error[1][0] + b2 * dither_error[1][1];
float ditherR = 0.5f * tpdf1();
float shapedR = r + feedbackR + ditherR / scale;
float shapedR = fright[i] + feedbackR + ditherR / scale;
shapedR = FCLAMP(shapedR, -1.0f, 1.0f);
int qR = (int)lrintf(shapedR * scale);
@@ -228,13 +340,15 @@ static void get_quantization_weights(int quality, int dwt_levels, float *weights
/*15*/{0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f}
};
float quality_scale = 4.0f + FCLAMP((3 - quality) * 0.5f, 0.0f, 1000.0f);
float quality_scale = 1.0f * (1.0f + FCLAMP((5 - quality) * 0.5f, 0.0f, 1000.0f));
for (int i = 0; i < dwt_levels; i++) {
weights[i] = FCLAMP(base_weights[dwt_levels][i] * quality_scale, 1.0f, 1000.0f);
weights[i] = 1.0f;//base_weights[dwt_levels][i] * quality_scale;
}
}
#define QUANT_STEPS 8.0f // 64 -> [-64..64] -> 7 bits for LL
static void dequantize_dwt_coefficients(const int16_t *quantized, float *coeffs, size_t count, int quality, int chunk_size, int dwt_levels) {
float weights[16];
get_quantization_weights(quality, dwt_levels, weights);
@@ -263,7 +377,7 @@ static void dequantize_dwt_coefficients(const int16_t *quantized, float *coeffs,
if (weight_idx >= dwt_levels) weight_idx = dwt_levels - 1;
float weight = weights[weight_idx];
coeffs[i] = (float)quantized[i] * weight / TAD_COEFF_SCALAR;
coeffs[i] = ((float)quantized[i] * TAD32_COEFF_SCALARS[sideband]) / (QUANT_STEPS * weight);
}
free(sideband_starts);
@@ -352,6 +466,8 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
int16_t *quant_side = malloc(sample_count * sizeof(int16_t));
float *dwt_mid = malloc(sample_count * sizeof(float));
float *dwt_side = malloc(sample_count * sizeof(float));
float *pcm32_left = malloc(sample_count * sizeof(float));
float *pcm32_right = malloc(sample_count * sizeof(float));
uint8_t *pcm8_left = malloc(sample_count * sizeof(uint8_t));
uint8_t *pcm8_right = malloc(sample_count * sizeof(uint8_t));
@@ -373,7 +489,13 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
float err[2][2] = {{0,0},{0,0}};
// M/S to L/R correlation
ms_correlate(dwt_mid, dwt_side, pcm8_left, pcm8_right, sample_count, err);
ms_correlate(dwt_mid, dwt_side, pcm32_left, pcm32_right, sample_count);
// expand dynamic range
// expand_gamma(pcm32_left, pcm32_right, sample_count);
// dither to 8-bit
pcm32f_to_pcm8(pcm32_left, pcm32_right, pcm8_left, pcm8_right, sample_count, err);
// Interleave stereo output (PCMu8)
for (size_t i = 0; i < sample_count; i++) {
@@ -383,7 +505,7 @@ static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_
// Cleanup
free(quant_mid); free(quant_side); free(dwt_mid); free(dwt_side);
free(pcm8_left); free(pcm8_right);
free(pcm32_left); free(pcm32_right); free(pcm8_left); free(pcm8_right);
if (decompressed) free(decompressed);
return 0;