fix: TAV C decoder outputting wrong brightness

This commit is contained in:
minjaesong
2025-11-11 13:28:11 +09:00
parent 901f6b52b4
commit bd530f803f

View File

@@ -376,7 +376,8 @@ static void remove_grain_synthesis_decoder(float *coeffs, int width, int height,
const int subband_count = calculate_subband_layout(width, height, decomp_levels, subbands);
// Noise amplitude (matches Kotlin: qYGlobal.coerceAtMost(32) * 0.8f)
const float noise_amplitude = (q_y_global < 32 ? q_y_global : 32) * 0.25f; // somehow noise amplitude works differently than Kotlin?
// FIX: Was 0.25f, should be 0.8f to match Kotlin decoder
const float noise_amplitude = (q_y_global < 32 ? q_y_global : 32) * 0.8f;
// Process each subband (skip LL band which is level 0)
for (int s = 0; s < subband_count; s++) {
@@ -1390,23 +1391,23 @@ static void apply_inverse_dwt_multilevel(float *data, int width, int height, int
// Get temporal subband level for a given frame index in a GOP
static int get_temporal_subband_level(int frame_idx, int num_frames, int temporal_levels) {
// Match encoder logic exactly (encoder_tav.c:1487-1501)
// After temporal DWT with 2 levels:
// Frames 0...num_frames/(2^2) = tLL (temporal low-low, coarsest, level 0)
// Frames in first half but after tLL = tLH (level 1)
// Remaining frames = tH from first level (level 2, finest)
// Match encoder logic exactly (encoder_tav.c:1487-1506)
// After temporal DWT with N levels, frames are organised as:
// Frames 0...num_frames/(2^N) = tL...L (N low-passes, coarsest, level 0)
// Remaining frames are temporal high-pass subbands at various levels
const int frames_per_level0 = num_frames >> temporal_levels; // e.g., 16 >> 2 = 4, or 8 >> 2 = 2
if (frame_idx < frames_per_level0) {
return 0; // Coarsest temporal level (tLL)
} else if (frame_idx < (num_frames >> 1)) {
return 1; // First level high-pass (tLH)
} else {
return 2; // Finest level high-pass (tH from level 1)
// Check each level boundary from coarsest to finest
for (int level = 0; level < temporal_levels; level++) {
int frames_at_this_level = num_frames >> (temporal_levels - level);
if (frame_idx < frames_at_this_level) {
return level;
}
}
// Finest level (first decomposition's high-pass)
return temporal_levels;
}
// Calculate temporal quantiser scale for a given temporal subband level
static float get_temporal_quantiser_scale(int temporal_level) {
// Uses exponential scaling: 2^(BETA × level^KAPPA)
@@ -1841,9 +1842,11 @@ static void ycocg_r_to_rgb(float y, float co, float cg, uint8_t *r, uint8_t *g,
float b_val = tmp - co / 2.0f;
float r_val = co + b_val;
*r = CLAMP((int)(r_val + 0.5f), 0, 255);
*g = CLAMP((int)(g_val + 0.5f), 0, 255);
*b = CLAMP((int)(b_val + 0.5f), 0, 255);
// FIX: Use truncation (not rounding) to match Kotlin decoder behavior
// Kotlin uses .toInt() which truncates toward zero (floor for positive values)
*r = CLAMP((int)(r_val), 0, 255);
*g = CLAMP((int)(g_val), 0, 255);
*b = CLAMP((int)(b_val), 0, 255);
}
// ICtCp to RGB conversion (for even TAV versions)
@@ -2407,32 +2410,9 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint
const int is_perceptual = (decoder->header.version >= 5 && decoder->header.version <= 8);
const int is_ezbc = (decoder->header.entropy_coder == 1);
// Debug: Print decoder state
static int state_debug_once = 1;
if (state_debug_once) {
fprintf(stderr, "[DECODER-STATE] version=%d, entropy_coder=%d, is_perceptual=%d, is_ezbc=%d\n",
decoder->header.version, decoder->header.entropy_coder, is_perceptual, is_ezbc);
state_debug_once = 0;
}
if (is_ezbc && is_perceptual) {
// EZBC mode with perceptual quantisation: coefficients are normalised
// Need to dequantise using perceptual weights (same as twobit-map mode)
// Debug: Print quantised LL values before dequantisation
static int debug_count = 0;
if (debug_count < 1) {
fprintf(stderr, "[EZBC-DECODER-DEBUG] Quantised LL coefficients (9x7):\n");
for (int y = 0; y < 7 && y < decoder->header.height; y++) {
for (int x = 0; x < 9 && x < decoder->header.width; x++) {
int idx = y * decoder->header.width + x;
fprintf(stderr, "%6d ", quantised_y[idx]);
}
fprintf(stderr, "\n");
}
debug_count++;
}
dequantise_dwt_subbands_perceptual(0, qy, quantised_y, decoder->dwt_buffer_y,
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, qy, 0, decoder->frame_count);
@@ -2442,18 +2422,6 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint
dequantise_dwt_subbands_perceptual(0, qy, quantised_cg, decoder->dwt_buffer_cg,
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, qcg, 1, decoder->frame_count);
// Debug: Print dequantised LL values
if (debug_count <= 1) {
fprintf(stderr, "[EZBC-DECODER-DEBUG] Dequantised LL coefficients (9x7):\n");
for (int y = 0; y < 7 && y < decoder->header.height; y++) {
for (int x = 0; x < 9 && x < decoder->header.width; x++) {
int idx = y * decoder->header.width + x;
fprintf(stderr, "%7.0f ", decoder->dwt_buffer_y[idx]);
}
fprintf(stderr, "\n");
}
}
} else if (is_perceptual) {
dequantise_dwt_subbands_perceptual(0, qy, quantised_y, decoder->dwt_buffer_y,
decoder->header.width, decoder->header.height,
@@ -2957,26 +2925,33 @@ int main(int argc, char *argv[]) {
const int temporal_level = get_temporal_subband_level(t, gop_size, temporal_levels);
const float temporal_scale = get_temporal_quantiser_scale(temporal_level);
const float base_q_y = roundf(decoder->header.quantiser_y * temporal_scale);
const float base_q_co = roundf(decoder->header.quantiser_co * temporal_scale);
const float base_q_cg = roundf(decoder->header.quantiser_cg * temporal_scale);
// FIX: Use QLUT to convert header quantiser indices to actual values
const float base_q_y = roundf(QLUT[decoder->header.quantiser_y] * temporal_scale);
const float base_q_co = roundf(QLUT[decoder->header.quantiser_co] * temporal_scale);
const float base_q_cg = roundf(QLUT[decoder->header.quantiser_cg] * temporal_scale);
dequantise_dwt_subbands_perceptual(0, decoder->header.quantiser_y,
dequantise_dwt_subbands_perceptual(0, QLUT[decoder->header.quantiser_y],
quantised_gop[t][0], gop_y[t],
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, base_q_y, 0, decoder->frame_count + t);
dequantise_dwt_subbands_perceptual(0, decoder->header.quantiser_y,
dequantise_dwt_subbands_perceptual(0, QLUT[decoder->header.quantiser_y],
quantised_gop[t][1], gop_co[t],
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, base_q_co, 1, decoder->frame_count + t);
dequantise_dwt_subbands_perceptual(0, decoder->header.quantiser_y,
dequantise_dwt_subbands_perceptual(0, QLUT[decoder->header.quantiser_y],
quantised_gop[t][2], gop_cg[t],
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, base_q_cg, 1, decoder->frame_count + t);
if (t == 0 && verbose) {
fprintf(stderr, "[GOP-EZBC] Frame 0: Quantised LL[0]=%d, Dequantised LL[0]=%.1f, base_q_y=%.1f\n",
quantised_gop[t][0][0], gop_y[t][0], base_q_y);
// Debug: Check multiple LL values
fprintf(stderr, "[GOP-EZBC] Frame 0 after dequant:\n");
fprintf(stderr, " Quantised: LL[0]=%d, LL[1]=%d, LL[2]=%d\n",
quantised_gop[t][0][0], quantised_gop[t][0][1], quantised_gop[t][0][2]);
fprintf(stderr, " Dequantised: LL[0]=%.1f, LL[1]=%.1f, LL[2]=%.1f\n",
gop_y[t][0], gop_y[t][1], gop_y[t][2]);
fprintf(stderr, " base_q_y=%.1f, temporal_level=%d, temporal_scale=%.3f\n",
base_q_y, temporal_level, temporal_scale);
}
} else if (!is_ezbc) {
// Normal mode: multiply by quantiser
@@ -2984,20 +2959,21 @@ int main(int argc, char *argv[]) {
const float temporal_scale = get_temporal_quantiser_scale(temporal_level);
// CRITICAL: Must ROUND temporal quantiser to match encoder's roundf() behavior
const float base_q_y = roundf(decoder->header.quantiser_y * temporal_scale);
const float base_q_co = roundf(decoder->header.quantiser_co * temporal_scale);
const float base_q_cg = roundf(decoder->header.quantiser_cg * temporal_scale);
// FIX: Use QLUT to convert header quantiser indices to actual values
const float base_q_y = roundf(QLUT[decoder->header.quantiser_y] * temporal_scale);
const float base_q_co = roundf(QLUT[decoder->header.quantiser_co] * temporal_scale);
const float base_q_cg = roundf(QLUT[decoder->header.quantiser_cg] * temporal_scale);
if (is_perceptual) {
dequantise_dwt_subbands_perceptual(0, decoder->header.quantiser_y,
dequantise_dwt_subbands_perceptual(0, QLUT[decoder->header.quantiser_y],
quantised_gop[t][0], gop_y[t],
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, base_q_y, 0, decoder->frame_count + t);
dequantise_dwt_subbands_perceptual(0, decoder->header.quantiser_y,
dequantise_dwt_subbands_perceptual(0, QLUT[decoder->header.quantiser_y],
quantised_gop[t][1], gop_co[t],
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, base_q_co, 1, decoder->frame_count + t);
dequantise_dwt_subbands_perceptual(0, decoder->header.quantiser_y,
dequantise_dwt_subbands_perceptual(0, QLUT[decoder->header.quantiser_y],
quantised_gop[t][2], gop_cg[t],
decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, base_q_cg, 1, decoder->frame_count + t);
@@ -3021,19 +2997,29 @@ int main(int argc, char *argv[]) {
}
free(quantised_gop);
// Remove grain synthesis from Y channel for each GOP frame
// This must happen after dequantisation but before inverse DWT
// FIX: Disable grain removal for GOP frames to prevent frame-varying artifacts (blips)
// Grain removal in DWT coefficient space causes inconsistent results across GOP frames
// The Kotlin decoder may have a different implementation that avoids this issue
// TODO: Investigate correct grain removal for temporal DWT GOP frames
/*
for (int t = 0; t < gop_size; t++) {
remove_grain_synthesis_decoder(gop_y[t], decoder->header.width, decoder->header.height,
decoder->header.decomp_levels, decoder->frame_count + t,
decoder->header.quantiser_y);
}
*/
// Apply inverse 3D DWT (spatial + temporal)
apply_inverse_3d_dwt(gop_y, gop_co, gop_cg, decoder->header.width, decoder->header.height,
gop_size, decoder->header.decomp_levels, temporal_levels,
decoder->header.wavelet_filter);
// Debug: Check Y values after inverse DWT
if (verbose && decoder->frame_count == 0) {
fprintf(stderr, "[GOP-DEBUG] After inverse 3D DWT: Frame 0 Y[0]=%.1f, Y[1]=%.1f, Y[2]=%.1f\n",
gop_y[0][0], gop_y[0][1], gop_y[0][2]);
}
// Debug: Check spatial coefficients after inverse temporal DWT (before inverse spatial DWT)
// if (is_ezbc) {
// float max_y = 0.0f, min_y = 0.0f;