wavelet deblocking using simulated overlapping tiles

This commit is contained in:
minjaesong
2025-09-16 10:03:17 +09:00
parent 54f335e3de
commit a5da200507
2 changed files with 459 additions and 62 deletions

View File

@@ -16,6 +16,11 @@ import kotlin.math.*
class GraphicsJSR223Delegate(private val vm: VM) {
// TAV Simulated overlapping tiles constants (must match encoder)
private val TILE_SIZE = 112
private val TILE_MARGIN = 32 // 32-pixel margin for 3 DWT levels (4 * 2^3 = 32px)
private val PADDED_TILE_SIZE = TILE_SIZE + 2 * TILE_MARGIN // 112 + 64 = 176px
// Reusable working arrays to reduce allocation overhead
private val idct8TempBuffer = FloatArray(64)
private val idct16TempBuffer = FloatArray(256) // For 16x16 IDCT
@@ -3978,62 +3983,78 @@ class GraphicsJSR223Delegate(private val vm: VM) {
println("TAV decode error: ${e.message}")
}
// Apply deblocking filter if enabled to reduce DWT quantization artifacts
// if (enableDeblocking) {
// tavDeblockingFilter(currentRGBAddr, width, height)
// tavAdaptiveDeblockingFilter(currentRGBAddr, width, height)
// }
}
private fun decodeDWTIntraTileRGB(readPtr: Long, tileX: Int, tileY: Int, currentRGBAddr: Long,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int, rcf: Float,
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int): Long {
val tileSize = 112
val coeffCount = tileSize * tileSize
// Now reading padded coefficient tiles (176x176) instead of core tiles (112x112)
val paddedSize = PADDED_TILE_SIZE
val paddedCoeffCount = paddedSize * paddedSize
var ptr = readPtr
// Read quantized DWT coefficients for Y, Co, Cg channels
val quantizedY = ShortArray(coeffCount)
val quantizedCo = ShortArray(coeffCount)
val quantizedCg = ShortArray(coeffCount)
// Read quantized DWT coefficients for padded tile Y, Co, Cg channels (176x176)
val quantizedY = ShortArray(paddedCoeffCount)
val quantizedCo = ShortArray(paddedCoeffCount)
val quantizedCg = ShortArray(paddedCoeffCount)
// Read Y coefficients
for (i in 0 until coeffCount) {
// Read Y coefficients (176x176)
for (i in 0 until paddedCoeffCount) {
quantizedY[i] = vm.peekShort(ptr)
ptr += 2
}
// Read Co coefficients
for (i in 0 until coeffCount) {
// Read Co coefficients (176x176)
for (i in 0 until paddedCoeffCount) {
quantizedCo[i] = vm.peekShort(ptr)
ptr += 2
}
// Read Cg coefficients
for (i in 0 until coeffCount) {
// Read Cg coefficients (176x176)
for (i in 0 until paddedCoeffCount) {
quantizedCg[i] = vm.peekShort(ptr)
ptr += 2
}
// Dequantize and apply inverse DWT
val yTile = FloatArray(coeffCount)
val coTile = FloatArray(coeffCount)
val cgTile = FloatArray(coeffCount)
// Dequantize padded coefficient tiles (176x176)
val yPaddedTile = FloatArray(paddedCoeffCount)
val coPaddedTile = FloatArray(paddedCoeffCount)
val cgPaddedTile = FloatArray(paddedCoeffCount)
for (i in 0 until coeffCount) {
yTile[i] = quantizedY[i] * qY * rcf
coTile[i] = quantizedCo[i] * qCo * rcf
cgTile[i] = quantizedCg[i] * qCg * rcf
for (i in 0 until paddedCoeffCount) {
yPaddedTile[i] = quantizedY[i] * qY * rcf
coPaddedTile[i] = quantizedCo[i] * qCo * rcf
cgPaddedTile[i] = quantizedCg[i] * qCg * rcf
}
// Apply inverse DWT using specified filter with decomposition levels
// Apply inverse DWT on full padded tiles (176x176)
if (isLossless) {
applyDWTInverseMultiLevel(yTile, tileSize, tileSize, decompLevels, 0)
applyDWTInverseMultiLevel(coTile, tileSize, tileSize, decompLevels, 0)
applyDWTInverseMultiLevel(cgTile, tileSize, tileSize, decompLevels, 0)
applyDWTInverseMultiLevel(yPaddedTile, paddedSize, paddedSize, decompLevels, 0)
applyDWTInverseMultiLevel(coPaddedTile, paddedSize, paddedSize, decompLevels, 0)
applyDWTInverseMultiLevel(cgPaddedTile, paddedSize, paddedSize, decompLevels, 0)
} else {
applyDWTInverseMultiLevel(yTile, tileSize, tileSize, decompLevels, waveletFilter)
applyDWTInverseMultiLevel(coTile, tileSize, tileSize, decompLevels, waveletFilter)
applyDWTInverseMultiLevel(cgTile, tileSize, tileSize, decompLevels, waveletFilter)
applyDWTInverseMultiLevel(yPaddedTile, paddedSize, paddedSize, decompLevels, waveletFilter)
applyDWTInverseMultiLevel(coPaddedTile, paddedSize, paddedSize, decompLevels, waveletFilter)
applyDWTInverseMultiLevel(cgPaddedTile, paddedSize, paddedSize, decompLevels, waveletFilter)
}
// Extract core 112x112 pixels from reconstructed padded tiles (176x176)
val yTile = FloatArray(TILE_SIZE * TILE_SIZE)
val coTile = FloatArray(TILE_SIZE * TILE_SIZE)
val cgTile = FloatArray(TILE_SIZE * TILE_SIZE)
for (y in 0 until TILE_SIZE) {
for (x in 0 until TILE_SIZE) {
val coreIdx = y * TILE_SIZE + x
val paddedIdx = (y + TILE_MARGIN) * paddedSize + (x + TILE_MARGIN)
yTile[coreIdx] = yPaddedTile[paddedIdx]
coTile[coreIdx] = coPaddedTile[paddedIdx]
cgTile[coreIdx] = cgPaddedTile[paddedIdx]
}
}
// Convert to RGB based on TAV version (YCoCg-R for v1, ICtCp for v2)
@@ -4326,6 +4347,14 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Lifting scheme implementation for 9/7 irreversible filter
}
private fun generateWindowFunction(window: FloatArray, size: Int) {
// Raised cosine (Hann) window for smooth blending
for (i in 0 until size) {
val t = i.toFloat() / (size - 1)
window[i] = 0.5f * (1.0f - kotlin.math.cos(PI * t))
}
}
private fun applyDWTInverseMultiLevel(data: FloatArray, width: Int, height: Int, levels: Int, filterType: Int) {
// Multi-level inverse DWT - reconstruct from smallest to largest (reverse of encoder)
val size = width // Full tile size (112 for TAV)
@@ -4602,12 +4631,302 @@ class GraphicsJSR223Delegate(private val vm: VM) {
if (half + idx < length) {
data[i] = temp[half + idx]
} else {
data[i] = 0.0f // Boundary case
// Symmetric extension: mirror the last available high-pass coefficient
val lastHighIdx = (length / 2) - 1
if (lastHighIdx >= 0 && half + lastHighIdx < length) {
data[i] = temp[half + lastHighIdx]
} else {
data[i] = 0.0f
}
}
}
}
}
private fun tavAdaptiveDeblockingFilter(rgbAddr: Long, width: Int, height: Int) {
val tileSize = 112
val tilesX = (width + tileSize - 1) / tileSize
val tilesY = (height + tileSize - 1) / tileSize
// Process vertical seams (between horizontally adjacent tiles)
for (tileY in 0 until tilesY) {
for (tileX in 0 until tilesX - 1) {
val seamX = (tileX + 1) * tileSize // Actual boundary between tiles
deblockVerticalSeamStrong(rgbAddr, width, height, seamX, tileY * tileSize, tileSize)
}
}
// Process horizontal seams (between vertically adjacent tiles)
for (tileY in 0 until tilesY - 1) {
for (tileX in 0 until tilesX) {
val seamY = (tileY + 1) * tileSize // Actual boundary between tiles
deblockHorizontalSeamStrong(rgbAddr, width, height, tileX * tileSize, seamY, tileSize)
}
}
}
private fun deblockVerticalSeamStrong(rgbAddr: Long, width: Int, height: Int, seamX: Int, startY: Int, tileHeight: Int) {
if (seamX >= width) return
val endY = minOf(startY + tileHeight, height)
for (y in startY until endY) {
if (y >= height) break
// Check for discontinuity across the seam
val leftX = seamX - 1
val rightX = seamX
if (leftX >= 0 && rightX < width) {
val leftOffset = (y * width + leftX) * 3L
val rightOffset = (y * width + rightX) * 3L
val leftR = vm.peek(rgbAddr + leftOffset).toInt() and 0xFF
val leftG = vm.peek(rgbAddr + leftOffset + 1).toInt() and 0xFF
val leftB = vm.peek(rgbAddr + leftOffset + 2).toInt() and 0xFF
val rightR = vm.peek(rgbAddr + rightOffset).toInt() and 0xFF
val rightG = vm.peek(rgbAddr + rightOffset + 1).toInt() and 0xFF
val rightB = vm.peek(rgbAddr + rightOffset + 2).toInt() and 0xFF
// Calculate discontinuity strength
val diffR = abs(leftR - rightR)
val diffG = abs(leftG - rightG)
val diffB = abs(leftB - rightB)
val maxDiff = maxOf(diffR, diffG, diffB)
// Only apply deblocking if there's a significant discontinuity
if (maxDiff in 2 until 120) {
// Adaptive filter radius: wider for smooth gradients, narrower for sharp edges
val filterRadius = when {
maxDiff <= 15 -> 6 // Very smooth gradients: wide filter (13 pixels)
maxDiff <= 30 -> 4 // Moderate gradients: medium filter (9 pixels)
maxDiff <= 60 -> 3 // Sharp transitions: narrow filter (7 pixels)
else -> 2 // Very sharp edges: minimal filter (5 pixels)
}
for (dx in -filterRadius..filterRadius) {
val x = seamX + dx
if (x in 0 until width) {
val offset = (y * width + x) * 3L
val currentR = vm.peek(rgbAddr + offset).toInt() and 0xFF
val currentG = vm.peek(rgbAddr + offset + 1).toInt() and 0xFF
val currentB = vm.peek(rgbAddr + offset + 2).toInt() and 0xFF
var sumR = 0.0f
var sumG = 0.0f
var sumB = 0.0f
var weightSum = 0.0f
// Bilateral filtering with spatial and intensity weights
for (sx in maxOf(0, x-filterRadius)..minOf(width-1, x+filterRadius)) {
val sOffset = (y * width + sx) * 3L
val sR = vm.peek(rgbAddr + sOffset).toInt() and 0xFF
val sG = vm.peek(rgbAddr + sOffset + 1).toInt() and 0xFF
val sB = vm.peek(rgbAddr + sOffset + 2).toInt() and 0xFF
// Spatial weight (distance from current pixel)
val spatialWeight = 1.0f / (1.0f + abs(sx - x))
// Intensity weight (color similarity)
val colorDiff = sqrt(((sR - currentR) * (sR - currentR) +
(sG - currentG) * (sG - currentG) +
(sB - currentB) * (sB - currentB)).toFloat())
val intensityWeight = exp(-colorDiff / 30.0f)
val totalWeight = spatialWeight * intensityWeight
sumR += sR * totalWeight
sumG += sG * totalWeight
sumB += sB * totalWeight
weightSum += totalWeight
}
if (weightSum > 0) {
val filteredR = (sumR / weightSum).toInt()
val filteredG = (sumG / weightSum).toInt()
val filteredB = (sumB / weightSum).toInt()
// Concentrate blur heavily at the seam boundary
val distance = abs(dx).toFloat()
val blendWeight = when {
distance == 0.0f -> 0.95f // Maximum blur at exact seam
distance == 1.0f -> 0.8f // Strong blur adjacent to seam
distance == 2.0f -> 0.5f // Medium blur 2 pixels away
else -> exp(-distance * distance / 1.5f) * 0.3f // Gentle falloff beyond
}
val finalR = (currentR * (1 - blendWeight) + filteredR * blendWeight).toInt().coerceIn(0, 255)
val finalG = (currentG * (1 - blendWeight) + filteredG * blendWeight).toInt().coerceIn(0, 255)
val finalB = (currentB * (1 - blendWeight) + filteredB * blendWeight).toInt().coerceIn(0, 255)
vm.poke(rgbAddr + offset, finalR.toByte())
vm.poke(rgbAddr + offset + 1, finalG.toByte())
vm.poke(rgbAddr + offset + 2, finalB.toByte())
}
}
}
}
}
}
}
private fun deblockHorizontalSeamStrong(rgbAddr: Long, width: Int, height: Int, startX: Int, seamY: Int, tileWidth: Int) {
if (seamY >= height) return
val endX = minOf(startX + tileWidth, width)
for (x in startX until endX) {
if (x >= width) break
// Check for discontinuity across the seam
val topY = seamY - 1
val bottomY = seamY
if (topY >= 0 && bottomY < height) {
val topOffset = (topY * width + x) * 3L
val bottomOffset = (bottomY * width + x) * 3L
val topR = vm.peek(rgbAddr + topOffset).toInt() and 0xFF
val topG = vm.peek(rgbAddr + topOffset + 1).toInt() and 0xFF
val topB = vm.peek(rgbAddr + topOffset + 2).toInt() and 0xFF
val bottomR = vm.peek(rgbAddr + bottomOffset).toInt() and 0xFF
val bottomG = vm.peek(rgbAddr + bottomOffset + 1).toInt() and 0xFF
val bottomB = vm.peek(rgbAddr + bottomOffset + 2).toInt() and 0xFF
// Calculate discontinuity strength
val diffR = abs(topR - bottomR)
val diffG = abs(topG - bottomG)
val diffB = abs(topB - bottomB)
val maxDiff = maxOf(diffR, diffG, diffB)
// Only apply deblocking if there's a significant discontinuity
if (maxDiff in 2 until 120) {
// Adaptive filter radius: wider for smooth gradients, narrower for sharp edges
val filterRadius = when {
maxDiff <= 15 -> 6 // Very smooth gradients: wide filter (13 pixels)
maxDiff <= 30 -> 4 // Moderate gradients: medium filter (9 pixels)
maxDiff <= 60 -> 3 // Sharp transitions: narrow filter (7 pixels)
else -> 2 // Very sharp edges: minimal filter (5 pixels)
}
for (dy in -filterRadius..filterRadius) {
val y = seamY + dy
if (y in 0 until height) {
val offset = (y * width + x) * 3L
val currentR = vm.peek(rgbAddr + offset).toInt() and 0xFF
val currentG = vm.peek(rgbAddr + offset + 1).toInt() and 0xFF
val currentB = vm.peek(rgbAddr + offset + 2).toInt() and 0xFF
var sumR = 0.0f
var sumG = 0.0f
var sumB = 0.0f
var weightSum = 0.0f
// Bilateral filtering with spatial and intensity weights
for (sy in maxOf(0, y-filterRadius)..minOf(height-1, y+filterRadius)) {
val sOffset = (sy * width + x) * 3L
val sR = vm.peek(rgbAddr + sOffset).toInt() and 0xFF
val sG = vm.peek(rgbAddr + sOffset + 1).toInt() and 0xFF
val sB = vm.peek(rgbAddr + sOffset + 2).toInt() and 0xFF
// Spatial weight (distance from current pixel)
val spatialWeight = 1.0f / (1.0f + abs(sy - y))
// Intensity weight (color similarity)
val colorDiff = sqrt(((sR - currentR) * (sR - currentR) +
(sG - currentG) * (sG - currentG) +
(sB - currentB) * (sB - currentB)).toFloat())
val intensityWeight = exp(-colorDiff / 30.0f)
val totalWeight = spatialWeight * intensityWeight
sumR += sR * totalWeight
sumG += sG * totalWeight
sumB += sB * totalWeight
weightSum += totalWeight
}
if (weightSum > 0) {
val filteredR = (sumR / weightSum).toInt()
val filteredG = (sumG / weightSum).toInt()
val filteredB = (sumB / weightSum).toInt()
// Concentrate blur heavily at the seam boundary
val distance = abs(dy).toFloat()
val blendWeight = when {
distance == 0.0f -> 0.95f // Maximum blur at exact seam
distance == 1.0f -> 0.8f // Strong blur adjacent to seam
distance == 2.0f -> 0.5f // Medium blur 2 pixels away
else -> exp(-distance * distance / 1.5f) * 0.3f // Gentle falloff beyond
}
val finalR = (currentR * (1 - blendWeight) + filteredR * blendWeight).toInt().coerceIn(0, 255)
val finalG = (currentG * (1 - blendWeight) + filteredG * blendWeight).toInt().coerceIn(0, 255)
val finalB = (currentB * (1 - blendWeight) + filteredB * blendWeight).toInt().coerceIn(0, 255)
vm.poke(rgbAddr + offset, finalR.toByte())
vm.poke(rgbAddr + offset + 1, finalG.toByte())
vm.poke(rgbAddr + offset + 2, finalB.toByte())
}
}
}
}
}
}
}
private fun analyzeTextureComplexity(rgbAddr: Long, width: Int, height: Int, centerX: Int, centerY: Int, isVerticalSeam: Boolean): Float {
val radius = 4
var totalVariance = 0.0f
var count = 0
// Calculate variance in a small window around the seam
for (dy in -radius..radius) {
for (dx in -radius..radius) {
val x = centerX + dx
val y = centerY + dy
if (x >= 0 && x < width && y >= 0 && y < height) {
val offset = (y * width + x) * 3L
val r = vm.peek(rgbAddr + offset).toInt() and 0xFF
val g = vm.peek(rgbAddr + offset + 1).toInt() and 0xFF
val b = vm.peek(rgbAddr + offset + 2).toInt() and 0xFF
val luma = 0.299f * r + 0.587f * g + 0.114f * b
// Compare with adjacent pixels to measure local variance
if (x > 0) {
val leftOffset = (y * width + (x-1)) * 3L
val leftR = vm.peek(rgbAddr + leftOffset).toInt() and 0xFF
val leftG = vm.peek(rgbAddr + leftOffset + 1).toInt() and 0xFF
val leftB = vm.peek(rgbAddr + leftOffset + 2).toInt() and 0xFF
val leftLuma = 0.299f * leftR + 0.587f * leftG + 0.114f * leftB
totalVariance += abs(luma - leftLuma)
count++
}
if (y > 0) {
val topOffset = ((y-1) * width + x) * 3L
val topR = vm.peek(rgbAddr + topOffset).toInt() and 0xFF
val topG = vm.peek(rgbAddr + topOffset + 1).toInt() and 0xFF
val topB = vm.peek(rgbAddr + topOffset + 2).toInt() and 0xFF
val topLuma = 0.299f * topR + 0.587f * topG + 0.114f * topB
totalVariance += abs(luma - topLuma)
count++
}
}
}
}
return if (count > 0) totalVariance / count else 0.0f
}
private fun bilinearInterpolate(
dataPtr: Long, width: Int, height: Int,
x: Float, y: Float