From 1a0b754cde398603675fe578de8b3abfcfdc6be7 Mon Sep 17 00:00:00 2001 From: minjaesong Date: Sun, 15 Mar 2026 13:26:05 +0900 Subject: [PATCH] font update, video sprite (TAV) --- buildapp/instructions.md | 2 - lib/TerrarumSansBitmap.jar | 4 +- .../spriteanimation/VideoSpriteAnimation.kt | 119 +++ src/net/torvald/terrarum/tav/AudioBankTav.kt | 36 + src/net/torvald/terrarum/tav/DwtUtil.kt | 276 +++++++ src/net/torvald/terrarum/tav/EzbcDecode.kt | 248 +++++++ src/net/torvald/terrarum/tav/TadDecode.kt | 192 +++++ src/net/torvald/terrarum/tav/TavDecoder.kt | 456 ++++++++++++ .../torvald/terrarum/tav/TavVideoDecode.kt | 685 ++++++++++++++++++ 9 files changed, 2014 insertions(+), 4 deletions(-) create mode 100644 src/net/torvald/spriteanimation/VideoSpriteAnimation.kt create mode 100644 src/net/torvald/terrarum/tav/AudioBankTav.kt create mode 100644 src/net/torvald/terrarum/tav/DwtUtil.kt create mode 100644 src/net/torvald/terrarum/tav/EzbcDecode.kt create mode 100644 src/net/torvald/terrarum/tav/TadDecode.kt create mode 100644 src/net/torvald/terrarum/tav/TavDecoder.kt create mode 100644 src/net/torvald/terrarum/tav/TavVideoDecode.kt diff --git a/buildapp/instructions.md b/buildapp/instructions.md index 3a0e817a9..bcc0a7704 100644 --- a/buildapp/instructions.md +++ b/buildapp/instructions.md @@ -30,8 +30,6 @@ This process assumes that the game does NOT use the Java 9+ modules and every si The Linux Aarch64 runtime must be prepared on the actual ARM Linux session. -Copy the runtimes to your workstation, rename the `bin/java` into `bin/Terrarum`, then `chmod -R +x` all of them. - ### Packaging Create an output directory if there is none (project root/buildapp/out) diff --git a/lib/TerrarumSansBitmap.jar b/lib/TerrarumSansBitmap.jar index 0da3d8e4b..94ed4f7c9 100644 --- a/lib/TerrarumSansBitmap.jar +++ b/lib/TerrarumSansBitmap.jar @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ea84912ae6fb0554273af23f6d9aba67e95999e3332aca101283b819b767e454 -size 2113816 +oid sha256:b99d2d0129bfc309fcd4a485c5f2f29eb926c62e460d8d8260f7dc084513dee4 +size 2181259 diff --git a/src/net/torvald/spriteanimation/VideoSpriteAnimation.kt b/src/net/torvald/spriteanimation/VideoSpriteAnimation.kt new file mode 100644 index 000000000..f0cbe5573 --- /dev/null +++ b/src/net/torvald/spriteanimation/VideoSpriteAnimation.kt @@ -0,0 +1,119 @@ +package net.torvald.spriteanimation + +import com.badlogic.gdx.graphics.Color +import com.badlogic.gdx.graphics.Texture +import com.badlogic.gdx.graphics.g2d.SpriteBatch +import com.jme3.math.FastMath +import net.torvald.terrarum.Second +import net.torvald.terrarum.gameactors.ActorWithBody +import net.torvald.terrarum.tav.AudioBankTav +import net.torvald.terrarum.tav.TavDecoder +import java.io.InputStream + +/** + * A SpriteAnimation that plays a TAV video file. + * + * Usage: + * val anim = VideoSpriteAnimation(actor, stream, looping = true) + * actor.sprite = anim + * anim.start() + * // Optionally route audio: + * anim.audioBank?.let { bank -> + * val track = App.audioMixer.getFreeTrackNoMatterWhat() + * track.currentTrack = bank + * track.play() + * } + */ +class VideoSpriteAnimation( + parentActor: ActorWithBody, + tavStream: InputStream, + val looping: Boolean = false +) : SpriteAnimation(parentActor) { + + val decoder = TavDecoder(tavStream, looping) + val audioBank: AudioBankTav? = if (decoder.hasAudio) AudioBankTav(decoder) else null + + val cellWidth: Int get() = decoder.videoWidth + val cellHeight: Int get() = decoder.videoHeight + + override val currentDelay: Second get() = 1f / decoder.fps.coerceAtLeast(1) + + private var currentTexture: Texture? = null + private var deltaAccumulator = 0f + private var started = false + + val isFinished: Boolean get() = decoder.isFinished.get() + + fun start() { + decoder.start() + started = true + } + + fun stop() { + decoder.stop() + started = false + } + + // update() is a no-op: frame timing is handled in render() via frameDelta + override fun update(delta: Float) {} + + override fun render( + frameDelta: Float, + batch: SpriteBatch, + posX: Float, + posY: Float, + scale: Float, + mode: Int, + forcedColourFilter: Color? + ) { + if (!started) return + + // Advance frame timing + deltaAccumulator += frameDelta + while (deltaAccumulator >= currentDelay) { + decoder.advanceFrame() + deltaAccumulator -= currentDelay + } + + val pixmap = decoder.getFramePixmap() ?: return + + // Dispose old texture and create new one from the current decoded Pixmap + currentTexture?.dispose() + currentTexture = Texture(pixmap).also { + it.setFilter(Texture.TextureFilter.Nearest, Texture.TextureFilter.Nearest) + } + + batch.color = forcedColourFilter ?: colourFilter + + val w = cellWidth + val h = cellHeight + + val tx = (parentActor.hitboxTranslateX) * scale + val txF = (parentActor.hitboxTranslateX + parentActor.baseHitboxW) * scale + val ty = (parentActor.hitboxTranslateY + (h - parentActor.baseHitboxH)) * scale + val tyF = (parentActor.hitboxTranslateY + parentActor.baseHitboxH) * scale + + val tex = currentTexture!! + val x0 = FastMath.floor(posX).toFloat() + val y0 = FastMath.floor(posY).toFloat() + val fw = FastMath.floor(w * scale).toFloat() + val fh = FastMath.floor(h * scale).toFloat() + + if (flipHorizontal && flipVertical) { + batch.draw(tex, x0 + txF, y0 + tyF, -fw, -fh) + } else if (flipHorizontal && !flipVertical) { + batch.draw(tex, x0 + txF, y0 - ty, -fw, fh) + } else if (!flipHorizontal && flipVertical) { + batch.draw(tex, x0 - tx, y0 + tyF, fw, -fh) + } else { + batch.draw(tex, x0 - tx, y0 - ty, fw, fh) + } + } + + override fun dispose() { + stop() + decoder.dispose() + currentTexture?.dispose() + currentTexture = null + } +} diff --git a/src/net/torvald/terrarum/tav/AudioBankTav.kt b/src/net/torvald/terrarum/tav/AudioBankTav.kt new file mode 100644 index 000000000..72c99e518 --- /dev/null +++ b/src/net/torvald/terrarum/tav/AudioBankTav.kt @@ -0,0 +1,36 @@ +package net.torvald.terrarum.tav + +import net.torvald.terrarum.audio.AudioBank + +/** + * AudioBank adapter wrapping a TavDecoder's audio ring buffer. + * Reports 32000 Hz sampling rate; the audio pipeline resamples to 48000 Hz automatically. + * Lifecycle is managed by VideoSpriteAnimation — dispose() is a no-op here. + */ +class AudioBankTav( + private val decoder: TavDecoder, + override var songFinishedHook: (AudioBank) -> Unit = {} +) : AudioBank() { + + override val notCopyable = true + override val name = "tav-audio" + + /** TAD native sample rate; AudioProcessBuf resamples to 48000 Hz. */ + override var samplingRate = 32000f + override var channels = 2 + + override var totalSizeInSamples: Long = + decoder.totalFrames * (32000L / decoder.fps.coerceAtLeast(1)) + + override fun readSamples(bufferL: FloatArray, bufferR: FloatArray): Int = + decoder.readAudioSamples(bufferL, bufferR) + + override fun currentPositionInSamples(): Long = decoder.audioReadPos.get() + + override fun reset() { /* reset is handled at decoder level by VideoSpriteAnimation */ } + + override fun makeCopy(): AudioBank = throw UnsupportedOperationException("AudioBankTav is not copyable") + + /** Lifecycle managed by VideoSpriteAnimation; do not dispose the decoder here. */ + override fun dispose() {} +} diff --git a/src/net/torvald/terrarum/tav/DwtUtil.kt b/src/net/torvald/terrarum/tav/DwtUtil.kt new file mode 100644 index 000000000..74166ac74 --- /dev/null +++ b/src/net/torvald/terrarum/tav/DwtUtil.kt @@ -0,0 +1,276 @@ +package net.torvald.terrarum.tav + +/** + * Shared DWT (Discrete Wavelet Transform) utility functions. + * Provides inverse CDF 9/7, CDF 5/3, and Haar transforms used by both + * video and audio decoders. + * + * Ported from GraphicsJSR223Delegate.kt and AudioAdapter.kt in the TSVM project. + */ +object DwtUtil { + + // CDF 9/7 lifting constants + private const val ALPHA = -1.586134342f + private const val BETA = -0.052980118f + private const val GAMMA = 0.882911076f + private const val DELTA = 0.443506852f + private const val K = 1.230174105f + + // ------------------------------------------------------------------------- + // 1D Transforms + // ------------------------------------------------------------------------- + + /** + * Single-level 1D CDF 9/7 inverse lifting transform. + * Layout: first half = low-pass coefficients, second half = high-pass. + */ + fun inverse1D(data: FloatArray, length: Int) { + if (length < 2) return + + val temp = FloatArray(length) + val half = (length + 1) / 2 + + for (i in 0 until half) { + temp[i] = data[i] + } + for (i in 0 until length / 2) { + if (half + i < length) temp[half + i] = data[half + i] + } + + // Step 1: Undo scaling + for (i in 0 until half) temp[i] /= K + for (i in 0 until length / 2) { + if (half + i < length) temp[half + i] *= K + } + + // Step 2: Undo delta update + for (i in 0 until half) { + val dCurr = if (half + i < length) temp[half + i] else 0.0f + val dPrev = if (i > 0 && half + i - 1 < length) temp[half + i - 1] else dCurr + temp[i] -= DELTA * (dCurr + dPrev) + } + + // Step 3: Undo gamma predict + for (i in 0 until length / 2) { + if (half + i < length) { + val sCurr = temp[i] + val sNext = if (i + 1 < half) temp[i + 1] else sCurr + temp[half + i] -= GAMMA * (sCurr + sNext) + } + } + + // Step 4: Undo beta update + for (i in 0 until half) { + val dCurr = if (half + i < length) temp[half + i] else 0.0f + val dPrev = if (i > 0 && half + i - 1 < length) temp[half + i - 1] else dCurr + temp[i] -= BETA * (dCurr + dPrev) + } + + // Step 5: Undo alpha predict + for (i in 0 until length / 2) { + if (half + i < length) { + val sCurr = temp[i] + val sNext = if (i + 1 < half) temp[i + 1] else sCurr + temp[half + i] -= ALPHA * (sCurr + sNext) + } + } + + // Interleave reconstruction + for (i in 0 until length) { + if (i % 2 == 0) { + data[i] = temp[i / 2] + } else { + val idx = i / 2 + data[i] = if (half + idx < length) temp[half + idx] else 0.0f + } + } + } + + /** + * Multi-level 1D CDF 9/7 inverse transform. + * Uses exact forward-transform lengths in reverse to handle non-power-of-2 sizes. + */ + fun inverseMultilevel1D(data: FloatArray, length: Int, levels: Int) { + val lengths = IntArray(levels + 1) + lengths[0] = length + for (i in 1..levels) lengths[i] = (lengths[i - 1] + 1) / 2 + + for (level in levels - 1 downTo 0) { + inverse1D(data, lengths[level]) + } + } + + /** + * Single-level 2D CDF 9/7 inverse transform. + * Column inverse first, then row inverse (matching encoder's row-then-column forward order). + */ + fun inverse2D(data: FloatArray, width: Int, height: Int, currentWidth: Int, currentHeight: Int) { + val maxSize = maxOf(width, height) + val tempBuf = FloatArray(maxSize) + + // Column inverse transform (vertical) + for (x in 0 until currentWidth) { + for (y in 0 until currentHeight) tempBuf[y] = data[y * width + x] + inverse1D(tempBuf, currentHeight) + for (y in 0 until currentHeight) data[y * width + x] = tempBuf[y] + } + + // Row inverse transform (horizontal) + for (y in 0 until currentHeight) { + for (x in 0 until currentWidth) tempBuf[x] = data[y * width + x] + inverse1D(tempBuf, currentWidth) + for (x in 0 until currentWidth) data[y * width + x] = tempBuf[x] + } + } + + /** + * Multi-level 2D CDF 9/7 inverse transform. + * Uses exact forward-transform dimension sequences. + */ + fun inverseMultilevel2D(data: FloatArray, width: Int, height: Int, levels: Int, + filterType: Int = 1) { + val widths = IntArray(levels + 1) + val heights = IntArray(levels + 1) + widths[0] = width + heights[0] = height + for (i in 1..levels) { + widths[i] = (widths[i - 1] + 1) / 2 + heights[i] = (heights[i - 1] + 1) / 2 + } + + val maxSize = maxOf(width, height) + val tempBuf = FloatArray(maxSize) + + for (level in levels - 1 downTo 0) { + val cw = widths[level] + val ch = heights[level] + if (cw < 1 || ch < 1 || (cw == 1 && ch == 1)) continue + + // Column inverse + for (x in 0 until cw) { + for (y in 0 until ch) tempBuf[y] = data[y * width + x] + applyInverse1DByFilter(tempBuf, ch, filterType) + for (y in 0 until ch) data[y * width + x] = tempBuf[y] + } + + // Row inverse + for (y in 0 until ch) { + for (x in 0 until cw) tempBuf[x] = data[y * width + x] + applyInverse1DByFilter(tempBuf, cw, filterType) + for (x in 0 until cw) data[y * width + x] = tempBuf[x] + } + } + } + + private fun applyInverse1DByFilter(data: FloatArray, length: Int, filterType: Int) { + when (filterType) { + 0 -> dwt53Inverse1D(data, length) + 1 -> inverse1D(data, length) + 255 -> haarInverse1D(data, length) + else -> inverse1D(data, length) + } + } + + // ------------------------------------------------------------------------- + // Haar 1D Inverse + // ------------------------------------------------------------------------- + + fun haarInverse1D(data: FloatArray, length: Int) { + if (length < 2) return + + val temp = FloatArray(length) + val half = (length + 1) / 2 + + for (i in 0 until half) { + if (2 * i + 1 < length) { + temp[2 * i] = data[i] + data[half + i] + temp[2 * i + 1] = data[i] - data[half + i] + } else { + temp[2 * i] = data[i] + } + } + + for (i in 0 until length) data[i] = temp[i] + } + + // ------------------------------------------------------------------------- + // CDF 5/3 1D Inverse + // ------------------------------------------------------------------------- + + fun dwt53Inverse1D(data: FloatArray, length: Int) { + if (length < 2) return + + val temp = FloatArray(length) + val half = (length + 1) / 2 + + System.arraycopy(data, 0, temp, 0, length) + + // Undo update step (low-pass) + for (i in 0 until half) { + val update = 0.25f * ((if (i > 0) temp[half + i - 1] else 0.0f) + + (if (i < half - 1) temp[half + i] else 0.0f)) + temp[i] -= update + } + + // Undo predict step and interleave + for (i in 0 until half) { + data[2 * i] = temp[i] + val idx = 2 * i + 1 + if (idx < length) { + val pred = 0.5f * (temp[i] + (if (i < half - 1) temp[i + 1] else temp[i])) + data[idx] = temp[half + i] + pred + } + } + } + + // ------------------------------------------------------------------------- + // Temporal inverse DWT (used for GOP decode) + // ------------------------------------------------------------------------- + + /** + * Apply inverse temporal 1D DWT using Haar or CDF 5/3. + * @param temporalMotionCoder 0=Haar, 1=CDF 5/3 + */ + fun temporalInverse1D(data: FloatArray, numFrames: Int, temporalMotionCoder: Int = 0) { + if (numFrames < 2) return + if (temporalMotionCoder == 0) haarInverse1D(data, numFrames) + else dwt53Inverse1D(data, numFrames) + } + + /** + * Apply inverse 3D DWT to GOP data (spatial inverse first, then temporal inverse). + */ + fun inverseMultilevel3D( + gopData: Array, + width: Int, height: Int, numFrames: Int, + spatialLevels: Int, temporalLevels: Int, + spatialFilter: Int = 1, temporalMotionCoder: Int = 0 + ) { + // Step 1: Inverse spatial 2D DWT on each temporal frame + for (t in 0 until numFrames) { + inverseMultilevel2D(gopData[t], width, height, spatialLevels, spatialFilter) + } + + if (numFrames < 2) return + + // Step 2: Inverse temporal DWT at each spatial location + val temporalLengths = IntArray(temporalLevels + 1) + temporalLengths[0] = numFrames + for (i in 1..temporalLevels) temporalLengths[i] = (temporalLengths[i - 1] + 1) / 2 + + val temporalLine = FloatArray(numFrames) + for (y in 0 until height) { + for (x in 0 until width) { + val pixelIdx = y * width + x + for (t in 0 until numFrames) temporalLine[t] = gopData[t][pixelIdx] + + for (level in temporalLevels - 1 downTo 0) { + val levelFrames = temporalLengths[level] + if (levelFrames >= 2) temporalInverse1D(temporalLine, levelFrames, temporalMotionCoder) + } + + for (t in 0 until numFrames) gopData[t][pixelIdx] = temporalLine[t] + } + } + } +} diff --git a/src/net/torvald/terrarum/tav/EzbcDecode.kt b/src/net/torvald/terrarum/tav/EzbcDecode.kt new file mode 100644 index 000000000..084173a48 --- /dev/null +++ b/src/net/torvald/terrarum/tav/EzbcDecode.kt @@ -0,0 +1,248 @@ +package net.torvald.terrarum.tav + +/** + * EZBC (Embedded Zero Block Coding) entropy decoder. + * Provides both 2D (video) and 1D (audio) variants. + * + * Ported from GraphicsJSR223Delegate.kt and AudioAdapter.kt in the TSVM project. + */ +object EzbcDecode { + + // ------------------------------------------------------------------------- + // Shared bitstream reader + // ------------------------------------------------------------------------- + + private class BitstreamReader(private val data: ByteArray, private val startOffset: Int, private val size: Int) { + private var bytePos = startOffset + private var bitPos = 0 + private val endPos = startOffset + size + + fun readBit(): Int { + if (bytePos >= endPos) return 0 + val bit = (data[bytePos].toInt() shr bitPos) and 1 + bitPos++ + if (bitPos == 8) { bitPos = 0; bytePos++ } + return bit + } + + fun readBits(numBits: Int): Int { + var value = 0 + for (i in 0 until numBits) value = value or (readBit() shl i) + return value + } + + fun bytesConsumed(): Int = (bytePos - startOffset) + if (bitPos > 0) 1 else 0 + } + + // ------------------------------------------------------------------------- + // 2D EZBC decode (video coefficients, ShortArray output) + // ------------------------------------------------------------------------- + + /** + * Decode a single EZBC channel (2D variant for video). + * Header: 8-bit MSB bitplane, 16-bit width, 16-bit height. + */ + fun decode2DChannel(ezbcData: ByteArray, offset: Int, size: Int, outputCoeffs: ShortArray) { + val bs = BitstreamReader(ezbcData, offset, size) + + val msbBitplane = bs.readBits(8) + val width = bs.readBits(16) + val height = bs.readBits(16) + + if (width * height != outputCoeffs.size) { + System.err.println("[EZBC-2D] Dimension mismatch: ${width}x${height} != ${outputCoeffs.size}") + return + } + + outputCoeffs.fill(0) + + val significant = BooleanArray(outputCoeffs.size) + + data class Block(val x: Int, val y: Int, val w: Int, val h: Int) + + var insignificantQueue = ArrayList() + var nextInsignificant = ArrayList() + var significantQueue = ArrayList() + var nextSignificant = ArrayList() + + insignificantQueue.add(Block(0, 0, width, height)) + + fun processSignificantBlockRecursive(block: Block, bitplane: Int, threshold: Int) { + if (block.w == 1 && block.h == 1) { + val idx = block.y * width + block.x + val signBit = bs.readBit() + outputCoeffs[idx] = (if (signBit == 1) -threshold else threshold).toShort() + significant[idx] = true + nextSignificant.add(block) + return + } + + var midX = block.w / 2; if (midX == 0) midX = 1 + var midY = block.h / 2; if (midY == 0) midY = 1 + + // Top-left + val tl = Block(block.x, block.y, midX, midY) + if (bs.readBit() == 1) processSignificantBlockRecursive(tl, bitplane, threshold) + else nextInsignificant.add(tl) + + // Top-right + if (block.w > midX) { + val tr = Block(block.x + midX, block.y, block.w - midX, midY) + if (bs.readBit() == 1) processSignificantBlockRecursive(tr, bitplane, threshold) + else nextInsignificant.add(tr) + } + + // Bottom-left + if (block.h > midY) { + val bl = Block(block.x, block.y + midY, midX, block.h - midY) + if (bs.readBit() == 1) processSignificantBlockRecursive(bl, bitplane, threshold) + else nextInsignificant.add(bl) + } + + // Bottom-right + if (block.w > midX && block.h > midY) { + val br = Block(block.x + midX, block.y + midY, block.w - midX, block.h - midY) + if (bs.readBit() == 1) processSignificantBlockRecursive(br, bitplane, threshold) + else nextInsignificant.add(br) + } + } + + for (bitplane in msbBitplane downTo 0) { + val threshold = 1 shl bitplane + + for (block in insignificantQueue) { + if (bs.readBit() == 0) nextInsignificant.add(block) + else processSignificantBlockRecursive(block, bitplane, threshold) + } + + for (block in significantQueue) { + val idx = block.y * width + block.x + if (bs.readBit() == 1) { + val bitValue = 1 shl bitplane + if (outputCoeffs[idx] < 0) outputCoeffs[idx] = (outputCoeffs[idx] - bitValue).toShort() + else outputCoeffs[idx] = (outputCoeffs[idx] + bitValue).toShort() + } + nextSignificant.add(block) + } + + insignificantQueue = nextInsignificant; nextInsignificant = ArrayList() + significantQueue = nextSignificant; nextSignificant = ArrayList() + } + } + + /** + * Decode all channels from an EZBC block. + * Format: [size_y(4)][ezbc_y][size_co(4)][ezbc_co][size_cg(4)][ezbc_cg]... + */ + fun decode2D( + compressedData: ByteArray, offset: Int, + channelLayout: Int, + outputY: ShortArray?, outputCo: ShortArray?, outputCg: ShortArray?, outputAlpha: ShortArray? + ) { + val hasY = (channelLayout and 4) == 0 + val hasCoCg = (channelLayout and 2) == 0 + val hasAlpha = (channelLayout and 1) != 0 + + var ptr = offset + + fun readSize(): Int { + val b0 = compressedData[ptr ].toInt() and 0xFF + val b1 = compressedData[ptr+1].toInt() and 0xFF + val b2 = compressedData[ptr+2].toInt() and 0xFF + val b3 = compressedData[ptr+3].toInt() and 0xFF + return b0 or (b1 shl 8) or (b2 shl 16) or (b3 shl 24) + } + + if (hasY && outputY != null) { + val sz = readSize(); ptr += 4 + decode2DChannel(compressedData, ptr, sz, outputY); ptr += sz + } + if (hasCoCg && outputCo != null) { + val sz = readSize(); ptr += 4 + decode2DChannel(compressedData, ptr, sz, outputCo); ptr += sz + } + if (hasCoCg && outputCg != null) { + val sz = readSize(); ptr += 4 + decode2DChannel(compressedData, ptr, sz, outputCg); ptr += sz + } + if (hasAlpha && outputAlpha != null) { + val sz = readSize(); ptr += 4 + decode2DChannel(compressedData, ptr, sz, outputAlpha); ptr += sz + } + } + + // ------------------------------------------------------------------------- + // 1D EZBC decode (audio coefficients, ByteArray output) + // ------------------------------------------------------------------------- + + /** + * Decode a single EZBC channel (1D variant for TAD audio). + * Header: 8-bit MSB bitplane, 16-bit coefficient count. + * @return number of bytes consumed from [input] + */ + fun decode1DChannel(input: ByteArray, inputOffset: Int, inputSize: Int, coeffs: ByteArray): Int { + val bs = BitstreamReader(input, inputOffset, inputSize) + + val msbBitplane = bs.readBits(8) + val count = bs.readBits(16) + + coeffs.fill(0) + + data class Block(val start: Int, val length: Int) + + val states = BooleanArray(count) // significant flags + + var insignificantQueue = ArrayList() + var nextInsignificant = ArrayList() + var significantQueue = ArrayList() + var nextSignificant = ArrayList() + + insignificantQueue.add(Block(0, count)) + + fun processSignificantBlockRecursive(block: Block, bitplane: Int) { + if (block.length == 1) { + val idx = block.start + val signBit = bs.readBit() + val absVal = 1 shl bitplane + coeffs[idx] = (if (signBit != 0) -absVal else absVal).toByte() + states[idx] = true + nextSignificant.add(block) + return + } + + val mid = maxOf(1, block.length / 2) + + val left = Block(block.start, mid) + if (bs.readBit() != 0) processSignificantBlockRecursive(left, bitplane) + else nextInsignificant.add(left) + + if (block.length > mid) { + val right = Block(block.start + mid, block.length - mid) + if (bs.readBit() != 0) processSignificantBlockRecursive(right, bitplane) + else nextInsignificant.add(right) + } + } + + for (bitplane in msbBitplane downTo 0) { + for (block in insignificantQueue) { + if (bs.readBit() == 0) nextInsignificant.add(block) + else processSignificantBlockRecursive(block, bitplane) + } + + for (block in significantQueue) { + val idx = block.start + if (bs.readBit() != 0) { + val sign = if (coeffs[idx] < 0) -1 else 1 + val absVal = kotlin.math.abs(coeffs[idx].toInt()) + coeffs[idx] = (sign * (absVal or (1 shl bitplane))).toByte() + } + nextSignificant.add(block) + } + + insignificantQueue = nextInsignificant; nextInsignificant = ArrayList() + significantQueue = nextSignificant; nextSignificant = ArrayList() + } + + return bs.bytesConsumed() + } +} diff --git a/src/net/torvald/terrarum/tav/TadDecode.kt b/src/net/torvald/terrarum/tav/TadDecode.kt new file mode 100644 index 000000000..2e399d9cc --- /dev/null +++ b/src/net/torvald/terrarum/tav/TadDecode.kt @@ -0,0 +1,192 @@ +package net.torvald.terrarum.tav + +import io.airlift.compress.zstd.ZstdInputStream +import java.io.ByteArrayInputStream + +/** + * TAD (TSVM Advanced Audio) decoder. + * Decodes TAD chunks to Float32 stereo PCM at 32000 Hz. + * + * Ported from AudioAdapter.kt in the TSVM project. + */ +object TadDecode { + + // Coefficient scalars per subband (LL + 9 H bands, index 0=LL, 1-9=H bands L9..L1) + private val COEFF_SCALARS = floatArrayOf( + 64.0f, 45.255f, 32.0f, 22.627f, 16.0f, 11.314f, 8.0f, 5.657f, 4.0f, 2.828f + ) + + // Base quantiser weight table: [channel 0=Mid][channel 1=Side] + private val BASE_QUANTISER_WEIGHTS = arrayOf( + floatArrayOf(4.0f, 2.0f, 1.8f, 1.6f, 1.4f, 1.2f, 1.0f, 1.0f, 1.3f, 2.0f), // Mid + floatArrayOf(6.0f, 5.0f, 2.6f, 2.4f, 1.8f, 1.3f, 1.0f, 1.0f, 1.6f, 3.2f) // Side + ) + + private const val LAMBDA_FIXED = 6.0f + private const val DWT_LEVELS = 9 + + /** + * Cross-chunk persistent state for the TAD de-emphasis IIR filter. + */ + class TadDecoderState { + var prevYL: Float = 0.0f + var prevYR: Float = 0.0f + } + + // ------------------------------------------------------------------------- + // Full TAD chunk decode + // ------------------------------------------------------------------------- + + /** + * Decode a single TAD chunk payload. + * Returns Pair(leftSamples, rightSamples) as Float32 in [-1, 1]. + * + * @param payload Zstd-compressed TAD chunk payload + * @param sampleCount samples per channel + * @param maxIndex max quantiser index + * @param state persistent de-emphasis state (mutated in-place) + */ + fun decodeTadChunk( + payload: ByteArray, + sampleCount: Int, + maxIndex: Int, + state: TadDecoderState + ): Pair { + // Step 1: Zstd decompress + val decompressed = ZstdInputStream(ByteArrayInputStream(payload)).use { it.readBytes() } + + // Step 2: EZBC 1D decode Mid and Side channels + val quantMid = ByteArray(sampleCount) + val quantSide = ByteArray(sampleCount) + + val midBytesConsumed = EzbcDecode.decode1DChannel(decompressed, 0, decompressed.size, quantMid) + EzbcDecode.decode1DChannel( + decompressed, midBytesConsumed, decompressed.size - midBytesConsumed, quantSide + ) + + // Step 3 & 4: Lambda decompanding + dequantise + val dwtMid = FloatArray(sampleCount) + val dwtSide = FloatArray(sampleCount) + dequantiseCoeffs(0, quantMid, dwtMid, sampleCount, maxIndex) + dequantiseCoeffs(1, quantSide, dwtSide, sampleCount, maxIndex) + + // Step 5: Inverse CDF 9/7 DWT (9 levels) + DwtUtil.inverseMultilevel1D(dwtMid, sampleCount, DWT_LEVELS) + DwtUtil.inverseMultilevel1D(dwtSide, sampleCount, DWT_LEVELS) + + // Step 6: M/S to L/R + val left = FloatArray(sampleCount) + val right = FloatArray(sampleCount) + msToLR(dwtMid, dwtSide, left, right, sampleCount) + + // Step 7: Gamma expansion + gammaExpand(left, right, sampleCount) + + // Step 8: De-emphasis IIR (persistent state) + deemphasis(left, right, sampleCount, state) + + return Pair(left, right) + } + + // ------------------------------------------------------------------------- + // PCM fallback decoders + // ------------------------------------------------------------------------- + + /** Decode Zstd-compressed interleaved PCMu8 stereo. Returns Float32 L/R. */ + fun decodePcm8(payload: ByteArray): Pair { + val decompressed = ZstdInputStream(ByteArrayInputStream(payload)).use { it.readBytes() } + val sampleCount = decompressed.size / 2 + val left = FloatArray(sampleCount) + val right = FloatArray(sampleCount) + for (i in 0 until sampleCount) { + val l = (decompressed[i * 2 ].toInt() and 0xFF) - 128 + val r = (decompressed[i * 2 + 1].toInt() and 0xFF) - 128 + left[i] = l / 128.0f + right[i] = r / 128.0f + } + return Pair(left, right) + } + + /** Decode Zstd-compressed interleaved PCM16-LE stereo. Returns Float32 L/R. */ + fun decodePcm16(payload: ByteArray): Pair { + val decompressed = ZstdInputStream(ByteArrayInputStream(payload)).use { it.readBytes() } + val sampleCount = decompressed.size / 4 + val left = FloatArray(sampleCount) + val right = FloatArray(sampleCount) + for (i in 0 until sampleCount) { + val lLo = decompressed[i * 4 ].toInt() and 0xFF + val lHi = decompressed[i * 4 + 1].toInt() + val rLo = decompressed[i * 4 + 2].toInt() and 0xFF + val rHi = decompressed[i * 4 + 3].toInt() + left[i] = ((lHi shl 8) or lLo).toShort() / 32768.0f + right[i] = ((rHi shl 8) or rLo).toShort() / 32768.0f + } + return Pair(left, right) + } + + // ------------------------------------------------------------------------- + // Internal pipeline stages + // ------------------------------------------------------------------------- + + private fun lambdaDecompand(quantVal: Byte, maxIndex: Int): Float { + if (quantVal == 0.toByte()) return 0.0f + val sign = if (quantVal < 0) -1 else 1 + var absIndex = kotlin.math.abs(quantVal.toInt()).coerceAtMost(maxIndex) + val normalisedCdf = absIndex.toFloat() / maxIndex + val cdf = 0.5f + normalisedCdf * 0.5f + var absVal = -(1.0f / LAMBDA_FIXED) * kotlin.math.ln(2.0f * (1.0f - cdf)) + absVal = absVal.coerceIn(0.0f, 1.0f) + return sign * absVal + } + + private fun dequantiseCoeffs( + channel: Int, quantised: ByteArray, coeffs: FloatArray, + count: Int, maxIndex: Int + ) { + val firstBandSize = count shr DWT_LEVELS + val sidebandStarts = IntArray(DWT_LEVELS + 2) + sidebandStarts[0] = 0 + sidebandStarts[1] = firstBandSize + for (i in 2..DWT_LEVELS + 1) { + sidebandStarts[i] = sidebandStarts[i - 1] + (firstBandSize shl (i - 2)) + } + + for (i in 0 until count) { + var sideband = DWT_LEVELS + for (s in 0..DWT_LEVELS) { + if (i < sidebandStarts[s + 1]) { sideband = s; break } + } + val normalisedVal = lambdaDecompand(quantised[i], maxIndex) + val weight = BASE_QUANTISER_WEIGHTS[channel][sideband] + coeffs[i] = normalisedVal * COEFF_SCALARS[sideband] * weight + } + } + + private fun msToLR(mid: FloatArray, side: FloatArray, left: FloatArray, right: FloatArray, count: Int) { + for (i in 0 until count) { + left[i] = (mid[i] + side[i]).coerceIn(-1.0f, 1.0f) + right[i] = (mid[i] - side[i]).coerceIn(-1.0f, 1.0f) + } + } + + private fun gammaExpand(left: FloatArray, right: FloatArray, count: Int) { + for (i in 0 until count) { + val x = left[i]; val a = kotlin.math.abs(x) + left[i] = if (x >= 0) a * a else -(a * a) + + val y = right[i]; val b = kotlin.math.abs(y) + right[i] = if (y >= 0) b * b else -(b * b) + } + } + + // De-emphasis: y[n] = x[n] + 0.5 * y[n-1] (state persists across chunks) + private fun deemphasis(left: FloatArray, right: FloatArray, count: Int, state: TadDecoderState) { + for (i in 0 until count) { + val yL = left[i] + 0.5f * state.prevYL + state.prevYL = yL; left[i] = yL + + val yR = right[i] + 0.5f * state.prevYR + state.prevYR = yR; right[i] = yR + } + } +} diff --git a/src/net/torvald/terrarum/tav/TavDecoder.kt b/src/net/torvald/terrarum/tav/TavDecoder.kt new file mode 100644 index 000000000..24cba660a --- /dev/null +++ b/src/net/torvald/terrarum/tav/TavDecoder.kt @@ -0,0 +1,456 @@ +package net.torvald.terrarum.tav + +import com.badlogic.gdx.graphics.Pixmap +import io.airlift.compress.zstd.ZstdInputStream +import java.io.ByteArrayInputStream +import java.io.InputStream +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong + +/** + * TAV file demuxer and frame/audio coordinator. + * Owns the InputStream, demuxes packets, and decodes video+audio on a background thread. + * Provides lock-free SPSC ring buffers for thread-safe GL and audio mixer consumption. + */ +class TavDecoder( + private val stream: InputStream, + val looping: Boolean = false +) { + + companion object { + private const val FRAME_RING_SIZE = 32 + private const val AUDIO_RING_SIZE = 65536 + private const val BACK_PRESSURE_SLEEP_MS = 2L + + private val TAV_MAGIC = byteArrayOf(0x1F, 'T'.code.toByte(), 'S'.code.toByte(), 'V'.code.toByte(), + 'M'.code.toByte(), 'T'.code.toByte(), 'A'.code.toByte(), 'V'.code.toByte()) + } + + // ------------------------------------------------------------------------- + // Header + // ------------------------------------------------------------------------- + + lateinit var header: TavVideoDecode.TavHeader + private set + + val videoWidth: Int get() = header.width + val videoHeight: Int get() = header.height + val fps: Int get() = header.fps + val totalFrames: Long get() = header.totalFrames + val hasAudio: Boolean get() = header.hasAudio + val isPerceptual: Boolean get() = header.isPerceptual + val isMonoblock: Boolean get() = header.isMonoblock + + // ------------------------------------------------------------------------- + // Ring buffers + // ------------------------------------------------------------------------- + + // Video: pre-allocated Pixmap ring buffer + private lateinit var frameRing: Array + val frameReadIdx = AtomicInteger(0) + val frameWriteIdx = AtomicInteger(0) + + // Audio: circular Float32 ring + private lateinit var audioRingL: FloatArray + private lateinit var audioRingR: FloatArray + val audioReadPos = AtomicLong(0L) + val audioWritePos = AtomicLong(0L) + + // ------------------------------------------------------------------------- + // Thread state + // ------------------------------------------------------------------------- + + private var decodeThread: Thread? = null + val shouldStop = AtomicBoolean(false) + val isFinished = AtomicBoolean(false) + + // ------------------------------------------------------------------------- + // Codec state + // ------------------------------------------------------------------------- + + private var prevCoeffsY: FloatArray? = null + private var prevCoeffsCo: FloatArray? = null + private var prevCoeffsCg: FloatArray? = null + + private val tadState = TadDecode.TadDecoderState() + + private var frameCounter = 0 + private var gopFrameCounter = 0 // for grain synthesis RNG continuity + + // ------------------------------------------------------------------------- + // Looping support: remember stream position after header + // ------------------------------------------------------------------------- + + private var streamBuffer: ByteArray? = null // null when not a resettable stream + private var headerSize = 0 + + // ------------------------------------------------------------------------- + // Lifecycle + // ------------------------------------------------------------------------- + + init { + // Buffer the stream to support reset for looping + val rawBytes = stream.readBytes() + streamBuffer = rawBytes + headerSize = parseHeader(rawBytes) + } + + private fun parseHeader(bytes: ByteArray): Int { + // Verify magic + for (i in 0..7) { + if (bytes[i] != TAV_MAGIC[i]) throw IllegalArgumentException("Not a TAV file (magic mismatch)") + } + + var ptr = 8 + val version = bytes[ptr++].toInt() and 0xFF + val width = ((bytes[ptr].toInt() and 0xFF) or ((bytes[ptr+1].toInt() and 0xFF) shl 8)).also { ptr += 2 } + val height = ((bytes[ptr].toInt() and 0xFF) or ((bytes[ptr+1].toInt() and 0xFF) shl 8)).also { ptr += 2 } + val fps = bytes[ptr++].toInt() and 0xFF + val totalFrames = ( + (bytes[ptr ].toLong() and 0xFF) or + ((bytes[ptr+1].toLong() and 0xFF) shl 8) or + ((bytes[ptr+2].toLong() and 0xFF) shl 16) or + ((bytes[ptr+3].toLong() and 0xFF) shl 24) + ).also { ptr += 4 } + val waveletFilter = bytes[ptr++].toInt() and 0xFF + val decompLevels = bytes[ptr++].toInt() and 0xFF + val qIndexY = bytes[ptr++].toInt() and 0xFF + val qIndexCo = bytes[ptr++].toInt() and 0xFF + val qIndexCg = bytes[ptr++].toInt() and 0xFF + val extraFlags = bytes[ptr++].toInt() and 0xFF + val videoFlags = bytes[ptr++].toInt() and 0xFF + val encoderQuality = bytes[ptr++].toInt() and 0xFF + val channelLayout = bytes[ptr++].toInt() and 0xFF + val entropyCoder = bytes[ptr++].toInt() and 0xFF + val encoderPreset = bytes[ptr++].toInt() and 0xFF + ptr += 2 // reserved + device orientation (ignored) + file role + + header = TavVideoDecode.TavHeader( + version = version, width = width, height = height, + fps = fps, totalFrames = totalFrames, + waveletFilter = waveletFilter, decompLevels = decompLevels, + qIndexY = qIndexY, qIndexCo = qIndexCo, qIndexCg = qIndexCg, + extraFlags = extraFlags, videoFlags = videoFlags, + encoderQuality = encoderQuality, channelLayout = channelLayout, + entropyCoder = entropyCoder, encoderPreset = encoderPreset + ) + + return ptr // byte offset to first packet + } + + private fun allocateBuffers() { + frameRing = Array(FRAME_RING_SIZE) { + Pixmap(videoWidth, videoHeight, Pixmap.Format.RGBA8888) + } + audioRingL = FloatArray(AUDIO_RING_SIZE) + audioRingR = FloatArray(AUDIO_RING_SIZE) + } + + fun start() { + allocateBuffers() + shouldStop.set(false) + isFinished.set(false) + + decodeThread = Thread(::decodeLoop, "tav-decode").also { + it.isDaemon = true + it.start() + } + } + + fun stop() { + shouldStop.set(true) + decodeThread?.join(2000) + decodeThread = null + } + + fun dispose() { + stop() + if (::frameRing.isInitialized) { + for (px in frameRing) px.dispose() + } + } + + // ------------------------------------------------------------------------- + // Decode loop (background thread) + // ------------------------------------------------------------------------- + + private fun decodeLoop() { + val bytes = streamBuffer ?: return + var ptr = headerSize + + try { + while (!shouldStop.get()) { + if (ptr >= bytes.size) { + if (looping) { + ptr = headerSize + prevCoeffsY = null; prevCoeffsCo = null; prevCoeffsCg = null + tadState.prevYL = 0f; tadState.prevYR = 0f + continue + } else { + isFinished.set(true) + break + } + } + + val packetType = bytes[ptr++].toInt() and 0xFF + + when (packetType) { + // --- Special fixed-size packets (no payload size) --- + 0xFF -> { /* sync, no-op */ } + 0xFE -> { /* NTSC sync, no-op */ } + 0x00 -> { /* no-op */ } + 0xF0 -> { /* loop point start, no-op */ } + 0xF1 -> { /* loop point end, no-op */ } + 0xFC -> { + // GOP sync: 1 extra byte (frame count) + if (ptr < bytes.size) ptr++ // skip frame count byte + } + 0xFD -> { + // Timecode: 8-byte uint64 nanosecond timestamp + ptr += 8 + } + + // --- Video packets --- + 0x10, 0x11 -> { + val payloadSize = readInt32LE(bytes, ptr); ptr += 4 + if (ptr + payloadSize > bytes.size) { isFinished.set(true); break } + val payload = bytes.copyOfRange(ptr, ptr + payloadSize); ptr += payloadSize + + val blockData = if (!header.noZstd) + ZstdInputStream(ByteArrayInputStream(payload)).use { it.readBytes() } + else payload + + // Back-pressure: wait until there's space in the ring + while (frameRingFull() && !shouldStop.get()) Thread.sleep(BACK_PRESSURE_SLEEP_MS) + if (shouldStop.get()) break + + val result = TavVideoDecode.decodeFrame( + blockData, header, + prevCoeffsY, prevCoeffsCo, prevCoeffsCg, + frameCounter + ) + prevCoeffsY = result.coeffsY + prevCoeffsCo = result.coeffsCo + prevCoeffsCg = result.coeffsCg + + val rgba = result.rgba + if (rgba != null) { + writeFrameToRing(rgba) + } else { + // SKIP frame: duplicate the previous ring entry + duplicateLastFrame() + } + frameCounter++ + } + + 0x12 -> { + // GOP Unified + val gopSize = bytes[ptr++].toInt() and 0xFF + val payloadSize = readInt32LE(bytes, ptr); ptr += 4 + if (ptr + payloadSize > bytes.size) { isFinished.set(true); break } + val payload = bytes.copyOfRange(ptr, ptr + payloadSize); ptr += payloadSize + + val frames = TavVideoDecode.decodeGop(payload, header, gopSize, frameCounter = gopFrameCounter) + gopFrameCounter += gopSize + + for (rgba in frames) { + while (frameRingFull() && !shouldStop.get()) Thread.sleep(BACK_PRESSURE_SLEEP_MS) + if (shouldStop.get()) break + writeFrameToRing(rgba) + frameCounter++ + } + } + + // --- Audio packets --- + 0x21 -> { + val payloadSize = readInt32LE(bytes, ptr); ptr += 4 + val payload = bytes.copyOfRange(ptr, ptr + payloadSize); ptr += payloadSize + val pcm = TadDecode.decodePcm8(payload) + writeAudioToRing(pcm.first, pcm.second, pcm.first.size) + } + + 0x22 -> { + val payloadSize = readInt32LE(bytes, ptr); ptr += 4 + val payload = bytes.copyOfRange(ptr, ptr + payloadSize); ptr += payloadSize + val pcm = TadDecode.decodePcm16(payload) + writeAudioToRing(pcm.first, pcm.second, pcm.first.size) + } + + 0x24 -> { + // TAD packet structure: + // uint16 sampleCount, uint32 outerSize (=compressedSize+7) + // TAD chunk header: uint16 sampleCount, uint8 maxIndex, uint32 compressedSize + // * Zstd payload + val sampleCount = readInt16LE(bytes, ptr); ptr += 2 + val outerSize = readInt32LE(bytes, ptr); ptr += 4 // = compSize + 7 + + val chunkSamples = readInt16LE(bytes, ptr); ptr += 2 + val maxIndex = bytes[ptr++].toInt() and 0xFF + val compSize = readInt32LE(bytes, ptr); ptr += 4 + + if (ptr + compSize > bytes.size) { isFinished.set(true); break } + val payload = bytes.copyOfRange(ptr, ptr + compSize); ptr += compSize + + while (audioRingFull(chunkSamples) && !shouldStop.get()) Thread.sleep(BACK_PRESSURE_SLEEP_MS) + if (shouldStop.get()) break + + try { + val pcm = TadDecode.decodeTadChunk(payload, chunkSamples, maxIndex, tadState) + writeAudioToRing(pcm.first, pcm.second, chunkSamples) + } catch (e: Exception) { + // Silently drop corrupted audio packets + } + } + + // --- Extended header and metadata: read and skip --- + 0xEF -> { + // TAV extended header: uint16 num_kvp, then key-value pairs + if (ptr + 2 > bytes.size) { isFinished.set(true); break } + val numKvp = readInt16LE(bytes, ptr); ptr += 2 + repeat(numKvp) { + if (ptr + 5 <= bytes.size) { + ptr += 4 // key[4] + val valueType = bytes[ptr++].toInt() and 0xFF + val valueSize = when (valueType) { + 0x00 -> 2; 0x01 -> 3; 0x02 -> 4; 0x03 -> 6; 0x04 -> 8 + 0x10 -> { val len = readInt16LE(bytes, ptr); ptr += 2; len } + else -> 0 + } + ptr += valueSize + } + } + } + + in 0xE0..0xEE -> { + // Standard metadata: uint32 size, * payload + val payloadSize = readInt32LE(bytes, ptr); ptr += 4 + ptr += payloadSize + } + + else -> { + // Unknown packet with payload: uint32 size + payload + if (ptr + 4 <= bytes.size) { + val payloadSize = readInt32LE(bytes, ptr); ptr += 4 + ptr += payloadSize + } else { + isFinished.set(true); break + } + } + } + } + } catch (e: InterruptedException) { + // Thread interrupted, exit cleanly + } catch (e: Exception) { + System.err.println("[TavDecoder] Decode error: ${e.message}") + } + + if (!shouldStop.get()) isFinished.set(true) + } + + // ------------------------------------------------------------------------- + // Frame ring operations + // ------------------------------------------------------------------------- + + private fun frameRingFull(): Boolean { + val write = frameWriteIdx.get() + val read = frameReadIdx.get() + return ((write + 1) % FRAME_RING_SIZE) == (read % FRAME_RING_SIZE) + } + + private fun writeFrameToRing(rgba: ByteArray) { + val idx = frameWriteIdx.get() % FRAME_RING_SIZE + val px = frameRing[idx] + val buf = px.pixels + buf.position(0) + buf.put(rgba) + buf.position(0) + frameWriteIdx.incrementAndGet() + } + + private fun duplicateLastFrame() { + val writeIdx = frameWriteIdx.get() + if (writeIdx == frameReadIdx.get()) return // ring empty, nothing to duplicate + val srcIdx = ((writeIdx - 1 + FRAME_RING_SIZE) % FRAME_RING_SIZE) + val dstIdx = writeIdx % FRAME_RING_SIZE + val src = frameRing[srcIdx] + val dst = frameRing[dstIdx] + src.pixels.position(0) + dst.pixels.position(0) + dst.pixels.put(src.pixels) + src.pixels.position(0) + dst.pixels.position(0) + frameWriteIdx.incrementAndGet() + } + + /** Returns the current decoded Pixmap without advancing, or null if no frame available. */ + fun getFramePixmap(): Pixmap? { + val read = frameReadIdx.get() + val write = frameWriteIdx.get() + if (read == write) return null + return frameRing[read % FRAME_RING_SIZE] + } + + /** Advance to the next decoded frame. */ + fun advanceFrame() { + val read = frameReadIdx.get() + val write = frameWriteIdx.get() + if (read != write) frameReadIdx.incrementAndGet() + } + + // ------------------------------------------------------------------------- + // Audio ring operations + // ------------------------------------------------------------------------- + + private fun audioRingFull(needed: Int): Boolean { + val avail = AUDIO_RING_SIZE - (audioWritePos.get() - audioReadPos.get()).toInt() + return avail < needed + } + + private fun writeAudioToRing(left: FloatArray, right: FloatArray, count: Int) { + var writePos = audioWritePos.get() + for (i in 0 until count) { + val slot = (writePos % AUDIO_RING_SIZE).toInt() + audioRingL[slot] = left[i] + audioRingR[slot] = right[i] + writePos++ + } + audioWritePos.set(writePos) + } + + /** + * Read audio samples from the ring buffer into the caller's buffers. + * @return number of samples actually read + */ + fun readAudioSamples(bufL: FloatArray, bufR: FloatArray): Int { + val available = (audioWritePos.get() - audioReadPos.get()).toInt().coerceAtMost(bufL.size) + if (available <= 0) return 0 + var readPos = audioReadPos.get() + for (i in 0 until available) { + val slot = (readPos % AUDIO_RING_SIZE).toInt() + bufL[i] = audioRingL[slot] + bufR[i] = audioRingR[slot] + readPos++ + } + audioReadPos.set(readPos) + return available + } + + // ------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------- + + private fun readInt32LE(data: ByteArray, offset: Int): Int { + val b0 = data[offset ].toInt() and 0xFF + val b1 = data[offset+1].toInt() and 0xFF + val b2 = data[offset+2].toInt() and 0xFF + val b3 = data[offset+3].toInt() and 0xFF + return b0 or (b1 shl 8) or (b2 shl 16) or (b3 shl 24) + } + + private fun readInt16LE(data: ByteArray, offset: Int): Int { + val b0 = data[offset ].toInt() and 0xFF + val b1 = data[offset+1].toInt() and 0xFF + return b0 or (b1 shl 8) + } +} diff --git a/src/net/torvald/terrarum/tav/TavVideoDecode.kt b/src/net/torvald/terrarum/tav/TavVideoDecode.kt new file mode 100644 index 000000000..9b36115f7 --- /dev/null +++ b/src/net/torvald/terrarum/tav/TavVideoDecode.kt @@ -0,0 +1,685 @@ +package net.torvald.terrarum.tav + +import io.airlift.compress.zstd.ZstdInputStream +import java.io.ByteArrayInputStream +import kotlin.math.roundToInt + +/** + * TAV video frame decoder (stateless pipeline functions). + * Handles I-frames (0x10), P-frames (0x11) and GOP Unified (0x12) packets. + * Supports YCoCg-R colour space only (odd version numbers). + * + * Ported from GraphicsJSR223Delegate.kt in the TSVM project. + */ +object TavVideoDecode { + + // Exponential quantiser lookup table (index → value) + private val QLUT = intArrayOf( + 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30, + 31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60, + 61,62,63,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108,110,112, + 114,116,118,120,122,124,126,128,132,136,140,144,148,152,156,160,164,168,172,176,180,184,188, + 192,196,200,204,208,212,216,220,224,228,232,236,240,244,248,252,256,264,272,280,288,296,304, + 312,320,328,336,344,352,360,368,376,384,392,400,408,416,424,432,440,448,456,464,472,480,488, + 496,504,512,528,544,560,576,592,608,624,640,656,672,688,704,720,736,752,768,784,800,816,832, + 848,864,880,896,912,928,944,960,976,992,1008,1024,1056,1088,1120,1152,1184,1216,1248,1280, + 1312,1344,1376,1408,1440,1472,1504,1536,1568,1600,1632,1664,1696,1728,1760,1792,1824,1856, + 1888,1920,1952,1984,2016,2048,2112,2176,2240,2304,2368,2432,2496,2560,2624,2688,2752,2816, + 2880,2944,3008,3072,3136,3200,3264,3328,3392,3456,3520,3584,3648,3712,3776,3840,3904,3968, + 4032,4096 + ) + + private val ANISOTROPY_MULT = floatArrayOf(5.1f, 3.8f, 2.7f, 2.0f, 1.5f, 1.2f, 1.0f) + private val ANISOTROPY_BIAS = floatArrayOf(0.4f, 0.3f, 0.2f, 0.1f, 0.0f, 0.0f, 0.0f) + private val ANISOTROPY_MULT_CHROMA = floatArrayOf(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f) + private val ANISOTROPY_BIAS_CHROMA = floatArrayOf(1.0f, 0.8f, 0.6f, 0.4f, 0.2f, 0.0f, 0.0f) + + // ------------------------------------------------------------------------- + // Frame data class + // ------------------------------------------------------------------------- + + /** Decoded frame info returned to the caller. */ + data class TavHeader( + val version: Int, + val width: Int, + val height: Int, + val fps: Int, + val totalFrames: Long, + val waveletFilter: Int, + val decompLevels: Int, + val qIndexY: Int, + val qIndexCo: Int, + val qIndexCg: Int, + val extraFlags: Int, + val videoFlags: Int, + val encoderQuality: Int, + val channelLayout: Int, + val entropyCoder: Int, + val encoderPreset: Int + ) { + val hasAudio: Boolean get() = (extraFlags and 0x01) != 0 + val isLooping: Boolean get() = (extraFlags and 0x04) != 0 + val isInterlaced:Boolean get() = (videoFlags and 0x01) != 0 + val isLossless: Boolean get() = (videoFlags and 0x04) != 0 + val noZstd: Boolean get() = (videoFlags and 0x10) != 0 + val hasNoVideo: Boolean get() = (videoFlags and 0x80) != 0 + /** Monoblock mode: version 3-6 */ + val isMonoblock: Boolean get() = version in 3..6 + + val qY: Int get() = QLUT.getOrElse(qIndexY - 1) { 1 } + val qCo: Int get() = QLUT.getOrElse(qIndexCo - 1) { 1 } + val qCg: Int get() = QLUT.getOrElse(qIndexCg - 1) { 1 } + val isPerceptual: Boolean get() = (version % 8) in 5..8 + /** Temporal motion coder: 0=Haar (version<=8), 1=CDF5/3 (version>8) */ + val temporalMotionCoder: Int get() = if (version > 8) 1 else 0 + } + + // ------------------------------------------------------------------------- + // Subband layout + // ------------------------------------------------------------------------- + + data class SubbandInfo(val level: Int, val subbandType: Int, val coeffStart: Int, val coeffCount: Int) + + fun calculateSubbandLayout(width: Int, height: Int, decompLevels: Int): List { + val subbands = mutableListOf() + val llWidth = width shr decompLevels + val llHeight = height shr decompLevels + subbands.add(SubbandInfo(decompLevels, 0, 0, llWidth * llHeight)) + var offset = llWidth * llHeight + + for (level in decompLevels downTo 1) { + val lw = width shr (decompLevels - level + 1) + val lh = height shr (decompLevels - level + 1) + val sz = lw * lh + + subbands.add(SubbandInfo(level, 1, offset, sz)); offset += sz + subbands.add(SubbandInfo(level, 2, offset, sz)); offset += sz + subbands.add(SubbandInfo(level, 3, offset, sz)); offset += sz + } + return subbands + } + + // ------------------------------------------------------------------------- + // Perceptual weight calculation + // ------------------------------------------------------------------------- + + fun getPerceptualWeight(qIndex: Int, qYGlobal: Int, level0: Int, subbandType: Int, isChroma: Boolean, maxLevels: Int): Float { + val level = 1.0f + ((level0 - 1.0f) / (maxLevels - 1.0f)) * 5.0f + val qualityLevel = deriveEncoderQIndex(qIndex, qYGlobal) + + if (!isChroma) { + if (subbandType == 0) return perceptualLL(level) + val lh = perceptualLH(level) + if (subbandType == 1) return lh + val hl = perceptualHL(qualityLevel, lh) + val fineDetail = if (level in 1.8f..2.2f) 0.92f else if (level in 2.8f..3.2f) 0.88f else 1.0f + if (subbandType == 2) return hl * fineDetail + return perceptualHH(lh, hl, level) * fineDetail + } else { + val base = perceptualChromaBasecurve(qualityLevel, level - 1) + return when (subbandType) { + 0 -> 1.0f + 1 -> base.coerceAtLeast(1.0f) + 2 -> (base * ANISOTROPY_MULT_CHROMA[qualityLevel]).coerceAtLeast(1.0f) + else -> (base * ANISOTROPY_MULT_CHROMA[qualityLevel] + ANISOTROPY_BIAS_CHROMA[qualityLevel]).coerceAtLeast(1.0f) + } + } + } + + private fun deriveEncoderQIndex(qIndex: Int, qYGlobal: Int): Int { + if (qIndex > 0) return qIndex - 1 + return when { + qYGlobal >= 79 -> 0 + qYGlobal >= 47 -> 1 + qYGlobal >= 23 -> 2 + qYGlobal >= 11 -> 3 + qYGlobal >= 5 -> 4 + qYGlobal >= 2 -> 5 + else -> 6 + } + } + + private fun perceptualLH(level: Float): Float { + val H4 = 1.2f; val K = 2f; val K12 = K * 12f; val x = level + val Lx = H4 - ((K + 1f) / 15f) * (x - 4f) + val C3 = -1f / 45f * (K12 + 92) + val G3x = (-x / 180f) * (K12 + 5 * x * x - 60 * x + 252) - C3 + H4 + return if (level >= 4f) Lx else G3x + } + + private fun perceptualHL(quality: Int, lh: Float): Float = + lh * ANISOTROPY_MULT[quality] + ANISOTROPY_BIAS[quality] + + private fun perceptualHH(lh: Float, hl: Float, level: Float): Float { + val kx = (kotlin.math.sqrt(level.toDouble()).toFloat() - 1f) * 0.5f + 0.5f + return lh * (1f - kx) + hl * kx + } + + private fun perceptualLL(level: Float): Float { + val n = perceptualLH(level) + val m = perceptualLH(level - 1) / n + return n / m + } + + private fun perceptualChromaBasecurve(qualityLevel: Int, level: Float): Float = + 1.0f - (1.0f / (0.5f * qualityLevel * qualityLevel + 1.0f)) * (level - 4f) + + // ------------------------------------------------------------------------- + // Dequantisation + // ------------------------------------------------------------------------- + + fun dequantisePerceptual( + quantised: ShortArray, dequantised: FloatArray, + subbands: List, + baseQuantiser: Float, isChroma: Boolean, + qIndex: Int, qYGlobal: Int, decompLevels: Int + ) { + val weights = FloatArray(quantised.size) { 1.0f } + for (sb in subbands) { + val w = getPerceptualWeight(qIndex, qYGlobal, sb.level, sb.subbandType, isChroma, decompLevels) + for (i in 0 until sb.coeffCount) { + val idx = sb.coeffStart + i + if (idx < weights.size) weights[idx] = w + } + } + for (i in quantised.indices) { + if (i < dequantised.size) dequantised[i] = quantised[i] * baseQuantiser * weights[i] + } + } + + fun dequantiseUniform(quantised: ShortArray, dequantised: FloatArray, baseQuantiser: Float) { + for (i in quantised.indices) { + if (i < dequantised.size) dequantised[i] = quantised[i] * baseQuantiser + } + } + + // ------------------------------------------------------------------------- + // Grain synthesis + // ------------------------------------------------------------------------- + + fun grainSynthesis(coeffs: FloatArray, width: Int, height: Int, + frameNum: Int, subbands: List, qYGlobal: Int, encoderPreset: Int) { + if ((encoderPreset and 0x02) != 0) return // Anime preset: disable grain + + val noiseAmplitude = qYGlobal.coerceAtMost(32) * 0.8f + + for (sb in subbands) { + if (sb.level == 0) continue // Skip LL band + + for (i in 0 until sb.coeffCount) { + val idx = sb.coeffStart + i + if (idx >= coeffs.size) continue + val y = idx / width + val x = idx % width + val rngVal = grainRng(frameNum.toUInt(), (sb.level + sb.subbandType * 31 + 16777619).toUInt(), x.toUInt(), y.toUInt()) + val noise = grainTriangularNoise(rngVal) + coeffs[idx] -= noise * noiseAmplitude + } + } + } + + private fun grainRng(frame: UInt, band: UInt, x: UInt, y: UInt): UInt { + val key = frame * 0x9e3779b9u xor band * 0x7f4a7c15u xor (y shl 16) xor x + var hash = key + hash = hash xor (hash shr 16) + hash = hash * 0x7feb352du + hash = hash xor (hash shr 15) + hash = hash * 0x846ca68bu + hash = hash xor (hash shr 16) + return hash + } + + private fun grainTriangularNoise(rngVal: UInt): Float { + val u1 = (rngVal and 0xFFFFu).toFloat() / 65535.0f + val u2 = ((rngVal shr 16) and 0xFFFFu).toFloat() / 65535.0f + return (u1 + u2) - 1.0f + } + + // ------------------------------------------------------------------------- + // YCoCg-R to RGB + // ------------------------------------------------------------------------- + + /** + * Convert YCoCg-R float arrays to RGBA8888 byte array. + * Each pixel = 4 bytes (R, G, B, A=255). + */ + fun ycocgrToRgba(y: FloatArray, co: FloatArray, cg: FloatArray, + width: Int, height: Int, + channelLayout: Int = 0): ByteArray { + val out = ByteArray(width * height * 4) + for (i in 0 until width * height) { + val yv = y[i]; val cov = co[i]; val cgv = cg[i] + val tmp = yv - cgv / 2.0f + val g = cgv + tmp + val b = tmp - cov / 2.0f + val r = cov + b + + out[i * 4 ] = r.roundToInt().coerceIn(0, 255).toByte() + out[i * 4 + 1] = g.roundToInt().coerceIn(0, 255).toByte() + out[i * 4 + 2] = b.roundToInt().coerceIn(0, 255).toByte() + out[i * 4 + 3] = 0xFF.toByte() + } + return out + } + + // ------------------------------------------------------------------------- + // Temporal quantiser scale helpers (for GOP decode) + // ------------------------------------------------------------------------- + + private fun getTemporalSubbandLevel(frameIdx: Int, numFrames: Int, temporalLevels: Int): Int { + val framesPerLevel0 = numFrames shr temporalLevels + return when { + frameIdx < framesPerLevel0 -> 0 + frameIdx < (numFrames shr 1) -> 1 + else -> 2 + } + } + + private fun getTemporalQuantiserScale(encoderPreset: Int, temporalLevel: Int): Float { + val beta = if (encoderPreset and 0x01 == 1) 0.0f else 0.6f + val kappa = if (encoderPreset and 0x01 == 1) 1.0f else 1.14f + return Math.pow(2.0, (beta * Math.pow(temporalLevel.toDouble(), kappa.toDouble()))).toFloat() + } + + // ------------------------------------------------------------------------- + // Coefficients from block data (significance-map or EZBC) + // ------------------------------------------------------------------------- + + private fun readInt32LE(data: ByteArray, offset: Int): Int { + val b0 = data[offset ].toInt() and 0xFF + val b1 = data[offset+1].toInt() and 0xFF + val b2 = data[offset+2].toInt() and 0xFF + val b3 = data[offset+3].toInt() and 0xFF + return b0 or (b1 shl 8) or (b2 shl 16) or (b3 shl 24) + } + + /** + * Decode quantised coefficients from block data (single frame). + * Supports EZBC (entropyCoder=1) and 2-bit significance map (entropyCoder=0). + */ + private fun extractCoefficients( + blockData: ByteArray, offset: Int, + coeffCount: Int, channelLayout: Int, entropyCoder: Int, + qY: ShortArray, qCo: ShortArray, qCg: ShortArray + ) { + if (entropyCoder == 1) { + EzbcDecode.decode2D(blockData, offset, channelLayout, qY, qCo, qCg, null) + } else { + extractCoeffsSigMap(blockData, offset, coeffCount, channelLayout, qY, qCo, qCg) + } + } + + /** 2-bit significance map decoder (legacy format). */ + private fun extractCoeffsSigMap( + data: ByteArray, offset: Int, coeffCount: Int, channelLayout: Int, + outY: ShortArray, outCo: ShortArray, outCg: ShortArray + ) { + val hasY = (channelLayout and 4) == 0 + val hasCoCg = (channelLayout and 2) == 0 + + val mapBytes = (coeffCount * 2 + 7) / 8 + + val yMapStart = if (hasY) { offset } else -1 + val coMapStart = if (hasCoCg) { offset + (if (hasY) mapBytes else 0) } else -1 + val cgMapStart = if (hasCoCg) { coMapStart + mapBytes } else -1 + + var yOthers = 0; var coOthers = 0; var cgOthers = 0 + + fun countOthers(mapStart: Int): Int { + var cnt = 0 + for (i in 0 until coeffCount) { + val bitPos = i * 2 + val byteIdx = bitPos / 8; val bitOffset = bitPos % 8 + val byteVal = data[mapStart + byteIdx].toInt() and 0xFF + var code = (byteVal shr bitOffset) and 0x03 + if (bitOffset == 7 && byteIdx + 1 < mapBytes) { + val nb = data[mapStart + byteIdx + 1].toInt() and 0xFF + code = (code and 0x01) or ((nb and 0x01) shl 1) + } + if (code == 3) cnt++ + } + return cnt + } + + if (hasY) yOthers = countOthers(yMapStart) + if (hasCoCg) { coOthers = countOthers(coMapStart); cgOthers = countOthers(cgMapStart) } + + val numChannels = if (hasY && hasCoCg) 3 else if (hasY) 1 else 2 + var valueOffset = offset + mapBytes * numChannels + + val yValStart = if (hasY) { val s = valueOffset; valueOffset += yOthers * 2; s } else -1 + val coValStart = if (hasCoCg) { val s = valueOffset; valueOffset += coOthers * 2; s } else -1 + val cgValStart = if (hasCoCg) { val s = valueOffset; valueOffset += cgOthers * 2; s } else -1 + + fun decodeChannel(mapStart: Int, valStart: Int, out: ShortArray) { + var vIdx = 0 + for (i in 0 until coeffCount) { + val bitPos = i * 2; val byteIdx = bitPos / 8; val bitOffset = bitPos % 8 + val byteVal = data[mapStart + byteIdx].toInt() and 0xFF + var code = (byteVal shr bitOffset) and 0x03 + if (bitOffset == 7 && byteIdx + 1 < mapBytes) { + code = (code and 0x01) or ((data[mapStart + byteIdx + 1].toInt() and 0x01) shl 1) + } + out[i] = when (code) { + 0 -> 0 + 1 -> 1 + 2 -> (-1).toShort() + 3 -> { + val vp = valStart + vIdx * 2; vIdx++ + val lo = data[vp ].toInt() and 0xFF + val hi = data[vp+1].toInt() + ((hi shl 8) or lo).toShort() + } + else -> 0 + } + } + } + + if (hasY) decodeChannel(yMapStart, yValStart, outY) + if (hasCoCg) { decodeChannel(coMapStart, coValStart, outCo); decodeChannel(cgMapStart, cgValStart, outCg) } + } + + // ------------------------------------------------------------------------- + // I-frame / P-frame decode (monoblock only) + // ------------------------------------------------------------------------- + + /** + * Decode an I-frame or P-frame packet payload (already Zstd-decompressed). + * Returns RGBA8888 pixels and the new float coefficients for P-frame reference. + * + * @param blockData decompressed block data (after ZstdInputStream) + * @param header parsed TAV header + * @param prevCoeffsY/Co/Cg previous frame coefficients for P-frame delta (null for I-frame) + * @param frameNum frame counter for grain synthesis RNG + * @return Triple(rgbaPixels, newCoeffsY, newCo, newCg) — newCoeffs are null for GOP frames + */ + fun decodeFrame( + blockData: ByteArray, + header: TavHeader, + prevCoeffsY: FloatArray?, + prevCoeffsCo: FloatArray?, + prevCoeffsCg: FloatArray?, + frameNum: Int + ): FrameDecodeResult { + val width = header.width + val height = header.height + val coeffCount = width * height + + var ptr = 0 + + // Read tile header (4 bytes) + val modeRaw = blockData[ptr++].toInt() and 0xFF + val qYOverride = blockData[ptr++].toInt() and 0xFF + val qCoOverride = blockData[ptr++].toInt() and 0xFF + val qCgOverride = blockData[ptr++].toInt() and 0xFF + + val baseMode = modeRaw and 0x0F + val haarNibble = modeRaw shr 4 + val haarLevel = if (baseMode == 0x02 && haarNibble > 0) haarNibble + 1 else 0 + + val qY = if (qYOverride != 0) QLUT[qYOverride - 1] else header.qY + val qCo = if (qCoOverride != 0) QLUT[qCoOverride - 1] else header.qCo + val qCg = if (qCgOverride != 0) QLUT[qCgOverride - 1] else header.qCg + + val quantY = ShortArray(coeffCount) + val quantCo = ShortArray(coeffCount) + val quantCg = ShortArray(coeffCount) + + val floatY = FloatArray(coeffCount) + val floatCo = FloatArray(coeffCount) + val floatCg = FloatArray(coeffCount) + + val subbands = calculateSubbandLayout(width, height, header.decompLevels) + + when (baseMode) { + 0x00 -> { // SKIP - caller should copy previous frame + return FrameDecodeResult(null, prevCoeffsY, prevCoeffsCo, prevCoeffsCg, frameMode = 'S') + } + 0x01 -> { // INTRA + extractCoefficients(blockData, ptr, coeffCount, header.channelLayout, header.entropyCoder, + quantY, quantCo, quantCg) + + if (header.isPerceptual) { + dequantisePerceptual(quantY, floatY, subbands, qY.toFloat(), false, header.encoderQuality, header.qY, header.decompLevels) + dequantisePerceptual(quantCo, floatCo, subbands, qCo.toFloat(), true, header.encoderQuality, header.qY, header.decompLevels) + dequantisePerceptual(quantCg, floatCg, subbands, qCg.toFloat(), true, header.encoderQuality, header.qY, header.decompLevels) + } else { + dequantiseUniform(quantY, floatY, qY.toFloat()) + dequantiseUniform(quantCo, floatCo, qCo.toFloat()) + dequantiseUniform(quantCg, floatCg, qCg.toFloat()) + } + + grainSynthesis(floatY, width, height, frameNum, subbands, header.qY, header.encoderPreset) + + DwtUtil.inverseMultilevel2D(floatY, width, height, header.decompLevels, header.waveletFilter) + DwtUtil.inverseMultilevel2D(floatCo, width, height, header.decompLevels, header.waveletFilter) + DwtUtil.inverseMultilevel2D(floatCg, width, height, header.decompLevels, header.waveletFilter) + } + 0x02 -> { // DELTA + extractCoefficients(blockData, ptr, coeffCount, header.channelLayout, header.entropyCoder, + quantY, quantCo, quantCg) + + val deltaY = FloatArray(coeffCount) { quantY[it].toFloat() * qY } + val deltaCo = FloatArray(coeffCount) { quantCo[it].toFloat() * qCo } + val deltaCg = FloatArray(coeffCount) { quantCg[it].toFloat() * qCg } + + if (haarLevel > 0) { + DwtUtil.inverseMultilevel2D(deltaY, width, height, haarLevel, 255) + DwtUtil.inverseMultilevel2D(deltaCo, width, height, haarLevel, 255) + DwtUtil.inverseMultilevel2D(deltaCg, width, height, haarLevel, 255) + } + + val pY = prevCoeffsY ?: FloatArray(coeffCount) + val pCo = prevCoeffsCo ?: FloatArray(coeffCount) + val pCg = prevCoeffsCg ?: FloatArray(coeffCount) + + for (i in 0 until coeffCount) { + floatY[i] = pY[i] + deltaY[i] + floatCo[i] = pCo[i] + deltaCo[i] + floatCg[i] = pCg[i] + deltaCg[i] + } + + grainSynthesis(floatY, width, height, frameNum, subbands, header.qY, header.encoderPreset) + + DwtUtil.inverseMultilevel2D(floatY, width, height, header.decompLevels, header.waveletFilter) + DwtUtil.inverseMultilevel2D(floatCo, width, height, header.decompLevels, header.waveletFilter) + DwtUtil.inverseMultilevel2D(floatCg, width, height, header.decompLevels, header.waveletFilter) + } + } + + val rgba = ycocgrToRgba(floatY, floatCo, floatCg, width, height, header.channelLayout) + return FrameDecodeResult(rgba, floatY.clone(), floatCo.clone(), floatCg.clone()) + } + + data class FrameDecodeResult( + val rgba: ByteArray?, // null on SKIP frames + val coeffsY: FloatArray?, + val coeffsCo: FloatArray?, + val coeffsCg: FloatArray?, + val frameMode: Char = ' ' + ) + + // ------------------------------------------------------------------------- + // GOP Unified decode (0x12 packet) + // ------------------------------------------------------------------------- + + /** + * Decode a GOP Unified packet. + * @param gopPayload Zstd-compressed unified block data + * @param header TAV header + * @param gopSize number of frames in this GOP + * @param frameCounter global frame counter at start of GOP (for grain synthesis RNG) + * @return list of RGBA8888 byte arrays, one per frame + */ + fun decodeGop( + gopPayload: ByteArray, + header: TavHeader, + gopSize: Int, + temporalLevels: Int = 2, + frameCounter: Int = 0 + ): List { + val width = header.width + val height = header.height + val pixels = width * height + + // Decompress + val decompressed = ZstdInputStream(ByteArrayInputStream(gopPayload)).use { it.readBytes() } + + // Extract per-frame quantised coefficients + val quantisedCoeffs = decodeGopUnifiedBlock(decompressed, gopSize, pixels, header) + + val gopWidth = header.width + val gopHeight = header.height + + val gopY = Array(gopSize) { FloatArray(pixels) } + val gopCo = Array(gopSize) { FloatArray(pixels) } + val gopCg = Array(gopSize) { FloatArray(pixels) } + + val subbands = calculateSubbandLayout(gopWidth, gopHeight, header.decompLevels) + + // Dequantise with temporal scaling + for (t in 0 until gopSize) { + val temporalLevel = getTemporalSubbandLevel(t, gopSize, temporalLevels) + val temporalScale = getTemporalQuantiserScale(header.encoderPreset, temporalLevel) + val baseQY = kotlin.math.round(header.qY * temporalScale).toFloat().coerceIn(1.0f, 4096.0f) + val baseQCo = kotlin.math.round(header.qCo * temporalScale).toFloat().coerceIn(1.0f, 4096.0f) + val baseQCg = kotlin.math.round(header.qCg * temporalScale).toFloat().coerceIn(1.0f, 4096.0f) + + if (header.isPerceptual) { + dequantisePerceptual(quantisedCoeffs[t][0], gopY[t], subbands, baseQY, false, header.encoderQuality, header.qY, header.decompLevels) + dequantisePerceptual(quantisedCoeffs[t][1], gopCo[t], subbands, baseQCo, true, header.encoderQuality, header.qY, header.decompLevels) + dequantisePerceptual(quantisedCoeffs[t][2], gopCg[t], subbands, baseQCg, true, header.encoderQuality, header.qY, header.decompLevels) + } else { + for (i in 0 until pixels) { + gopY[t][i] = quantisedCoeffs[t][0][i] * baseQY + gopCo[t][i] = quantisedCoeffs[t][1][i] * baseQCo + gopCg[t][i] = quantisedCoeffs[t][2][i] * baseQCg + } + } + } + + // Grain synthesis on each GOP frame + for (t in 0 until gopSize) { + grainSynthesis(gopY[t], gopWidth, gopHeight, frameCounter + t, subbands, header.qY, header.encoderPreset) + } + + // Inverse 3D DWT + DwtUtil.inverseMultilevel3D(gopY, gopWidth, gopHeight, gopSize, header.decompLevels, temporalLevels, header.waveletFilter, header.temporalMotionCoder) + DwtUtil.inverseMultilevel3D(gopCo, gopWidth, gopHeight, gopSize, header.decompLevels, temporalLevels, header.waveletFilter, header.temporalMotionCoder) + DwtUtil.inverseMultilevel3D(gopCg, gopWidth, gopHeight, gopSize, header.decompLevels, temporalLevels, header.waveletFilter, header.temporalMotionCoder) + + // Convert each frame to RGBA + return (0 until gopSize).map { t -> + ycocgrToRgba(gopY[t], gopCo[t], gopCg[t], width, height, header.channelLayout) + } + } + + /** Decode unified GOP block to per-frame per-channel ShortArrays. */ + private fun decodeGopUnifiedBlock( + data: ByteArray, numFrames: Int, numPixels: Int, header: TavHeader + ): Array> { + val output = Array(numFrames) { Array(3) { ShortArray(numPixels) } } + + if (header.entropyCoder == 1) { + // EZBC: [frame_size(4)][frame_ezbc]... + var ptr2 = 0 + for (frame in 0 until numFrames) { + if (ptr2 + 4 > data.size) break + val frameSize = readInt32LE(data, ptr2); ptr2 += 4 + if (ptr2 + frameSize > data.size) break + EzbcDecode.decode2D(data, ptr2, header.channelLayout, + output[frame][0], output[frame][1], output[frame][2], null) + ptr2 += frameSize + } + } else { + // 2-bit significance map (legacy), all frames concatenated + decodeSigMapGop(data, numFrames, numPixels, header.channelLayout, output) + } + + return output + } + + private fun decodeSigMapGop( + data: ByteArray, numFrames: Int, numPixels: Int, channelLayout: Int, + output: Array> + ) { + val hasY = (channelLayout and 4) == 0 + val hasCoCg = (channelLayout and 2) == 0 + val mapBytesPerFrame = (numPixels * 2 + 7) / 8 + + var readPtr = 0 + val yMapsStart = if (hasY) { val s = readPtr; readPtr += mapBytesPerFrame * numFrames; s } else -1 + val coMapsStart = if (hasCoCg) { val s = readPtr; readPtr += mapBytesPerFrame * numFrames; s } else -1 + val cgMapsStart = if (hasCoCg) { val s = readPtr; readPtr += mapBytesPerFrame * numFrames; s } else -1 + + var yOthers = 0; var coOthers = 0; var cgOthers = 0 + + fun countOthers(mapsStart: Int): Int { + var cnt = 0 + for (frame in 0 until numFrames) { + val frameMapOffset = frame * mapBytesPerFrame + for (i in 0 until numPixels) { + val bitPos = i * 2; val byteIdx = bitPos / 8; val bitOffset = bitPos % 8 + val byteVal = data.getOrElse(mapsStart + frameMapOffset + byteIdx) { 0 }.toInt() and 0xFF + var code = (byteVal shr bitOffset) and 0x03 + if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) { + val nb = data.getOrElse(mapsStart + frameMapOffset + byteIdx + 1) { 0 }.toInt() and 0xFF + code = (code and 0x01) or ((nb and 0x01) shl 1) + } + if (code == 3) cnt++ + } + } + return cnt + } + + if (hasY) yOthers = countOthers(yMapsStart) + if (hasCoCg) { coOthers = countOthers(coMapsStart); cgOthers = countOthers(cgMapsStart) } + + val yValStart = readPtr; readPtr += yOthers * 2 + val coValStart = readPtr; readPtr += coOthers * 2 + val cgValStart = readPtr + + var yVIdx = 0; var coVIdx = 0; var cgVIdx = 0 + + for (frame in 0 until numFrames) { + val frameMapOffset = frame * mapBytesPerFrame + for (i in 0 until numPixels) { + val bitPos = i * 2; val byteIdx = bitPos / 8; val bitOffset = bitPos % 8 + + fun getCode(mapsStart: Int): Int { + val byteVal = data.getOrElse(mapsStart + frameMapOffset + byteIdx) { 0 }.toInt() and 0xFF + var code = (byteVal shr bitOffset) and 0x03 + if (bitOffset == 7 && byteIdx + 1 < mapBytesPerFrame) { + val nb = data.getOrElse(mapsStart + frameMapOffset + byteIdx + 1) { 0 }.toInt() and 0xFF + code = (code and 0x01) or ((nb and 0x01) shl 1) + } + return code + } + + fun readVal(valStart: Int, vIdx: Int): Short { + val vp = valStart + vIdx * 2 + return if (vp + 1 < data.size) { + val lo = data[vp ].toInt() and 0xFF + val hi = data[vp+1].toInt() + ((hi shl 8) or lo).toShort() + } else 0 + } + + if (hasY) { + output[frame][0][i] = when (getCode(yMapsStart)) { + 1 -> 1; 2 -> (-1).toShort(); 3 -> { val v = readVal(yValStart, yVIdx); yVIdx++; v }; else -> 0 + } + } + if (hasCoCg) { + output[frame][1][i] = when (getCode(coMapsStart)) { + 1 -> 1; 2 -> (-1).toShort(); 3 -> { val v = readVal(coValStart, coVIdx); coVIdx++; v }; else -> 0 + } + output[frame][2][i] = when (getCode(cgMapsStart)) { + 1 -> 1; 2 -> (-1).toShort(); 3 -> { val v = readVal(cgValStart, cgVIdx); cgVIdx++; v }; else -> 0 + } + } + } + } + } + +}