TAV: preset implementation

This commit is contained in:
minjaesong
2025-11-24 17:40:45 +09:00
parent 6132012e74
commit 08bb33bf27
6 changed files with 152 additions and 66 deletions

View File

@@ -47,25 +47,19 @@ import kotlin.collections.component2
import kotlin.collections.component3
import kotlin.collections.component4
import kotlin.collections.copyOf
import kotlin.collections.count
import kotlin.collections.fill
import kotlin.collections.first
import kotlin.collections.forEach
import kotlin.collections.forEachIndexed
import kotlin.collections.indices
import kotlin.collections.isNotEmpty
import kotlin.collections.last
import kotlin.collections.listOf
import kotlin.collections.map
import kotlin.collections.maxOfOrNull
import kotlin.collections.mutableListOf
import kotlin.collections.mutableMapOf
import kotlin.collections.set
import kotlin.collections.sliceArray
import kotlin.collections.sorted
import kotlin.collections.sumOf
import kotlin.collections.toFloatArray
import kotlin.collections.toList
import kotlin.error
import kotlin.floatArrayOf
import kotlin.fromBits
@@ -5039,9 +5033,9 @@ class GraphicsJSR223Delegate(private val vm: VM) {
* - Level 1 (tH): 1.0 × 2^0.8 = 1.74
* - Level 2 (tHH): 1.0 × 2^1.6 = 3.03
*/
private fun getTemporalQuantizerScale(temporalLevel: Int): Float {
val BETA = 0.6f // Temporal scaling exponent (aggressive for temporal high-pass)
val KAPPA = 1.14f
private fun getTemporalQuantizerScale(encoderPreset: Int, temporalLevel: Int): Float {
val BETA = if (encoderPreset and 0x01 == 1) 0.0f else 0.6f // Temporal scaling exponent (aggressive for temporal high-pass)
val KAPPA = if (encoderPreset and 0x01 == 1) 1.0f else 1.14f
return 2.0f.pow(BETA * temporalLevel.toFloat().pow(KAPPA))
}
@@ -5177,8 +5171,13 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Remove grain synthesis from DWT coefficients (decoder subtracts noise)
// This must be called AFTER dequantization but BEFORE inverse DWT
private fun removeGrainSynthesisDecoder(coeffs: FloatArray, width: Int, height: Int,
frameNum: Int, subbands: List<DWTSubbandInfo>, qYGlobal: Int) {
private fun tavApplyGrainSynthesis(coeffs: FloatArray, width: Int, height: Int,
frameNum: Int, subbands: List<DWTSubbandInfo>, qYGlobal: Int, encoderPreset: Int = 0) {
// Anime preset: completely disable grain synthesis
if ((encoderPreset and 0x02) != 0) {
return // Skip grain synthesis entirely
}
// Only apply to Y channel, excluding LL band
// Noise amplitude = half of quantization step (scaled by perceptual weight if enabled)
@@ -5220,7 +5219,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// New tavDecode function that accepts compressed data and decompresses internally
fun tavDecodeCompressed(compressedDataPtr: Long, compressedSize: Int, currentRGBAddr: Long, prevRGBAddr: Long,
width: Int, height: Int, qIndex: Int, qYGlobal: Int, qCoGlobal: Int, qCgGlobal: Int, channelLayout: Int,
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1, entropyCoder: Int = 0): HashMap<String, Any> {
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1, entropyCoder: Int = 0, encoderPreset: Int = 0): HashMap<String, Any> {
// Read compressed data from VM memory into byte array
val compressedData = ByteArray(compressedSize)
@@ -5250,7 +5249,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Call the existing tavDecode function with decompressed data
tavDecode(decompressedBuffer.toLong(), currentRGBAddr, prevRGBAddr,
width, height, qIndex, qYGlobal, qCoGlobal, qCgGlobal, channelLayout,
frameCount, waveletFilter, decompLevels, isLossless, tavVersion, entropyCoder)
frameCount, waveletFilter, decompLevels, isLossless, tavVersion, entropyCoder, encoderPreset)
} finally {
// Clean up allocated buffer
@@ -5266,7 +5265,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Original tavDecode function for backward compatibility (now handles decompressed data)
fun tavDecode(blockDataPtr: Long, currentRGBAddr: Long, prevRGBAddr: Long,
width: Int, height: Int, qIndex: Int, qYGlobal: Int, qCoGlobal: Int, qCgGlobal: Int, channelLayout: Int,
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1, entropyCoder: Int = 0): HashMap<String, Any> {
frameCount: Int, waveletFilter: Int = 1, decompLevels: Int = 6, isLossless: Boolean = false, tavVersion: Int = 1, entropyCoder: Int = 0, encoderPreset: Int = 0): HashMap<String, Any> {
val dbgOut = HashMap<String, Any>()
@@ -5328,14 +5327,14 @@ class GraphicsJSR223Delegate(private val vm: VM) {
0x01 -> { // TAV_MODE_INTRA
// Decode DWT coefficients directly to RGB buffer
readPtr = tavDecodeDWTIntraTileRGB(qIndex, qYGlobal, channelLayout, readPtr, tileX, tileY, currentRGBAddr,
width, height, qY, qCo, qCg, entropyCoder,
width, height, qY, qCo, qCg, entropyCoder, encoderPreset,
waveletFilter, decompLevels, isLossless, tavVersion, isMonoblock, frameCount)
dbgOut["frameMode"] = "I"
}
0x02 -> { // TAV_MODE_DELTA (with optional Haar wavelet)
// Coefficient delta encoding for efficient P-frames
readPtr = tavDecodeDeltaTileRGB(readPtr, channelLayout, tileX, tileY, currentRGBAddr,
width, height, qY, qCo, qCg, entropyCoder,
width, height, qY, qCo, qCg, entropyCoder, encoderPreset,
waveletFilter, decompLevels, tavVersion, isMonoblock, frameCount, haarLevel)
dbgOut["frameMode"] = " "
}
@@ -5351,7 +5350,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
private fun tavDecodeDWTIntraTileRGB(qIndex: Int, qYGlobal: Int, channelLayout: Int, readPtr: Long, tileX: Int, tileY: Int, currentRGBAddr: Long,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int, entropyCoder: Int,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int, entropyCoder: Int, encoderPreset: Int,
waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int, isMonoblock: Boolean = false, frameCount: Int): Long {
// Determine coefficient count based on mode
val coeffCount = if (isMonoblock) {
@@ -5451,7 +5450,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT)
// Use perceptual weights since this is the perceptual quantization path
removeGrainSynthesisDecoder(yTile, tileWidth, tileHeight, frameCount, subbands, qYGlobal)
tavApplyGrainSynthesis(yTile, tileWidth, tileHeight, frameCount, subbands, qYGlobal, encoderPreset)
// Apply film grain filter if enabled
// commented; grain synthesis is now a part of the spec
@@ -5476,7 +5475,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
val tileWidth = if (isMonoblock) width else TAV_PADDED_TILE_SIZE_X
val tileHeight = if (isMonoblock) height else TAV_PADDED_TILE_SIZE_Y
val subbands = calculateSubbandLayout(tileWidth, tileHeight, decompLevels)
removeGrainSynthesisDecoder(yTile, tileWidth, tileHeight, frameCount, subbands, qYGlobal)
tavApplyGrainSynthesis(yTile, tileWidth, tileHeight, frameCount, subbands, qYGlobal, encoderPreset)
// Apply film grain filter if enabled
// commented; grain synthesis is now a part of the spec
@@ -5774,7 +5773,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
private fun tavDecodeDeltaTileRGB(readPtr: Long, channelLayout: Int, tileX: Int, tileY: Int, currentRGBAddr: Long,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int, entropyCoder: Int,
width: Int, height: Int, qY: Int, qCo: Int, qCg: Int, entropyCoder: Int, encoderPreset: Int,
spatialFilter: Int, decompLevels: Int, tavVersion: Int, isMonoblock: Boolean = false, frameCount: Int = 0, haarLevel: Int = 0): Long {
val tileIdx = if (isMonoblock) {
@@ -5927,7 +5926,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT)
val subbands = calculateSubbandLayout(tileWidth, tileHeight, decompLevels)
// Delta frames use uniform quantization for the deltas themselves, so no perceptual weights
removeGrainSynthesisDecoder(currentY, tileWidth, tileHeight, frameCount, subbands, qY)
tavApplyGrainSynthesis(currentY, tileWidth, tileHeight, frameCount, subbands, qY, encoderPreset)
// Store current coefficients as previous for next frame
tavPreviousCoeffsY!![tileIdx] = currentY.clone()
@@ -6475,7 +6474,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
temporalLevels: Int = 2,
entropyCoder: Int = 0,
bufferOffset: Long = 0,
temporalMotionCoder: Int = 0
temporalMotionCoder: Int = 0,
encoderPreset: Int = 0
): Array<Any> {
val dbgOut = HashMap<String, Any>()
dbgOut["qY"] = qYGlobal
@@ -6547,9 +6547,9 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// Step 5: Dequantize with temporal-spatial scaling
for (t in 0 until gopSize) {
val temporalLevel = getTemporalSubbandLevel(t, gopSize, temporalLevels)
val temporalScale = getTemporalQuantizerScale(temporalLevel)
val temporalScale = getTemporalQuantizerScale(encoderPreset, temporalLevel)
// CRITICAL FIX: Must ROUND temporal quantizer to match encoder's roundf() behavior
// CRITICAL FIX: Must ROUND temporal quantizer to match encoder's roundf() behaviour
// Encoder (encoder_tav.c:3189): temporal_base_quantiser = (int)roundf(temporal_quantiser)
// Without rounding, decoder uses float values (e.g., 1.516) while encoder used integers (e.g., 2)
// This causes ~24% under-reconstruction for odd baseQ values in temporal high-pass frames (Frame 5+)
@@ -6587,10 +6587,10 @@ class GraphicsJSR223Delegate(private val vm: VM) {
// This must happen after dequantization but before inverse DWT
// Use GOP dimensions (may be cropped)
for (t in 0 until gopSize) {
removeGrainSynthesisDecoder(
tavApplyGrainSynthesis(
gopY[t], gopWidth, gopHeight,
rngFrameTick.getAndAdd(1) + t,
subbands, qIndex
subbands, qIndex, encoderPreset
)
}
@@ -6818,7 +6818,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
temporalLevels: Int = 3,
entropyCoder: Int = 0,
bufferOffset: Long = 0,
temporalMotionCoder: Int = 0
temporalMotionCoder: Int = 0,
encoderPreset: Int = 0
) {
// Cancel any existing decode thread
asyncDecodeThread?.interrupt()
@@ -6836,7 +6837,7 @@ class GraphicsJSR223Delegate(private val vm: VM) {
width, height,
qIndex, qYGlobal, qCoGlobal, qCgGlobal,
channelLayout, spatialFilter, spatialLevels, temporalLevels,
entropyCoder, bufferOffset, temporalMotionCoder
entropyCoder, bufferOffset, temporalMotionCoder, encoderPreset
)
asyncDecodeResult = result
asyncDecodeComplete.set(true)