mirror of
https://github.com/curioustorvald/tsvm.git
synced 2026-03-07 11:51:49 +09:00
TAD: embedded zero tree coding (basically 1D EZBC)
This commit is contained in:
@@ -447,11 +447,11 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
|
||||
// decode(y) = sign(y) * |y|^(1/γ) where γ=0.5
|
||||
val x = left[i]
|
||||
val a = kotlin.math.abs(x)
|
||||
left[i] = signum(x) * a.pow(1.4142f)
|
||||
left[i] = signum(x) * a * a
|
||||
|
||||
val y = right[i]
|
||||
val b = kotlin.math.abs(y)
|
||||
right[i] = signum(y) * b.pow(1.4142f)
|
||||
right[i] = signum(y) * b * b
|
||||
}
|
||||
}
|
||||
|
||||
@@ -540,6 +540,218 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
|
||||
}
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Binary Tree EZBC Decoder (1D Variant for TAD)
|
||||
//=============================================================================
|
||||
|
||||
// Bitstream reader for EZBC
|
||||
private class TadBitstreamReader(private val data: ByteArray) {
|
||||
private var bytePos = 0
|
||||
private var bitPos = 0
|
||||
|
||||
fun readBit(): Int {
|
||||
if (bytePos >= data.size) {
|
||||
println("ERROR: Bitstream underflow")
|
||||
return 0
|
||||
}
|
||||
|
||||
val bit = ((data[bytePos].toInt() and 0xFF) shr bitPos) and 1
|
||||
|
||||
bitPos++
|
||||
if (bitPos == 8) {
|
||||
bitPos = 0
|
||||
bytePos++
|
||||
}
|
||||
|
||||
return bit
|
||||
}
|
||||
|
||||
fun readBits(numBits: Int): Int {
|
||||
var value = 0
|
||||
for (i in 0 until numBits) {
|
||||
value = value or (readBit() shl i)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
fun getBytesConsumed(): Int {
|
||||
return bytePos + if (bitPos > 0) 1 else 0
|
||||
}
|
||||
}
|
||||
|
||||
// Block structure for 1D binary tree
|
||||
private data class TadBlock(val start: Int, val length: Int)
|
||||
|
||||
// Queue for block processing
|
||||
private class TadBlockQueue {
|
||||
private val blocks = ArrayList<TadBlock>()
|
||||
|
||||
fun push(block: TadBlock) {
|
||||
blocks.add(block)
|
||||
}
|
||||
|
||||
fun get(index: Int): TadBlock = blocks[index]
|
||||
|
||||
val size: Int get() = blocks.size
|
||||
|
||||
fun clear() {
|
||||
blocks.clear()
|
||||
}
|
||||
}
|
||||
|
||||
// Track coefficient state for refinement
|
||||
private data class TadCoeffState(var significant: Boolean = false, var firstBitplane: Int = 0)
|
||||
|
||||
// Check if all coefficients in block have |coeff| < threshold
|
||||
private fun tadIsZeroBlock(coeffs: ByteArray, block: TadBlock, threshold: Int): Boolean {
|
||||
for (i in block.start until block.start + block.length) {
|
||||
if (kotlin.math.abs(coeffs[i].toInt()) >= threshold) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Get MSB position (bitplane number)
|
||||
private fun tadGetMsbBitplane(value: Int): Int {
|
||||
if (value == 0) return 0
|
||||
var bitplane = 0
|
||||
var v = value
|
||||
while (v > 1) {
|
||||
v = v shr 1
|
||||
bitplane++
|
||||
}
|
||||
return bitplane
|
||||
}
|
||||
|
||||
// Recursively decode a significant block - subdivide until size 1
|
||||
private fun tadDecodeSignificantBlockRecursive(
|
||||
bs: TadBitstreamReader,
|
||||
coeffs: ByteArray,
|
||||
states: Array<TadCoeffState>,
|
||||
bitplane: Int,
|
||||
block: TadBlock,
|
||||
nextInsignificant: TadBlockQueue,
|
||||
nextSignificant: TadBlockQueue
|
||||
) {
|
||||
// If size 1: read sign bit and reconstruct value
|
||||
if (block.length == 1) {
|
||||
val idx = block.start
|
||||
val signBit = bs.readBit()
|
||||
|
||||
// Reconstruct absolute value from bitplane
|
||||
val absVal = 1 shl bitplane
|
||||
|
||||
// Apply sign
|
||||
coeffs[idx] = (if (signBit != 0) -absVal else absVal).toByte()
|
||||
|
||||
states[idx].significant = true
|
||||
states[idx].firstBitplane = bitplane
|
||||
nextSignificant.push(block)
|
||||
return
|
||||
}
|
||||
|
||||
// Block is > 1: subdivide into left and right halves
|
||||
val mid = block.length / 2.coerceAtLeast(1)
|
||||
|
||||
// Process left child
|
||||
val left = TadBlock(block.start, mid)
|
||||
val leftSig = bs.readBit()
|
||||
if (leftSig != 0) {
|
||||
tadDecodeSignificantBlockRecursive(bs, coeffs, states, bitplane, left, nextInsignificant, nextSignificant)
|
||||
} else {
|
||||
nextInsignificant.push(left)
|
||||
}
|
||||
|
||||
// Process right child (if exists)
|
||||
if (block.length > mid) {
|
||||
val right = TadBlock(block.start + mid, block.length - mid)
|
||||
val rightSig = bs.readBit()
|
||||
if (rightSig != 0) {
|
||||
tadDecodeSignificantBlockRecursive(bs, coeffs, states, bitplane, right, nextInsignificant, nextSignificant)
|
||||
} else {
|
||||
nextInsignificant.push(right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Binary tree EZBC decoding for a single channel (1D variant)
|
||||
private fun tadDecodeChannelEzbc(input: ByteArray, inputSize: Int, coeffs: ByteArray): Int {
|
||||
val bs = TadBitstreamReader(input)
|
||||
|
||||
// Read header: MSB bitplane and length
|
||||
val msbBitplane = bs.readBits(8)
|
||||
val count = bs.readBits(16)
|
||||
|
||||
// Initialize coefficient array to zero
|
||||
coeffs.fill(0)
|
||||
|
||||
// Track coefficient significance
|
||||
val states = Array(count) { TadCoeffState() }
|
||||
|
||||
// Initialize queues
|
||||
val insignificantQueue = TadBlockQueue()
|
||||
val nextInsignificant = TadBlockQueue()
|
||||
val significantQueue = TadBlockQueue()
|
||||
val nextSignificant = TadBlockQueue()
|
||||
|
||||
// Start with root block as insignificant
|
||||
val root = TadBlock(0, count)
|
||||
insignificantQueue.push(root)
|
||||
|
||||
// Process bitplanes from MSB to LSB
|
||||
for (bitplane in msbBitplane downTo 0) {
|
||||
val threshold = 1 shl bitplane
|
||||
|
||||
// Process insignificant blocks
|
||||
for (i in 0 until insignificantQueue.size) {
|
||||
val block = insignificantQueue.get(i)
|
||||
|
||||
val sig = bs.readBit()
|
||||
if (sig == 0) {
|
||||
// Still insignificant
|
||||
nextInsignificant.push(block)
|
||||
} else {
|
||||
// Became significant: recursively decode
|
||||
tadDecodeSignificantBlockRecursive(
|
||||
bs, coeffs, states, bitplane, block,
|
||||
nextInsignificant, nextSignificant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Refinement pass: read next bit for already-significant coefficients
|
||||
for (i in 0 until significantQueue.size) {
|
||||
val block = significantQueue.get(i)
|
||||
val idx = block.start
|
||||
|
||||
val bit = bs.readBit()
|
||||
|
||||
// Add this bit to the coefficient's magnitude
|
||||
if (bit != 0) {
|
||||
val sign = if (coeffs[idx] < 0) -1 else 1
|
||||
val absVal = kotlin.math.abs(coeffs[idx].toInt())
|
||||
coeffs[idx] = (sign * (absVal or (1 shl bitplane))).toByte()
|
||||
}
|
||||
}
|
||||
|
||||
// Swap queues for next bitplane
|
||||
insignificantQueue.clear()
|
||||
for (i in 0 until nextInsignificant.size) {
|
||||
insignificantQueue.push(nextInsignificant.get(i))
|
||||
}
|
||||
nextInsignificant.clear()
|
||||
|
||||
significantQueue.clear()
|
||||
for (i in 0 until nextSignificant.size) {
|
||||
significantQueue.push(nextSignificant.get(i))
|
||||
}
|
||||
nextSignificant.clear()
|
||||
}
|
||||
|
||||
return bs.getBytesConsumed()
|
||||
}
|
||||
|
||||
private fun decodeTad() {
|
||||
tadBusy = true
|
||||
try {
|
||||
@@ -571,9 +783,23 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
|
||||
return
|
||||
}
|
||||
|
||||
// Decode raw int8_t storage (no significance map - encoder uses raw format)
|
||||
val quantMid = payload.sliceArray(0 until sampleCount)
|
||||
val quantSide = payload.sliceArray(sampleCount until sampleCount*2)
|
||||
// Decode using binary tree EZBC
|
||||
val quantMid = ByteArray(sampleCount)
|
||||
val quantSide = ByteArray(sampleCount)
|
||||
|
||||
// Decode Mid channel
|
||||
val midBytesConsumed = tadDecodeChannelEzbc(
|
||||
payload,
|
||||
payload.size,
|
||||
quantMid
|
||||
)
|
||||
|
||||
// Decode Side channel (starts after Mid channel data)
|
||||
val sideBytesConsumed = tadDecodeChannelEzbc(
|
||||
payload.sliceArray(midBytesConsumed until payload.size),
|
||||
payload.size - midBytesConsumed,
|
||||
quantSide
|
||||
)
|
||||
|
||||
// Calculate DWT levels from sample count
|
||||
val dwtLevels = calculateDwtLevels(sampleCount)
|
||||
|
||||
Reference in New Issue
Block a user