TAV and TAD now shares same code for encoding and decoding

This commit is contained in:
minjaesong
2025-11-10 10:35:17 +09:00
parent 28e9a88f8d
commit 0e6f2162c8
5 changed files with 99 additions and 113 deletions

View File

@@ -13,6 +13,7 @@
#include <sys/wait.h>
#include <getopt.h>
#include <signal.h>
#include "decoder_tad.h" // Shared TAD decoder library
#define DECODER_VENDOR_STRING "Decoder-TAV 20251103 (ffv1+pcmu8)"
@@ -502,6 +503,48 @@ static void expand_mu_law(float *left, float *right, size_t count) {
}
}
//=============================================================================
// De-emphasis Filter (TAD)
//=============================================================================
static void calculate_deemphasis_coeffs(float *b0, float *b1, float *a1) {
// De-emphasis factor (must match encoder pre-emphasis alpha=0.5)
const float alpha = 0.5f;
*b0 = 1.0f;
*b1 = 0.0f; // No feedforward delay
*a1 = -alpha; // NEGATIVE because equation has minus sign: y = x - a1*prev_y
}
static void apply_deemphasis(float *left, float *right, size_t count) {
// Static state variables - persistent across chunks to prevent discontinuities
static float prev_x_l = 0.0f;
static float prev_y_l = 0.0f;
static float prev_x_r = 0.0f;
static float prev_y_r = 0.0f;
float b0, b1, a1;
calculate_deemphasis_coeffs(&b0, &b1, &a1);
// Left channel - use persistent state
for (size_t i = 0; i < count; i++) {
float x = left[i];
float y = b0 * x + b1 * prev_x_l - a1 * prev_y_l;
left[i] = y;
prev_x_l = x;
prev_y_l = y;
}
// Right channel - use persistent state
for (size_t i = 0; i < count; i++) {
float x = right[i];
float y = b0 * x + b1 * prev_x_r - a1 * prev_y_r;
right[i] = y;
prev_x_r = x;
prev_y_r = y;
}
}
static void pcm32f_to_pcm8(const float *fleft, const float *fright, uint8_t *left, uint8_t *right, size_t count, float dither_error[2][2]) {
const float b1 = 1.5f; // 1st feedback coefficient
const float b2 = -0.75f; // 2nd feedback coefficient
@@ -697,110 +740,9 @@ static void dequantize_dwt_coefficients(const int8_t *quantized, float *coeffs,
}
//=============================================================================
// Chunk Decoding
//=============================================================================
static int decode_chunk(const uint8_t *input, size_t input_size, uint8_t *pcmu8_stereo,
size_t *bytes_consumed, size_t *samples_decoded) {
const uint8_t *read_ptr = input;
// Read chunk header
uint16_t sample_count = *((const uint16_t*)read_ptr);
read_ptr += sizeof(uint16_t);
uint8_t max_index = *read_ptr;
read_ptr += sizeof(uint8_t);
uint32_t payload_size = *((const uint32_t*)read_ptr);
read_ptr += sizeof(uint32_t);
// Calculate DWT levels from sample count
int dwt_levels = calculate_dwt_levels(sample_count);
if (dwt_levels < 0) {
fprintf(stderr, "Error: Invalid sample count %u\n", sample_count);
return -1;
}
// Decompress if needed
const uint8_t *payload;
uint8_t *decompressed = NULL;
// Estimate decompressed size (generous upper bound)
size_t decompressed_size = sample_count * 4 * sizeof(int8_t);
decompressed = malloc(decompressed_size);
size_t actual_size = ZSTD_decompress(decompressed, decompressed_size, read_ptr, payload_size);
if (ZSTD_isError(actual_size)) {
fprintf(stderr, "Error: Zstd decompression failed: %s\n", ZSTD_getErrorName(actual_size));
free(decompressed);
return -1;
}
read_ptr += payload_size;
*bytes_consumed = read_ptr - input;
*samples_decoded = sample_count;
// Allocate working buffers
int8_t *quant_mid = malloc(sample_count * sizeof(int8_t));
int8_t *quant_side = malloc(sample_count * sizeof(int8_t));
float *dwt_mid = malloc(sample_count * sizeof(float));
float *dwt_side = malloc(sample_count * sizeof(float));
float *pcm32_left = malloc(sample_count * sizeof(float));
float *pcm32_right = malloc(sample_count * sizeof(float));
uint8_t *pcm8_left = malloc(sample_count * sizeof(uint8_t));
uint8_t *pcm8_right = malloc(sample_count * sizeof(uint8_t));
// Separate Mid/Side
memcpy(quant_mid, decompressed, sample_count);
memcpy(quant_side, decompressed + sample_count, sample_count);
// Debug: Check if we have non-zero coefficients
// static int debug_coeff_count = 0;
// if (debug_coeff_count < 3) {
// int nonzero_mid = 0, nonzero_side = 0;
// for (int i = 0; i < sample_count; i++) {
// if (quant_mid[i] != 0) nonzero_mid++;
// if (quant_side[i] != 0) nonzero_side++;
// }
// debug_coeff_count++;
// }
// Dequantize with quantiser scaling and spectral interpolation
// Use quantiser_scale = 1.0f for baseline (must match encoder)
float quantiser_scale = 1.0f;
dequantize_dwt_coefficients(quant_mid, dwt_mid, sample_count, sample_count, dwt_levels, max_index, quantiser_scale);
dequantize_dwt_coefficients(quant_side, dwt_side, sample_count, sample_count, dwt_levels, max_index, quantiser_scale);
// Inverse DWT
dwt_inverse_multilevel(dwt_mid, sample_count, dwt_levels);
dwt_inverse_multilevel(dwt_side, sample_count, dwt_levels);
float err[2][2] = {{0,0},{0,0}};
// M/S to L/R correlation
ms_correlate(dwt_mid, dwt_side, pcm32_left, pcm32_right, sample_count);
// expand dynamic range
expand_gamma(pcm32_left, pcm32_right, sample_count);
// dither to 8-bit
pcm32f_to_pcm8(pcm32_left, pcm32_right, pcm8_left, pcm8_right, sample_count, err);
// Interleave stereo output (PCMu8)
for (size_t i = 0; i < sample_count; i++) {
pcmu8_stereo[i * 2] = pcm8_left[i];
pcmu8_stereo[i * 2 + 1] = pcm8_right[i];
}
// Cleanup
free(quant_mid); free(quant_side); free(dwt_mid); free(dwt_side);
free(pcm32_left); free(pcm32_right); free(pcm8_left); free(pcm8_right);
if (decompressed) free(decompressed);
return 0;
}
// Chunk Decoding (TAD Audio)
// NOTE: TAD decoding now uses shared tad32_decode_chunk() from decoder_tad.h
// This ensures decoder_tav and decoder_tad use identical decoding logic
//=============================================================================
// Significance Map Postprocessing (matches TSVM exactly)
//=============================================================================
@@ -2075,7 +2017,7 @@ static int extract_audio_to_wav(const char *input_file, const char *wav_file, in
// Decode TAD
uint8_t *pcmu8_output = malloc(sample_count_chunk * 2);
size_t bytes_consumed, samples_decoded;
int decode_result = decode_chunk(tad_chunk, tad_chunk_size,
int decode_result = tad32_decode_chunk(tad_chunk, tad_chunk_size,
pcmu8_output, &bytes_consumed, &samples_decoded);
if (decode_result >= 0) {