tav: grain synthesis on the spec

This commit is contained in:
minjaesong
2025-10-08 23:47:54 +09:00
parent 17b5063ef0
commit 1a072f6a0c
4 changed files with 212 additions and 61 deletions

View File

@@ -4446,12 +4446,83 @@ class GraphicsJSR223Delegate(private val vm: VM) {
private val tavDebugFrameTarget = -1 // use negative number to disable the debug print
private var tavDebugCurrentFrameNumber = 0
// ==============================================================================
// Grain Synthesis Functions (must match encoder implementation)
// ==============================================================================
// Stateless RNG for grain synthesis (matches C encoder implementation)
private inline fun tavGrainSynthesisRNG(frame: UInt, band: UInt, x: UInt, y: UInt): UInt {
val key = frame * 0x9e3779b9u xor band * 0x7f4a7c15u xor (y shl 16) xor x
// rng_hash implementation
var hash = key
hash = hash xor (hash shr 16)
hash = hash * 0x7feb352du
hash = hash xor (hash shr 15)
hash = hash * 0x846ca68bu
hash = hash xor (hash shr 16)
return hash
}
// Generate triangular noise from uint32 RNG (returns value in range [-1.0, 1.0])
private inline fun tavGrainTriangularNoise(rngVal: UInt): Float {
// Get two uniform random values in [0, 1]
val u1 = (rngVal and 0xFFFFu).toFloat() / 65535.0f
val u2 = ((rngVal shr 16) and 0xFFFFu).toFloat() / 65535.0f
// Convert to range [-1, 1] and average for triangular distribution
return (u1 + u2) - 1.0f
}
// Remove grain synthesis from DWT coefficients (decoder subtracts noise)
// This must be called AFTER dequantization but BEFORE inverse DWT
private fun removeGrainSynthesisDecoder(coeffs: FloatArray, width: Int, height: Int,
decompLevels: Int, frameNum: Int, quantiser: Float,
subbands: List<DWTSubbandInfo>, qIndex: Int = 3, qYGlobal: Int = 0,
usePerceptualWeights: Boolean = false) {
// Only apply to Y channel, excluding LL band
// Noise amplitude = half of quantization step (scaled by perceptual weight if enabled)
// Process each subband (skip LL which is level 0)
for (subband in subbands) {
if (subband.level == 0) continue // Skip LL band
// Calculate perceptual weight for this subband if perceptual mode is enabled
/*val perceptualWeight = if (usePerceptualWeights) {
getPerceptualWeight(qIndex, qYGlobal, subband.level, subband.subbandType, false, decompLevels)
} else {
1.0f
}
// Noise amplitude for this subband
val noiseAmplitude = (quantiser * perceptualWeight) * 0.5f*/
val noiseAmplitude = quantiser.coerceAtMost(32f) * 0.5f
// Remove noise from each coefficient in this subband
for (i in 0 until subband.coeffCount) {
val idx = subband.coeffStart + i
if (idx < coeffs.size) {
// Calculate 2D position from linear index
val y = idx / width
val x = idx % width
// Generate same deterministic noise as encoder
val rngVal = tavGrainSynthesisRNG(frameNum.toUInt(), (subband.level + subband.subbandType * 31 + 16777619).toUInt(), x.toUInt(), y.toUInt())
val noise = tavGrainTriangularNoise(rngVal)
// Subtract noise from coefficient
coeffs[idx] -= noise * noiseAmplitude
}
}
}
}
private val TAV_QLUT = intArrayOf(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108,110,112,114,116,118,120,122,124,126,128,132,136,140,144,148,152,156,160,164,168,172,176,180,184,188,192,196,200,204,208,212,216,220,224,228,232,236,240,244,248,252,256,264,272,280,288,296,304,312,320,328,336,344,352,360,368,376,384,392,400,408,416,424,432,440,448,456,464,472,480,488,496,504,512,528,544,560,576,592,608,624,640,656,672,688,704,720,736,752,768,784,800,816,832,848,864,880,896,912,928,944,960,976,992,1008,1024,1056,1088,1120,1152,1184,1216,1248,1280,1312,1344,1376,1408,1440,1472,1504,1536,1568,1600,1632,1664,1696,1728,1760,1792,1824,1856,1888,1920,1952,1984,2016,2048,2112,2176,2240,2304,2368,2432,2496,2560,2624,2688,2752,2816,2880,2944,3008,3072,3136,3200,3264,3328,3392,3456,3520,3584,3648,3712,3776,3840,3904,3968,4032,4096)
// New tavDecode function that accepts compressed data and decompresses internally
fun tavDecodeCompressed(compressedDataPtr: Long, compressedSize: Int, currentRGBAddr: Long, prevRGBAddr: Long,
width: Int, height: Int, qIndex: Int, qYGlobal: Int, qCoGlobal: Int, qCgGlobal: Int, channelLayout: Int,
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1, filmGrainLevel: Int = 0): HashMap<String, Any> {
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1): HashMap<String, Any> {
// Read compressed data from VM memory into byte array
val compressedData = ByteArray(compressedSize)
@@ -4481,7 +4552,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Call the existing tavDecode function with decompressed data
tavDecode(decompressedBuffer.toLong(), currentRGBAddr, prevRGBAddr,
width, height, qIndex, qYGlobal, qCoGlobal, qCgGlobal, channelLayout,
frameCount, waveletFilter, decompLevels, isLossless, tavVersion, filmGrainLevel)
frameCount, waveletFilter, decompLevels, isLossless, tavVersion)
} finally {
// Clean up allocated buffer
@@ -4497,7 +4568,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Original tavDecode function for backward compatibility (now handles decompressed data)
fun tavDecode(blockDataPtr: Long, currentRGBAddr: Long, prevRGBAddr: Long,
width: Int, height: Int, qIndex: Int, qYGlobal: Int, qCoGlobal: Int, qCgGlobal: Int, channelLayout: Int,
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1, filmGrainLevel: Int = 0): HashMap<String, Any> {
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1): HashMap<String, Any> {
val dbgOut = HashMap<String, Any>()
@@ -4554,14 +4625,14 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Decode DWT coefficients directly to RGB buffer
readPtr = tavDecodeDWTIntraTileRGB(qIndex, qYGlobal, channelLayout, readPtr, tileX, tileY, currentRGBAddr,
width, height, qY, qCo, qCg,
waveletFilter, decompLevels, isLossless, tavVersion, isMonoblock, filmGrainLevel)
waveletFilter, decompLevels, isLossless, tavVersion, isMonoblock, frameCount)
dbgOut["frameMode"] = " "
}
0x02 -> { // TAV_MODE_DELTA
// Coefficient delta encoding for efficient P-frames
readPtr = tavDecodeDeltaTileRGB(readPtr, channelLayout, tileX, tileY, currentRGBAddr,
width, height, qY, qCo, qCg,
waveletFilter, decompLevels, isLossless, tavVersion, isMonoblock, filmGrainLevel)
waveletFilter, decompLevels, isLossless, tavVersion, isMonoblock, frameCount)
dbgOut["frameMode"] = " "
}
}
@@ -4577,7 +4648,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
private fun tavDecodeDWTIntraTileRGB(qIndex: Int, qYGlobal: Int, channelLayout: Int, readPtr: Long, tileX: Int, tileY: Int, currentRGBAddr: Long,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int,
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int, isMonoblock: Boolean = false, filmGrainLevel: Int = 0): Long {
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int, isMonoblock: Boolean = false, frameCount: Int): Long {
// Determine coefficient count based on mode
val coeffCount = if (isMonoblock) {
// Monoblock mode: entire frame
@@ -4678,15 +4749,20 @@ class GraphicsJSR223Delegate(private val vm: VM) {
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedCo, coTile, subbands, qCo.toFloat(), true, decompLevels)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedCg, cgTile, subbands, qCg.toFloat(), true, decompLevels)
// Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT)
// Use perceptual weights since this is the perceptual quantization path
removeGrainSynthesisDecoder(yTile, tileWidth, tileHeight, decompLevels, frameCount, qY.toFloat(), subbands, qIndex, qYGlobal, true)
// Apply film grain filter if enabled
if (filmGrainLevel > 0) {
// commented; grain synthesis is now a part of the spec
/*if (filmGrainLevel > 0) {
val random = java.util.Random()
for (i in 0 until coeffCount) {
yTile[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
// coTile[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
// cgTile[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
}
}
}*/
// Debug: Check coefficient values before inverse DWT
if (tavDebugCurrentFrameNumber == tavDebugFrameTarget) {
@@ -4744,15 +4820,22 @@ class GraphicsJSR223Delegate(private val vm: VM) {
cgTile[i] = quantisedCg[i] * qCg.toFloat()
}
// Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT)
val tileWidth = if (isMonoblock) width else TAV_PADDED_TILE_SIZE_X
val tileHeight = if (isMonoblock) height else TAV_PADDED_TILE_SIZE_Y
val subbands = calculateSubbandLayout(tileWidth, tileHeight, decompLevels)
removeGrainSynthesisDecoder(yTile, tileWidth, tileHeight, decompLevels, frameCount, qY.toFloat(), subbands)
// Apply film grain filter if enabled
if (filmGrainLevel > 0) {
// commented; grain synthesis is now a part of the spec
/*if (filmGrainLevel > 0) {
val random = java.util.Random()
for (i in 0 until coeffCount) {
yTile[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
// coTile[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
// cgTile[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
}
}
}*/
// Debug: Uniform quantisation subband analysis for comparison
if (tavDebugCurrentFrameNumber == tavDebugFrameTarget) {
@@ -5160,48 +5243,6 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
}
// Delta-specific perceptual weight model for motion-optimized coefficient reconstruction
private fun getPerceptualWeightDelta(qualityLevel: Int, level: Int, subbandType: Int, isChroma: Boolean, maxLevels: Int): Float {
// Delta coefficients have different perceptual characteristics than full-picture coefficients:
// 1. Motion edges are more perceptually critical than static edges
// 2. Temporal masking allows more aggressive quantisation in high-motion areas
// 3. Smaller delta magnitudes make relative quantisation errors more visible
// 4. Frequency distribution is motion-dependent rather than spatial-dependent
return if (!isChroma) {
// LUMA DELTA CHANNEL: Emphasize motion coherence and edge preservation
when (subbandType) {
0 -> { // LL subband - DC motion changes, still important
// DC motion changes - preserve somewhat but allow coarser quantisation than full-picture
2f // Slightly coarser than full-picture
}
1 -> { // LH subband - horizontal motion edges
// Motion boundaries benefit from temporal masking - allow coarser quantisation
0.9f
}
2 -> { // HL subband - vertical motion edges
// Vertical motion boundaries - equal treatment with horizontal for deltas
1.2f
}
else -> { // HH subband - diagonal motion details
// Diagonal motion deltas can be quantised most aggressively
0.5f
}
}
} else {
// CHROMA DELTA CHANNELS: More aggressive quantisation allowed due to temporal masking
// Motion chroma changes are less perceptually critical than static chroma
val base = getPerceptualModelChromaBase(qualityLevel, level - 1)
when (subbandType) {
0 -> 1.3f // LL chroma deltas - more aggressive than full-picture chroma
1 -> kotlin.math.max(1.2f, kotlin.math.min(120.0f, base * 1.4f)) // LH chroma deltas
2 -> kotlin.math.max(1.4f, kotlin.math.min(140.0f, base * 1.6f)) // HL chroma deltas
else -> kotlin.math.max(1.6f, kotlin.math.min(160.0f, base * 1.8f)) // HH chroma deltas
}
}
}
private fun getPerceptualModelChromaBase(qualityLevel: Int, level: Int): Float {
// Simplified chroma base curve
return 1.0f - (1.0f / (0.5f * qualityLevel * qualityLevel + 1.0f)) * (level - 4.0f)
@@ -5209,7 +5250,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
private fun tavDecodeDeltaTileRGB(readPtr: Long, channelLayout: Int, tileX: Int, tileY: Int, currentRGBAddr: Long,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int,
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int, isMonoblock: Boolean = false, filmGrainLevel: Int = 0): Long {
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int, isMonoblock: Boolean = false, frameCount: Int = 0): Long {
val tileIdx = if (isMonoblock) {
0 // Single tile index for monoblock
@@ -5326,15 +5367,23 @@ class GraphicsJSR223Delegate(private val vm: VM) {
currentCg[i] = prevCg[i] + (deltaCg[i].toFloat() * qCg)
}
// Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT)
val tileWidth = if (isMonoblock) width else TAV_PADDED_TILE_SIZE_X
val tileHeight = if (isMonoblock) height else TAV_PADDED_TILE_SIZE_Y
val subbands = calculateSubbandLayout(tileWidth, tileHeight, decompLevels)
// Delta frames use uniform quantization for the deltas themselves, so no perceptual weights
removeGrainSynthesisDecoder(currentY, tileWidth, tileHeight, decompLevels, frameCount, qY.toFloat(), subbands)
// Apply film grain filter if enabled
if (filmGrainLevel > 0) {
// commented; grain synthesis is now a part of the spec
/*if (filmGrainLevel > 0) {
val random = java.util.Random()
for (i in 0 until coeffCount) {
currentY[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
// currentCo[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
// currentCg[i] += (random.nextInt(filmGrainLevel * 2 + 1) - filmGrainLevel).toFloat()
}
}
}*/
// Store current coefficients as previous for next frame
tavPreviousCoeffsY!![tileIdx] = currentY.clone()
@@ -5342,9 +5391,6 @@ class GraphicsJSR223Delegate(private val vm: VM) {
tavPreviousCoeffsCg!![tileIdx] = currentCg.clone()
// Apply inverse DWT
val tileWidth = if (isMonoblock) width else TAV_PADDED_TILE_SIZE_X
val tileHeight = if (isMonoblock) height else TAV_PADDED_TILE_SIZE_Y
if (isLossless) {
tavApplyDWTInverseMultiLevel(currentY, tileWidth, tileHeight, decompLevels, 0, TavSharpenLuma)
tavApplyDWTInverseMultiLevel(currentCo, tileWidth, tileHeight, decompLevels, 0, TavNullFilter)