TAV: still bugfixing

This commit is contained in:
minjaesong
2025-10-16 00:03:58 +09:00
parent 7e248bc83d
commit ea72dec996
6 changed files with 697 additions and 23 deletions

View File

@@ -4203,6 +4203,184 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
}
/**
* Reconstruct per-frame coefficients from unified GOP block (2-bit format)
* Reverse of encoder's preprocess_gop_unified()
*
* Layout: [Y_maps_all][Co_maps_all][Cg_maps_all][Y_other_vals][Co_other_vals][Cg_other_vals]
*
* 2-bit encoding: 00=0, 01=+1, 10=-1, 11=other (stored in value array)
*
* @param decompressedData Unified block data (after Zstd decompression)
* @param numFrames Number of frames in GOP
* @param numPixels Pixels per frame (width × height)
* @param channelLayout Channel layout (0=YCoCg, 2=Y-only, etc)
* @return Array of [frame][channel] where channel: 0=Y, 1=Co, 2=Cg
*/
private fun tavPostprocessGopUnified(
decompressedData: ByteArray,
numFrames: Int,
numPixels: Int,
channelLayout: Int
): Array<Array<ShortArray>> {
// 2 bits per coefficient
val mapBytesPerFrame = (numPixels * 2 + 7) / 8
// Determine which channels are present
// Bit 0: has alpha, Bit 1: has chroma (inverted), Bit 2: has luma (inverted)
val hasY = (channelLayout and 0x04) == 0
val hasCo = (channelLayout and 0x02) == 0 // Inverted: 0 = has chroma
val hasCg = (channelLayout and 0x02) == 0 // Inverted: 0 = has chroma
// Calculate buffer positions for maps
var readPtr = 0
val yMapsStart = if (hasY) readPtr else -1
if (hasY) readPtr += mapBytesPerFrame * numFrames
val coMapsStart = if (hasCo) readPtr else -1
if (hasCo) readPtr += mapBytesPerFrame * numFrames
val cgMapsStart = if (hasCg) readPtr else -1
if (hasCg) readPtr += mapBytesPerFrame * numFrames
// Count "other" values (code 11) across ALL frames
var yOtherCount = 0
var coOtherCount = 0
var cgOtherCount = 0
for (frame in 0 until numFrames) {
val frameMapOffset = frame * mapBytesPerFrame
for (i in 0 until numPixels) {
val bitPos = i * 2
val byteIdx = bitPos / 8
val bitOffset = bitPos % 8
if (hasY && yMapsStart + frameMapOffset + byteIdx < decompressedData.size) {
var code = (decompressedData[yMapsStart + frameMapOffset + byteIdx].toInt() shr bitOffset) and 0x03
if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) {
val nextByte = decompressedData[yMapsStart + frameMapOffset + byteIdx + 1].toInt() and 0xFF
code = (code and 0x01) or ((nextByte and 0x01) shl 1)
}
if (code == 3) yOtherCount++
}
if (hasCo && coMapsStart + frameMapOffset + byteIdx < decompressedData.size) {
var code = (decompressedData[coMapsStart + frameMapOffset + byteIdx].toInt() shr bitOffset) and 0x03
if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) {
val nextByte = decompressedData[coMapsStart + frameMapOffset + byteIdx + 1].toInt() and 0xFF
code = (code and 0x01) or ((nextByte and 0x01) shl 1)
}
if (code == 3) coOtherCount++
}
if (hasCg && cgMapsStart + frameMapOffset + byteIdx < decompressedData.size) {
var code = (decompressedData[cgMapsStart + frameMapOffset + byteIdx].toInt() shr bitOffset) and 0x03
if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) {
val nextByte = decompressedData[cgMapsStart + frameMapOffset + byteIdx + 1].toInt() and 0xFF
code = (code and 0x01) or ((nextByte and 0x01) shl 1)
}
if (code == 3) cgOtherCount++
}
}
}
// Value arrays start after all maps
val yValuesStart = readPtr
readPtr += yOtherCount * 2
val coValuesStart = readPtr
readPtr += coOtherCount * 2
val cgValuesStart = readPtr
// Allocate output arrays
val output = Array(numFrames) { Array(3) { ShortArray(numPixels) } }
var yValueIdx = 0
var coValueIdx = 0
var cgValueIdx = 0
for (frame in 0 until numFrames) {
val frameMapOffset = frame * mapBytesPerFrame
for (i in 0 until numPixels) {
val bitPos = i * 2
val byteIdx = bitPos / 8
val bitOffset = bitPos % 8
// Decode Y
if (hasY && yMapsStart + frameMapOffset + byteIdx < decompressedData.size) {
var code = (decompressedData[yMapsStart + frameMapOffset + byteIdx].toInt() shr bitOffset) and 0x03
if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) {
val nextByte = decompressedData[yMapsStart + frameMapOffset + byteIdx + 1].toInt() and 0xFF
code = (code and 0x01) or ((nextByte and 0x01) shl 1)
}
output[frame][0][i] = when (code) {
0 -> 0
1 -> 1
2 -> -1
3 -> {
val valOffset = yValuesStart + yValueIdx * 2
yValueIdx++
if (valOffset + 1 < decompressedData.size) {
val lo = decompressedData[valOffset].toInt() and 0xFF
val hi = decompressedData[valOffset + 1].toInt()
((hi shl 8) or lo).toShort()
} else 0
}
else -> 0
}
}
// Decode Co
if (hasCo && coMapsStart + frameMapOffset + byteIdx < decompressedData.size) {
var code = (decompressedData[coMapsStart + frameMapOffset + byteIdx].toInt() shr bitOffset) and 0x03
if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) {
val nextByte = decompressedData[coMapsStart + frameMapOffset + byteIdx + 1].toInt() and 0xFF
code = (code and 0x01) or ((nextByte and 0x01) shl 1)
}
output[frame][1][i] = when (code) {
0 -> 0
1 -> 1
2 -> -1
3 -> {
val valOffset = coValuesStart + coValueIdx * 2
coValueIdx++
if (valOffset + 1 < decompressedData.size) {
val lo = decompressedData[valOffset].toInt() and 0xFF
val hi = decompressedData[valOffset + 1].toInt()
((hi shl 8) or lo).toShort()
} else 0
}
else -> 0
}
}
// Decode Cg
if (hasCg && cgMapsStart + frameMapOffset + byteIdx < decompressedData.size) {
var code = (decompressedData[cgMapsStart + frameMapOffset + byteIdx].toInt() shr bitOffset) and 0x03
if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) {
val nextByte = decompressedData[cgMapsStart + frameMapOffset + byteIdx + 1].toInt() and 0xFF
code = (code and 0x01) or ((nextByte and 0x01) shl 1)
}
output[frame][2][i] = when (code) {
0 -> 0
1 -> 1
2 -> -1
3 -> {
val valOffset = cgValuesStart + cgValueIdx * 2
cgValueIdx++
if (valOffset + 1 < decompressedData.size) {
val lo = decompressedData[valOffset].toInt() and 0xFF
val hi = decompressedData[valOffset + 1].toInt()
((hi shl 8) or lo).toShort()
} else 0
}
else -> 0
}
}
}
}
return output
}
// TAV Simulated overlapping tiles constants (must match encoder)
private val TAV_TILE_SIZE_X = 640
private val TAV_TILE_SIZE_Y = 540
@@ -4348,6 +4526,53 @@ class GraphicsJSR223Delegate(private val vm: VM) {
else 6
}
// GOP temporal quantization helpers
/**
* Determines the temporal subband level for a given frame in a GOP.
* Returns 0 for tLL (temporal low-pass), 1+ for temporal high-pass levels.
*
* For 2-level Haar decomposition on 16 frames:
* - Frames 0-3: Level 0 (tLL - lowest frequency)
* - Frames 4-7: Level 1 (tH - mid frequency)
* - Frames 8-15: Level 2 (tHH - highest frequency)
*/
private fun getTemporalSubbandLevel(frameIdx: Int, numFrames: Int, temporalLevels: Int): Int {
if (temporalLevels == 0) return 0
val framesPerSubband = numFrames shr temporalLevels // numFrames / 2^temporalLevels
// Determine which temporal subband this frame belongs to
val subbandIdx = frameIdx / framesPerSubband
// Map subband index to level (0 = tLL, 1+ = temporal high-pass levels)
return if (subbandIdx == 0) 0 else {
// Find highest bit position in subbandIdx to determine level
var level = 0
var idx = subbandIdx
while (idx > 1) {
idx = idx shr 1
level++
}
level + 1
}
}
/**
* Calculates temporal quantizer scale for a given temporal subband level.
* Uses exponential scaling: TEMPORAL_BASE_SCALE × 2^(BETA × level)
*
* With BETA=0.8, TEMPORAL_BASE_SCALE=1.0:
* - Level 0 (tLL): 1.0 × 2^0.0 = 1.00
* - Level 1 (tH): 1.0 × 2^0.8 = 1.74
* - Level 2 (tHH): 1.0 × 2^1.6 = 3.03
*/
private fun getTemporalQuantizerScale(temporalLevel: Int): Float {
val BETA = 0.8f
val TEMPORAL_BASE_SCALE = 1.0f
return TEMPORAL_BASE_SCALE * Math.pow(2.0, (BETA * temporalLevel).toDouble()).toFloat()
}
// level is one-based index
private fun getPerceptualWeight(qIndex: Int, qYGlobal: Int, level0: Int, subbandType: Int, isChroma: Boolean, maxLevels: Int): Float {
// Psychovisual model based on DWT coefficient statistics and Human Visual System sensitivity
@@ -4644,7 +4869,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Coefficient delta encoding for efficient P-frames
readPtr = tavDecodeDeltaTileRGB(readPtr, channelLayout, tileX, tileY, currentRGBAddr,
width, height, qY, qCo, qCg,
waveletFilter, decompLevels, isLossless, tavVersion, isMonoblock, frameCount)
decompLevels, tavVersion, isMonoblock, frameCount)
dbgOut["frameMode"] = " "
}
}
@@ -5262,7 +5487,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
private fun tavDecodeDeltaTileRGB(readPtr: Long, channelLayout: Int, tileX: Int, tileY: Int, currentRGBAddr: Long,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int,
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int, isMonoblock: Boolean = false, frameCount: Int = 0): Long {
decompLevels: Int, tavVersion: Int, isMonoblock: Boolean = false, frameCount: Int = 0): Long {
val tileIdx = if (isMonoblock) {
0 // Single tile index for monoblock
@@ -5403,15 +5628,9 @@ class GraphicsJSR223Delegate(private val vm: VM) {
tavPreviousCoeffsCg!![tileIdx] = currentCg.clone()
// Apply inverse DWT
if (isLossless) {
tavApplyDWTInverseMultiLevel(currentY, tileWidth, tileHeight, decompLevels, 0, TavSharpenLuma)
tavApplyDWTInverseMultiLevel(currentCo, tileWidth, tileHeight, decompLevels, 0, TavNullFilter)
tavApplyDWTInverseMultiLevel(currentCg, tileWidth, tileHeight, decompLevels, 0, TavNullFilter)
} else {
tavApplyDWTInverseMultiLevel(currentY, tileWidth, tileHeight, decompLevels, waveletFilter, TavSharpenLuma)
tavApplyDWTInverseMultiLevel(currentCo, tileWidth, tileHeight, decompLevels, waveletFilter, TavNullFilter)
tavApplyDWTInverseMultiLevel(currentCg, tileWidth, tileHeight, decompLevels, waveletFilter, TavNullFilter)
}
tavApplyDWTInverseMultiLevel(currentY, tileWidth, tileHeight, decompLevels, 255, TavSharpenLuma)
tavApplyDWTInverseMultiLevel(currentCo, tileWidth, tileHeight, decompLevels, 255, TavNullFilter)
tavApplyDWTInverseMultiLevel(currentCg, tileWidth, tileHeight, decompLevels, 255, TavNullFilter)
// Debug: Check coefficient values after inverse DWT
if (tavDebugCurrentFrameNumber == tavDebugFrameTarget) {
@@ -5882,6 +6101,194 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
}
/**
* Apply inverse translation (motion compensation) to a frame.
* Inverse operation: shifts by +dx, +dy (opposite of forward encoder).
*
* @param frameData Input frame data to shift
* @param width Frame width
* @param height Frame height
* @param dx Translation in X direction (pixels)
* @param dy Translation in Y direction (pixels)
*/
private fun applyInverseTranslation(frameData: FloatArray, width: Int, height: Int, dx: Int, dy: Int) {
val output = FloatArray(width * height)
// Apply inverse translation with boundary clamping
for (y in 0 until height) {
for (x in 0 until width) {
// Inverse: shift by +dx, +dy (opposite of encoder's -dx, -dy)
var srcX = x + dx
var srcY = y + dy
// Clamp to frame boundaries
srcX = srcX.coerceIn(0, width - 1)
srcY = srcY.coerceIn(0, height - 1)
output[y * width + x] = frameData[srcY * width + srcX]
}
}
// Copy back to original array
System.arraycopy(output, 0, frameData, 0, frameData.size)
}
/**
* Main GOP unified decoder function.
* Decodes a unified 3D DWT GOP block (temporal + spatial) and outputs RGB frames.
*
* @param compressedDataPtr Pointer to compressed Zstd data
* @param compressedSize Size of compressed data
* @param gopSize Number of frames in GOP (1-16)
* @param motionVectorsX X motion vectors in quarter-pixel units
* @param motionVectorsY Y motion vectors in quarter-pixel units
* @param outputRGBAddrs Array of output RGB buffer addresses
* @param width Frame width
* @param height Frame height
* @param qIndex Quality index
* @param qYGlobal Global Y quantizer
* @param qCoGlobal Global Co quantizer
* @param qCgGlobal Global Cg quantizer
* @param channelLayout Channel layout flags
* @param spatialFilter Wavelet filter type
* @param spatialLevels Number of spatial DWT levels (default 6)
* @param temporalLevels Number of temporal DWT levels (default 2)
* @return Number of frames decoded
*/
fun tavDecodeGopUnified(
compressedDataPtr: Long,
compressedSize: Int,
gopSize: Int,
motionVectorsX: IntArray,
motionVectorsY: IntArray,
outputRGBAddrs: LongArray,
width: Int,
height: Int,
qIndex: Int,
qYGlobal: Int,
qCoGlobal: Int,
qCgGlobal: Int,
channelLayout: Int,
spatialFilter: Int = 1,
spatialLevels: Int = 6,
temporalLevels: Int = 2
): Int {
val numPixels = width * height
// Step 1: Decompress unified GOP block
val compressedData = ByteArray(compressedSize)
UnsafeHelper.memcpyRaw(
null,
vm.usermem.ptr + compressedDataPtr,
compressedData,
UnsafeHelper.getArrayOffset(compressedData),
compressedSize.toLong()
)
val decompressedData = try {
ZstdInputStream(java.io.ByteArrayInputStream(compressedData)).use { zstd ->
zstd.readBytes()
}
} catch (e: Exception) {
println("ERROR: Zstd decompression failed: ${e.message}")
return 0
}
// Step 2: Postprocess unified block to per-frame coefficients
val quantizedCoeffs = tavPostprocessGopUnified(
decompressedData,
gopSize,
numPixels,
channelLayout
)
// Step 3: Allocate GOP buffers for float coefficients
val gopY = Array(gopSize) { FloatArray(numPixels) }
val gopCo = Array(gopSize) { FloatArray(numPixels) }
val gopCg = Array(gopSize) { FloatArray(numPixels) }
// Step 4: Calculate subband layout (needed for perceptual dequantization)
val subbands = calculateSubbandLayout(width, height, spatialLevels)
// Step 5: Dequantize with temporal-spatial scaling
for (t in 0 until gopSize) {
val temporalLevel = getTemporalSubbandLevel(t, gopSize, temporalLevels)
val temporalScale = getTemporalQuantizerScale(temporalLevel)
// Apply temporal scaling to base quantizers
val baseQY = (qYGlobal * temporalScale).coerceIn(1.0f, 255.0f)
val baseQCo = (qCoGlobal * temporalScale).coerceIn(1.0f, 255.0f)
val baseQCg = (qCgGlobal * temporalScale).coerceIn(1.0f, 255.0f)
// Use existing perceptual dequantization for spatial weighting
dequantiseDWTSubbandsPerceptual(
qIndex, qYGlobal,
quantizedCoeffs[t][0], gopY[t],
subbands, baseQY, false, spatialLevels // isChroma=false
)
dequantiseDWTSubbandsPerceptual(
qIndex, qYGlobal,
quantizedCoeffs[t][1], gopCo[t],
subbands, baseQCo, true, spatialLevels // isChroma=true
)
dequantiseDWTSubbandsPerceptual(
qIndex, qYGlobal,
quantizedCoeffs[t][2], gopCg[t],
subbands, baseQCg, true, spatialLevels // isChroma=true
)
}
// Step 6: Apply inverse 3D DWT (spatial first, then temporal)
tavApplyInverse3DDWT(gopY, width, height, gopSize, spatialLevels, temporalLevels, spatialFilter)
tavApplyInverse3DDWT(gopCo, width, height, gopSize, spatialLevels, temporalLevels, spatialFilter)
tavApplyInverse3DDWT(gopCg, width, height, gopSize, spatialLevels, temporalLevels, spatialFilter)
// Step 7: Apply inverse motion compensation (shift frames back)
// Note: Motion vectors are in quarter-pixel units
for (t in 1 until gopSize) { // Skip frame 0 (reference)
val dx = motionVectorsX[t] / 4 // Convert to pixel units
val dy = motionVectorsY[t] / 4
if (dx != 0 || dy != 0) {
applyInverseTranslation(gopY[t], width, height, dx, dy)
applyInverseTranslation(gopCo[t], width, height, dx, dy)
applyInverseTranslation(gopCg[t], width, height, dx, dy)
}
}
// Step 8: Convert each frame to RGB and write to output buffers
for (t in 0 until gopSize) {
val rgbAddr = outputRGBAddrs[t]
for (i in 0 until numPixels) {
val y = gopY[t][i]
val co = gopCo[t][i]
val cg = gopCg[t][i]
// YCoCg-R to RGB conversion
val tmp = y - (cg / 2.0f)
val g = cg + tmp
val b = tmp - (co / 2.0f)
val r = b + co
// Clamp to 0-255 range
val rClamped = r.toInt().coerceIn(0, 255)
val gClamped = g.toInt().coerceIn(0, 255)
val bClamped = b.toInt().coerceIn(0, 255)
// Write RGB24 format (3 bytes per pixel)
val offset = rgbAddr + i * 3L
vm.usermem[offset] = rClamped.toByte()
vm.usermem[offset + 1] = gClamped.toByte()
vm.usermem[offset + 2] = bClamped.toByte()
}
}
return gopSize
}
// Biorthogonal 13/7 wavelet inverse 1D transform
// Synthesis filters: Low-pass (13 taps), High-pass (7 taps)
private fun tavApplyDWTBior137Inverse1D(data: FloatArray, length: Int) {
@@ -5994,4 +6401,78 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
}
// =============================================================================
// Temporal 3D DWT Functions (GOP Decoding)
// =============================================================================
/**
* Inverse 1D temporal DWT (Haar) along time axis
* Reuses existing Haar inverse implementation
*/
private fun tavApplyTemporalDWTInverse1D(data: FloatArray, numFrames: Int) {
if (numFrames < 2) return
tavApplyDWTHaarInverse1D(data, numFrames)
}
/**
* Apply inverse 3D DWT to GOP data (spatial + temporal)
* Order: SPATIAL first (each frame), then TEMPORAL (across frames)
*
* @param gopData Array of frame buffers [frame][pixel]
* @param width Frame width
* @param height Frame height
* @param numFrames Number of frames in GOP
* @param spatialLevels Spatial decomposition levels (typically 6)
* @param temporalLevels Temporal decomposition levels (typically 2)
* @param spatialFilter Spatial wavelet filter type (0=5/3, 1=9/7, 255=Haar)
*/
private fun tavApplyInverse3DDWT(
gopData: Array<FloatArray>,
width: Int,
height: Int,
numFrames: Int,
spatialLevels: Int,
temporalLevels: Int,
spatialFilter: Int
) {
if (numFrames < 2) return
val numPixels = width * height
val temporalLine = FloatArray(numFrames)
// Step 1: Apply inverse 2D spatial DWT to each temporal subband (each frame)
for (t in 0 until numFrames) {
tavApplyDWTInverseMultiLevel(
gopData[t], width, height,
spatialLevels, spatialFilter,
TavNullFilter // No sharpening for GOP frames
)
}
// Step 2: Apply inverse temporal DWT to each spatial location
for (y in 0 until height) {
for (x in 0 until width) {
val pixelIdx = y * width + x
// Extract temporal coefficients for this spatial location
for (t in 0 until numFrames) {
temporalLine[t] = gopData[t][pixelIdx]
}
// Apply inverse temporal DWT with multiple levels (reverse order)
for (level in temporalLevels - 1 downTo 0) {
val levelFrames = numFrames shr level
if (levelFrames >= 2) {
tavApplyTemporalDWTInverse1D(temporalLine, levelFrames)
}
}
// Write back reconstructed values
for (t in 0 until numFrames) {
gopData[t][pixelIdx] = temporalLine[t]
}
}
}
}
}