TAV: 3D DWT makes coherent picture at least

This commit is contained in:
minjaesong
2025-10-17 02:01:08 +09:00
parent 0cf1173dd6
commit 93622fc8ca
5 changed files with 117 additions and 94 deletions

View File

@@ -72,6 +72,7 @@ import kotlin.intArrayOf
import kotlin.let
import kotlin.longArrayOf
import kotlin.math.*
import kotlin.math.pow
import kotlin.repeat
import kotlin.text.format
import kotlin.text.lowercase
@@ -4538,30 +4539,18 @@ class GraphicsJSR223Delegate(private val vm: VM) {
* - Frames 8-15: Level 2 (tHH - highest frequency)
*/
private fun getTemporalSubbandLevel(frameIdx: Int, numFrames: Int, temporalLevels: Int): Int {
if (temporalLevels == 0) return 0
// Match encoder logic exactly (encoder_tav.c:1487-1501)
// After temporal DWT with 2 levels:
// Frames 0...num_frames/(2^2) = tLL (temporal low-low, coarsest, level 0)
// Frames in first half but after tLL = tLH (level 1)
// Remaining frames = tH from first level (level 2, finest)
val framesPerSubband = numFrames shr temporalLevels // numFrames / 2^temporalLevels
val framesPerLevel0 = numFrames shr temporalLevels // e.g., 16 >> 2 = 4, or 8 >> 2 = 2
// Safety check: ensure we have enough frames for the temporal levels
// Minimum frames needed = 2^temporalLevels
if (framesPerSubband == 0) {
// Not enough frames for this many temporal levels - treat all as base level
return 0
}
// Determine which temporal subband this frame belongs to
val subbandIdx = frameIdx / framesPerSubband
// Map subband index to level (0 = tLL, 1+ = temporal high-pass levels)
return if (subbandIdx == 0) 0 else {
// Find highest bit position in subbandIdx to determine level
var level = 0
var idx = subbandIdx
while (idx > 1) {
idx = idx shr 1
level++
}
level + 1
return when {
frameIdx < framesPerLevel0 -> 0 // Coarsest temporal level (tLL)
frameIdx < (numFrames shr 1) -> 1 // First level high-pass (tLH)
else -> 2 // Finest level high-pass (tH from level 1)
}
}
@@ -4575,9 +4564,10 @@ class GraphicsJSR223Delegate(private val vm: VM) {
* - Level 2 (tHH): 1.0 × 2^1.6 = 3.03
*/
private fun getTemporalQuantizerScale(temporalLevel: Int): Float {
val BETA = 0.8f
val TEMPORAL_BASE_SCALE = 1.0f
return TEMPORAL_BASE_SCALE * Math.pow(2.0, (BETA * temporalLevel).toDouble()).toFloat()
val BETA = 0.6f // Temporal scaling exponent (aggressive for temporal high-pass)
val KAPPA = 1.14f
val TEMPORAL_BASE_SCALE = 1.0f // Don't reduce tLL quantization (same as intra)
return TEMPORAL_BASE_SCALE * 2.0f.pow(BETA * temporalLevel.toFloat().pow(KAPPA))
}
// level is one-based index
@@ -6251,8 +6241,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
* @param compressedDataPtr Pointer to compressed Zstd data
* @param compressedSize Size of compressed data
* @param gopSize Number of frames in GOP (1-16)
* @param motionVectorsX X motion vectors in quarter-pixel units
* @param motionVectorsY Y motion vectors in quarter-pixel units
* @param motionVectorsX X motion vectors in 1/16-pixel units
* @param motionVectorsY Y motion vectors in 1/16-pixel units
* @param outputRGBAddrs Array of output RGB buffer addresses
* @param width Frame width
* @param height Frame height
@@ -6363,10 +6353,10 @@ class GraphicsJSR223Delegate(private val vm: VM) {
tavApplyInverse3DDWT(gopCg, width, height, gopSize, spatialLevels, temporalLevels, spatialFilter)
// Step 7: Apply inverse motion compensation (shift frames back)
// Note: Motion vectors are in quarter-pixel units
// Note: Motion vectors are in 1/16-pixel units, cumulative relative to frame 0
for (t in 1 until gopSize) { // Skip frame 0 (reference)
val dx = motionVectorsX[t] / 4 // Convert to pixel units
val dy = motionVectorsY[t] / 4
val dx = motionVectorsX[t] / 16 // Convert to pixel units
val dy = motionVectorsY[t] / 16
if (dx != 0 || dy != 0) {
applyInverseTranslation(gopY[t], width, height, dx, dy)
@@ -6486,36 +6476,30 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Haar wavelet inverse 1D transform
// The simplest wavelet: reverses averages and differences
// MUST match encoder's dwt_haar_inverse_1d exactly (encoder_tav.c:1265-1284)
private fun tavApplyDWTHaarInverse1D(data: FloatArray, length: Int) {
if (length < 2) return
val temp = FloatArray(length)
val half = (length + 1) / 2
// Split into low and high frequency components
for (i in 0 until half) {
temp[i] = data[i] // Low-pass coefficients (averages)
}
for (i in 0 until length / 2) {
if (half + i < length) {
temp[half + i] = data[half + i] // High-pass coefficients (differences)
}
}
// Haar inverse: reconstruct original samples from averages and differences
// Inverse Haar transform: reconstruct from averages and differences
// Read directly from data array (already has low-pass then high-pass layout)
for (i in 0 until half) {
if (2 * i + 1 < length) {
val avg = temp[i] // Average (low-pass)
val diff = if (half + i < length) temp[half + i] else 0.0f // Difference (high-pass)
// Reconstruct original adjacent pair
data[2 * i] = avg + diff // First sample: average + difference
data[2 * i + 1] = avg - diff // Second sample: average - difference
// Reconstruct adjacent pairs from average and difference
temp[2 * i] = data[i] + data[half + i] // average + difference
temp[2 * i + 1] = data[i] - data[half + i] // average - difference
} else {
// Handle odd length: last sample comes directly from low-pass
data[2 * i] = temp[i]
// Handle odd length: last sample comes from low-pass only
temp[2 * i] = data[i]
}
}
// Copy reconstructed data back
for (i in 0 until length) {
data[i] = temp[i]
}
}
// =============================================================================