TAV update: CDF 5/3 for motion coder

This commit is contained in:
minjaesong
2025-11-23 18:16:12 +09:00
parent e928d2d3ec
commit 1c7ab17b1c
6 changed files with 174 additions and 95 deletions

View File

@@ -6297,65 +6297,25 @@ class GraphicsJSR223Delegate(private val vm: VM) {
if (length < 2) return
val temp = FloatArray(length)
val half = (length + 1) / 2 // Handle odd lengths properly
val half = (length + 1) / 2
// Split into low and high frequency components (matching encoder layout)
// Copy low-pass and high-pass subbands to temp
System.arraycopy(data, 0, temp, 0, length)
// Undo update step (low-pass)
for (i in 0 until half) {
temp[i] = data[i] // Low-pass coefficients (first half)
}
for (i in 0 until length / 2) {
if (half + i < length && half + i < data.size) {
temp[half + i] = data[half + i] // High-pass coefficients (second half)
}
val update = 0.25f * ((if (i > 0) temp[half + i - 1] else 0.0f) +
(if (i < half - 1) temp[half + i] else 0.0f))
temp[i] -= update
}
// 5/3 inverse lifting (undo forward steps in reverse order)
// Step 2: Undo update step (1/4 coefficient) - JPEG2000 symmetric extension
// Undo predict step (high-pass) and interleave samples
for (i in 0 until half) {
val leftIdx = half + i - 1
val centerIdx = half + i
// Symmetric extension for boundary handling
val left = when {
leftIdx >= 0 && leftIdx < length -> temp[leftIdx]
centerIdx < length && centerIdx + 1 < length -> temp[centerIdx + 1] // Mirror
centerIdx < length -> temp[centerIdx]
else -> 0.0f
}
val right = if (centerIdx < length) temp[centerIdx] else 0.0f
temp[i] -= 0.25f * (left + right)
}
// Step 1: Undo predict step (1/2 coefficient) - JPEG2000 symmetric extension
for (i in 0 until length / 2) {
if (half + i < length) {
val left = temp[i]
// Symmetric extension for right boundary
val right = if (i < half - 1) temp[i + 1] else if (half > 2) temp[half - 2] else temp[half - 1]
temp[half + i] += 0.5f * (left + right) // ADD to undo the subtraction in encoder
}
}
// Simple reconstruction (revert to working version)
for (i in 0 until length) {
if (i % 2 == 0) {
// Even positions: low-pass coefficients
data[i] = temp[i / 2]
} else {
// Odd positions: high-pass coefficients
val idx = i / 2
if (half + idx < length) {
data[i] = temp[half + idx]
} else {
// Symmetric extension: mirror the last available high-pass coefficient
val lastHighIdx = (length / 2) - 1
if (lastHighIdx >= 0 && half + lastHighIdx < length) {
data[i] = temp[half + lastHighIdx]
} else {
data[i] = 0.0f
}
}
data[2 * i] = temp[i] // Even samples (low-pass)
val idx = 2 * i + 1
if (idx < length) {
val pred = 0.5f * (temp[i] + (if (i < half - 1) temp[i + 1] else temp[i]))
data[idx] = temp[half + i] + pred // Odd samples (high-pass)
}
}
}
@@ -6514,7 +6474,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
spatialLevels: Int = 6,
temporalLevels: Int = 2,
entropyCoder: Int = 0,
bufferOffset: Long = 0
bufferOffset: Long = 0,
temporalMotionCoder: Int = 0
): Array<Any> {
val dbgOut = HashMap<String, Any>()
dbgOut["qY"] = qYGlobal
@@ -6634,9 +6595,9 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
// Step 6: Apply inverse 3D DWT using GOP dimensions (may be cropped)
tavApplyInverse3DDWT(gopY, gopWidth, gopHeight, gopSize, spatialLevels, temporalLevels, spatialFilter)
tavApplyInverse3DDWT(gopCo, gopWidth, gopHeight, gopSize, spatialLevels, temporalLevels, spatialFilter)
tavApplyInverse3DDWT(gopCg, gopWidth, gopHeight, gopSize, spatialLevels, temporalLevels, spatialFilter)
tavApplyInverse3DDWT(gopY, gopWidth, gopHeight, gopSize, spatialLevels, temporalLevels, spatialFilter, temporalMotionCoder)
tavApplyInverse3DDWT(gopCo, gopWidth, gopHeight, gopSize, spatialLevels, temporalLevels, spatialFilter, temporalMotionCoder)
tavApplyInverse3DDWT(gopCg, gopWidth, gopHeight, gopSize, spatialLevels, temporalLevels, spatialFilter, temporalMotionCoder)
// Step 8: Convert to RGB and composite to full frame
// With crop encoding, center the cropped frame and fill letterbox areas with black
@@ -6780,7 +6741,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
spatialLevels: Int = 6,
temporalLevels: Int = 3,
entropyCoder: Int = 0,
bufferOffset: Long = 0
bufferOffset: Long = 0,
temporalMotionCoder: Int = 0
) {
// Cancel any existing decode thread
asyncDecodeThread?.interrupt()
@@ -6798,7 +6760,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
width, height,
qIndex, qYGlobal, qCoGlobal, qCgGlobal,
channelLayout, spatialFilter, spatialLevels, temporalLevels,
entropyCoder, bufferOffset
entropyCoder, bufferOffset, temporalMotionCoder
)
asyncDecodeResult = result
asyncDecodeComplete.set(true)
@@ -6943,12 +6905,17 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// =============================================================================
/**
* Inverse 1D temporal DWT (Haar) along time axis
* Reuses existing Haar inverse implementation
* Inverse 1D temporal DWT along time axis
* Supports both Haar and CDF 5/3 wavelets
* @param temporalMotionCoder 0=Haar, 1=CDF 5/3
*/
private fun tavApplyTemporalDWTInverse1D(data: FloatArray, numFrames: Int) {
private fun tavApplyTemporalDWTInverse1D(data: FloatArray, numFrames: Int, temporalMotionCoder: Int = 0) {
if (numFrames < 2) return
tavApplyDWTHaarInverse1D(data, numFrames)
if (temporalMotionCoder == 0) {
tavApplyDWTHaarInverse1D(data, numFrames)
} else {
tavApplyDWT53Inverse1D(data, numFrames)
}
}
/**
@@ -6962,6 +6929,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
* @param spatialLevels Spatial decomposition levels (typically 6)
* @param temporalLevels Temporal decomposition levels (typically 2)
* @param spatialFilter Spatial wavelet filter type (0=5/3, 1=9/7, 255=Haar)
* @param temporalMotionCoder Temporal wavelet type (0=Haar, 1=CDF 5/3)
*/
private fun tavApplyInverse3DDWT(
gopData: Array<FloatArray>,
@@ -6970,7 +6938,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
numFrames: Int,
spatialLevels: Int,
temporalLevels: Int,
spatialFilter: Int
spatialFilter: Int,
temporalMotionCoder: Int = 0
) {
// Step 1: Apply inverse 2D spatial DWT to each temporal subband (each frame)
// This is required even for single frames (I-frames) to convert from DWT coefficients to pixel space
@@ -7008,7 +6977,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
for (level in temporalLevels - 1 downTo 0) {
val levelFrames = temporalLengths[level]
if (levelFrames >= 2) {
tavApplyTemporalDWTInverse1D(temporalLine, levelFrames)
tavApplyTemporalDWTInverse1D(temporalLine, levelFrames, temporalMotionCoder)
}
}