From c85b007ba951285dd78613d2243257bd0b477211 Mon Sep 17 00:00:00 2001 From: minjaesong Date: Tue, 4 Nov 2025 02:10:32 +0900 Subject: [PATCH] TAV decoder fix: limited RGB range --- .../torvald/tsvm/GraphicsJSR223Delegate.kt | 10 +- video_encoder/decoder_tav.c | 273 ++++++++++++++++-- video_encoder/encoder_tav.c | 10 +- 3 files changed, 251 insertions(+), 42 deletions(-) diff --git a/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt b/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt index 9eaa2c1..7ac3e67 100644 --- a/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt +++ b/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt @@ -4842,13 +4842,13 @@ class GraphicsJSR223Delegate(private val vm: VM) { private fun perceptual_model3_LH(level: Float): Float { val H4 = 1.2f - val Q = 2f // using fixed value for fixed curve; quantiser will scale it up anyway - val Q12 = Q * 12f + val K = 2f // using fixed value for fixed curve; quantiser will scale it up anyway + val K12 = K * 12f val x = level - val Lx = H4 - ((Q + 1f) / 15f) * (x - 4f) - val C3 = -1f / 45f * (Q12 + 92) - val G3x = (-x / 180f) * (Q12 + 5 * x * x - 60 * x + 252) - C3 + H4 + val Lx = H4 - ((K + 1f) / 15f) * (x - 4f) + val C3 = -1f / 45f * (K12 + 92) + val G3x = (-x / 180f) * (K12 + 5 * x * x - 60 * x + 252) - C3 + H4 return if (level >= 4) Lx else G3x } diff --git a/video_encoder/decoder_tav.c b/video_encoder/decoder_tav.c index aeb1001..5228eba 100644 --- a/video_encoder/decoder_tav.c +++ b/video_encoder/decoder_tav.c @@ -165,26 +165,35 @@ static int tav_derive_encoder_qindex(int q_index, int q_y_global) { else return 5; } -static float perceptual_model3_LH(int quality, float level) { +static float perceptual_model3_LH(float level) { const float H4 = 1.2f; - const float Lx = H4 - ((quality + 1.0f) / 15.0f) * (level - 4.0f); - const float Ld = (quality + 1.0f) / -15.0f; - const float C = H4 - 4.0f * Ld - ((-16.0f * (quality - 5.0f)) / 15.0f); - const float Gx = (Ld * level) - (((quality - 5.0f) * (level - 8.0f) * level) / 15.0f) + C; - return (level >= 4) ? Lx : Gx; + const float K = 2.0f; // CRITICAL: Fixed value for fixed curve; quantiser will scale it up anyway + const float K12 = K * 12.0f; + const float x = level; + + const float Lx = H4 - ((K + 1.0f) / 15.0f) * (x - 4.0f); + const float C3 = -1.0f / 45.0f * (K12 + 92.0f); + const float G3x = (-x / 180.0f) * (K12 + 5.0f * x * x - 60.0f * x + 252.0f) - C3 + H4; + + return (level >= 4.0f) ? Lx : G3x; } static float perceptual_model3_HL(int quality, float LH) { return LH * ANISOTROPY_MULT[quality] + ANISOTROPY_BIAS[quality]; } -static float perceptual_model3_HH(float LH, float HL) { - return (HL / LH) * 1.44f; +static float lerp(float x, float y, float a) { + return x * (1.0f - a) + y * a; } -static float perceptual_model3_LL(int quality, float level) { - const float n = perceptual_model3_LH(quality, level); - const float m = perceptual_model3_LH(quality, level - 1) / n; +static float perceptual_model3_HH(float LH, float HL, float level) { + const float Kx = (sqrtf(level) - 1.0f) * 0.5f + 0.5f; + return lerp(LH, HL, Kx); +} + +static float perceptual_model3_LL(float level) { + const float n = perceptual_model3_LH(level); + const float m = perceptual_model3_LH(level - 1.0f) / n; return n / m; } @@ -201,10 +210,10 @@ static float get_perceptual_weight(int q_index, int q_y_global, int level0, int if (!is_chroma) { // LUMA CHANNEL if (subband_type == 0) { - return perceptual_model3_LL(quality_level, level); + return perceptual_model3_LL(level); } - const float LH = perceptual_model3_LH(quality_level, level); + const float LH = perceptual_model3_LH(level); if (subband_type == 1) { return LH; } @@ -220,7 +229,7 @@ static float get_perceptual_weight(int q_index, int q_y_global, int level0, int float detailer = 1.0f; if (level >= 1.8f && level <= 2.2f) detailer = TWO_PIXEL_DETAILER; else if (level >= 2.8f && level <= 3.2f) detailer = FOUR_PIXEL_DETAILER; - return perceptual_model3_HH(LH, HL) * detailer; + return perceptual_model3_HH(LH, HL, level) * detailer; } } else { // CHROMA CHANNELS @@ -239,13 +248,18 @@ static float get_perceptual_weight(int q_index, int q_y_global, int level0, int static void dequantize_dwt_subbands_perceptual(int q_index, int q_y_global, const int16_t *quantized, float *dequantized, int width, int height, int decomp_levels, - float base_quantizer, int is_chroma) { + float base_quantizer, int is_chroma, int frame_num) { dwt_subband_info_t subbands[32]; // Max possible subbands const int subband_count = calculate_subband_layout(width, height, decomp_levels, subbands); const int coeff_count = width * height; memset(dequantized, 0, coeff_count * sizeof(float)); + int is_debug = (frame_num == 32); + if (frame_num == 32) { + fprintf(stderr, "DEBUG: dequantize called for frame %d, is_chroma=%d\n", frame_num, is_chroma); + } + // Apply perceptual weighting to each subband for (int s = 0; s < subband_count; s++) { const dwt_subband_info_t *subband = &subbands[s]; @@ -253,10 +267,128 @@ static void dequantize_dwt_subbands_perceptual(int q_index, int q_y_global, cons subband->subband_type, is_chroma, decomp_levels); const float effective_quantizer = base_quantizer * weight; + if (is_debug && !is_chroma) { + if (subband->subband_type == 0) { // LL band + fprintf(stderr, " Subband level %d (LL): weight=%.6f, base_q=%.1f, effective_q=%.1f, count=%d\n", + subband->level, weight, base_quantizer, effective_quantizer, subband->coeff_count); + + // Print first 5 quantized LL coefficients + fprintf(stderr, " First 5 quantized LL: "); + for (int k = 0; k < 5 && k < subband->coeff_count; k++) { + int idx = subband->coeff_start + k; + fprintf(stderr, "%d ", quantized[idx]); + } + fprintf(stderr, "\n"); + + // Find max quantized LL coefficient + int max_quant_ll = 0; + for (int k = 0; k < subband->coeff_count; k++) { + int idx = subband->coeff_start + k; + int abs_val = quantized[idx] < 0 ? -quantized[idx] : quantized[idx]; + if (abs_val > max_quant_ll) max_quant_ll = abs_val; + } + fprintf(stderr, " Max quantized LL coefficient: %d (dequantizes to %.1f)\n", + max_quant_ll, max_quant_ll * effective_quantizer); + } + } + for (int i = 0; i < subband->coeff_count; i++) { const int idx = subband->coeff_start + i; if (idx < coeff_count) { - dequantized[idx] = quantized[idx] * effective_quantizer; + // CRITICAL: Must ROUND to match EZBC encoder's roundf() behavior + // Without rounding, truncation limits brightness range (e.g., Y maxes at 227 instead of 255) + const float untruncated = quantized[idx] * effective_quantizer; + dequantized[idx] = roundf(untruncated); + } + } + } + + // Debug: Verify LL band was dequantized correctly + if (is_debug && !is_chroma) { + // Find LL band again to verify + for (int s = 0; s < subband_count; s++) { + const dwt_subband_info_t *subband = &subbands[s]; + if (subband->level == decomp_levels && subband->subband_type == 0) { + fprintf(stderr, " AFTER all subbands processed - First 5 dequantized LL: "); + for (int k = 0; k < 5 && k < subband->coeff_count; k++) { + int idx = subband->coeff_start + k; + fprintf(stderr, "%.1f ", dequantized[idx]); + } + fprintf(stderr, "\n"); + + // Find max dequantized LL + float max_dequant_ll = -999.0f; + for (int k = 0; k < subband->coeff_count; k++) { + int idx = subband->coeff_start + k; + float abs_val = dequantized[idx] < 0 ? -dequantized[idx] : dequantized[idx]; + if (abs_val > max_dequant_ll) max_dequant_ll = abs_val; + } + fprintf(stderr, " AFTER all subbands - Max dequantized LL: %.1f\n", max_dequant_ll); + break; + } + } + } +} + +//============================================================================= +// Grain Synthesis Removal (matches TSVM exactly) +//============================================================================= + +// Deterministic RNG for grain synthesis (matches encoder) +static inline uint32_t tav_grain_synthesis_rng(uint32_t frame, uint32_t band, uint32_t x, uint32_t y) { + uint32_t key = frame * 0x9e3779b9u ^ band * 0x7f4a7c15u ^ (y << 16) ^ x; + // rng_hash implementation + uint32_t hash = key; + hash = hash ^ (hash >> 16); + hash = hash * 0x7feb352du; + hash = hash ^ (hash >> 15); + hash = hash * 0x846ca68bu; + hash = hash ^ (hash >> 16); + return hash; +} + +// Generate triangular noise from uint32 RNG (returns value in range [-1.0, 1.0]) +static inline float tav_grain_triangular_noise(uint32_t rng_val) { + // Get two uniform random values in [0, 1] + float u1 = (rng_val & 0xFFFFu) / 65535.0f; + float u2 = ((rng_val >> 16) & 0xFFFFu) / 65535.0f; + + // Convert to range [-1, 1] and average for triangular distribution + return (u1 + u2) - 1.0f; +} + +// Remove grain synthesis from DWT coefficients (decoder subtracts noise) +// This must be called AFTER dequantization but BEFORE inverse DWT +static void remove_grain_synthesis_decoder(float *coeffs, int width, int height, + int decomp_levels, int frame_num, int q_y_global) { + dwt_subband_info_t subbands[32]; + const int subband_count = calculate_subband_layout(width, height, decomp_levels, subbands); + + // Noise amplitude (matches Kotlin: qYGlobal.coerceAtMost(32) * 0.5f) + const float noise_amplitude = (q_y_global < 32 ? q_y_global : 32) * 0.5f; + + // Process each subband (skip LL band which is level 0) + for (int s = 0; s < subband_count; s++) { + const dwt_subband_info_t *subband = &subbands[s]; + if (subband->level == 0) continue; // Skip LL band + + // Calculate band index for RNG (matches Kotlin: level + subbandType * 31 + 16777619) + uint32_t band = subband->level + subband->subband_type * 31 + 16777619; + + // Remove noise from each coefficient in this subband + for (int i = 0; i < subband->coeff_count; i++) { + const int idx = subband->coeff_start + i; + if (idx < width * height) { + // Calculate 2D position from linear index + int y = idx / width; + int x = idx % width; + + // Generate same deterministic noise as encoder + uint32_t rng_val = tav_grain_synthesis_rng(frame_num, band, x, y); + float noise = tav_grain_triangular_noise(rng_val); + + // Subtract noise from coefficient + coeffs[idx] -= noise * noise_amplitude; } } } @@ -754,6 +886,7 @@ static tav_decoder_t* tav_decoder_init(const char *input_file, const char *outpu "-video_size", video_size, "-framerate", framerate, "-i", "pipe:3", // Video from fd 3 + "-color_range", "2", // Note: Audio decoding not yet implemented, so we output video-only MKV "-c:v", "ffv1", // FFV1 codec "-level", "3", // FFV1 level 3 @@ -762,6 +895,8 @@ static tav_decoder_t* tav_decoder_init(const char *input_file, const char *outpu "-g", "1", // GOP size 1 (all I-frames) "-slices", "24", // 24 slices for threading "-slicecrc", "1", // CRC per slice + "-pixel_format", "rgb24", // make FFmpeg encode to RGB + "-color_range", "2", "-f", "matroska", // MKV container output_file, "-y", // Overwrite output @@ -933,12 +1068,20 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint postprocess_coefficients_twobit(ptr, coeff_count, quantized_y, quantized_co, quantized_cg); // Debug: Check first few coefficients - if (decoder->frame_count < 1) { + if (decoder->frame_count == 32) { fprintf(stderr, " First 10 quantized Y coeffs: "); for (int i = 0; i < 10 && i < coeff_count; i++) { fprintf(stderr, "%d ", quantized_y[i]); } fprintf(stderr, "\n"); + + // Check for any large quantized values that should produce bright pixels + int max_quant_y = 0; + for (int i = 0; i < coeff_count; i++) { + int abs_val = quantized_y[i] < 0 ? -quantized_y[i] : quantized_y[i]; + if (abs_val > max_quant_y) max_quant_y = abs_val; + } + fprintf(stderr, " Max quantized Y coefficient: %d\n", max_quant_y); } // Dequantize (perceptual for versions 5-8, uniform for 1-4) @@ -946,13 +1089,21 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint if (is_perceptual) { dequantize_dwt_subbands_perceptual(0, qy, quantized_y, decoder->dwt_buffer_y, decoder->header.width, decoder->header.height, - decoder->header.decomp_levels, qy, 0); + decoder->header.decomp_levels, qy, 0, decoder->frame_count); + + // Debug: Check if values survived the function call + if (decoder->frame_count == 32) { + fprintf(stderr, " RIGHT AFTER dequantize_Y returns: first 5 values: %.1f %.1f %.1f %.1f %.1f\n", + decoder->dwt_buffer_y[0], decoder->dwt_buffer_y[1], decoder->dwt_buffer_y[2], + decoder->dwt_buffer_y[3], decoder->dwt_buffer_y[4]); + } + dequantize_dwt_subbands_perceptual(0, qy, quantized_co, decoder->dwt_buffer_co, decoder->header.width, decoder->header.height, - decoder->header.decomp_levels, qco, 1); + decoder->header.decomp_levels, qco, 1, decoder->frame_count); dequantize_dwt_subbands_perceptual(0, qy, quantized_cg, decoder->dwt_buffer_cg, decoder->header.width, decoder->header.height, - decoder->header.decomp_levels, qcg, 1); + decoder->header.decomp_levels, qcg, 1, decoder->frame_count); } else { for (int i = 0; i < coeff_count; i++) { decoder->dwt_buffer_y[i] = quantized_y[i] * qy; @@ -961,20 +1112,50 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint } } - // Debug: Check dequantized values before IDWT - if (decoder->frame_count < 1) { - fprintf(stderr, " After dequant - First 10 Y DWT coeffs: "); - for (int i = 0; i < 10 && i < decoder->frame_size; i++) { - fprintf(stderr, "%.1f ", decoder->dwt_buffer_y[i]); - } - fprintf(stderr, "\n"); + // Debug: Check dequantized values using correct subband layout + if (decoder->frame_count == 32) { + dwt_subband_info_t subbands[32]; + const int subband_count = calculate_subband_layout(decoder->header.width, decoder->header.height, + decoder->header.decomp_levels, subbands); - // Count non-zero coefficients - int nonzero = 0; - for (int i = 0; i < decoder->frame_size; i++) { - if (decoder->dwt_buffer_y[i] != 0.0f) nonzero++; + // Find LL band (highest level, type 0) + for (int s = 0; s < subband_count; s++) { + if (subbands[s].level == decoder->header.decomp_levels && subbands[s].subband_type == 0) { + fprintf(stderr, " LL band: level=%d, start=%d, count=%d\n", + subbands[s].level, subbands[s].coeff_start, subbands[s].coeff_count); + fprintf(stderr, " Reading LL first 5 from dwt_buffer_y[0-4]: %.1f %.1f %.1f %.1f %.1f\n", + decoder->dwt_buffer_y[0], decoder->dwt_buffer_y[1], decoder->dwt_buffer_y[2], + decoder->dwt_buffer_y[3], decoder->dwt_buffer_y[4]); + + // Find max in CORRECT LL band + float max_ll = -999.0f; + for (int i = 0; i < subbands[s].coeff_count; i++) { + int idx = subbands[s].coeff_start + i; + if (decoder->dwt_buffer_y[idx] > max_ll) max_ll = decoder->dwt_buffer_y[idx]; + } + fprintf(stderr, " Max LL coefficient BEFORE grain removal: %.1f\n", max_ll); + break; + } } - fprintf(stderr, " Non-zero Y coefficients after dequant: %d / %d\n", nonzero, decoder->frame_size); + } + + // Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT) + remove_grain_synthesis_decoder(decoder->dwt_buffer_y, decoder->header.width, decoder->header.height, + decoder->header.decomp_levels, decoder->frame_count, decoder->header.quantiser_y); + + // Debug: Check LL band AFTER grain removal + if (decoder->frame_count == 32) { + int ll_width = decoder->header.width; + int ll_height = decoder->header.height; + for (int l = 0; l < decoder->header.decomp_levels; l++) { + ll_width = (ll_width + 1) / 2; + ll_height = (ll_height + 1) / 2; + } + float max_ll = -999.0f; + for (int i = 0; i < ll_width * ll_height; i++) { + if (decoder->dwt_buffer_y[i] > max_ll) max_ll = decoder->dwt_buffer_y[i]; + } + fprintf(stderr, " Max LL coefficient AFTER grain removal: %.1f\n", max_ll); } // Apply inverse DWT with correct non-power-of-2 dimension handling @@ -987,6 +1168,15 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint decoder->header.decomp_levels, decoder->header.wavelet_filter); // Debug: Check spatial domain values after IDWT + if (decoder->frame_count == 32) { + float max_y_spatial = -999.0f; + for (int i = 0; i < decoder->frame_size; i++) { + if (decoder->dwt_buffer_y[i] > max_y_spatial) max_y_spatial = decoder->dwt_buffer_y[i]; + } + fprintf(stderr, " Max Y in spatial domain AFTER IDWT: %.1f\n", max_y_spatial); + } + + // Debug: Check spatial domain values after IDWT (original debug) if (decoder->frame_count < 1) { fprintf(stderr, " After IDWT - First 10 Y values: "); for (int i = 0; i < 10 && i < decoder->frame_size; i++) { @@ -1008,6 +1198,9 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint // Convert YCoCg-R/ICtCp to RGB const int is_ictcp = (decoder->header.version % 2 == 0); + float max_y = -999, max_co = -999, max_cg = -999; + int max_r = 0, max_g = 0, max_b = 0; + for (int i = 0; i < decoder->frame_size; i++) { uint8_t r, g, b; if (is_ictcp) { @@ -1020,12 +1213,28 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint decoder->dwt_buffer_cg[i], &r, &g, &b); } + // Track max values for debugging + if (decoder->frame_count == 1000) { + if (decoder->dwt_buffer_y[i] > max_y) max_y = decoder->dwt_buffer_y[i]; + if (decoder->dwt_buffer_co[i] > max_co) max_co = decoder->dwt_buffer_co[i]; + if (decoder->dwt_buffer_cg[i] > max_cg) max_cg = decoder->dwt_buffer_cg[i]; + if (r > max_r) max_r = r; + if (g > max_g) max_g = g; + if (b > max_b) max_b = b; + } + // RGB byte order for FFmpeg rgb24 decoder->current_frame_rgb[i * 3 + 0] = r; decoder->current_frame_rgb[i * 3 + 1] = g; decoder->current_frame_rgb[i * 3 + 2] = b; } + if (decoder->frame_count == 1000) { + fprintf(stderr, "\n=== Frame 1000 Value Analysis ===\n"); + fprintf(stderr, "Max YCoCg values: Y=%.1f, Co=%.1f, Cg=%.1f\n", max_y, max_co, max_cg); + fprintf(stderr, "Max RGB values: R=%d, G=%d, B=%d\n", max_r, max_g, max_b); + } + // Debug: Check RGB output if (decoder->frame_count < 1) { fprintf(stderr, " First 5 pixels RGB: "); diff --git a/video_encoder/encoder_tav.c b/video_encoder/encoder_tav.c index d7abef7..f73078c 100644 --- a/video_encoder/encoder_tav.c +++ b/video_encoder/encoder_tav.c @@ -6411,13 +6411,13 @@ static void quantise_dwt_coefficients(float *coeffs, int16_t *quantised, int siz // https://www.desmos.com/calculator/mjlpwqm8ge static float perceptual_model3_LH(int quality, float level) { float H4 = 1.2f; - float Q = 2.f; // using fixed value for fixed curve; quantiser will scale it up anyway - float Q12 = Q * 12.f; + float K = 2.f; // using fixed value for fixed curve; quantiser will scale it up anyway + float K12 = K * 12.f; float x = level; - float Lx = H4 - ((Q + 1.f) / 15.f) * (x - 4.f); - float C3 = -1.f / 45.f * (Q12 + 92); - float G3x = (-x / 180.f) * (Q12 + 5*x*x - 60*x + 252) - C3 + H4; + float Lx = H4 - ((K + 1.f) / 15.f) * (x - 4.f); + float C3 = -1.f / 45.f * (K12 + 92); + float G3x = (-x / 180.f) * (K12 + 5*x*x - 60*x + 252) - C3 + H4; return (level >= 4) ? Lx : G3x; }