TAD: Terrarum Advanced Audio to use with video compression

This commit is contained in:
minjaesong
2025-10-23 18:56:57 +09:00
parent 6f669f4fd9
commit a9319fd812
10 changed files with 1887 additions and 22 deletions

View File

@@ -82,6 +82,7 @@ class AudioJSR223Delegate(private val vm: VM) {
// fun mp2DecodeFrame(mp2: MP2Env.MP2, framePtr: Long?, pcm: Boolean, outL: Long, outR: Long) = getFirstSnd()?.mp2Env?.decodeFrame(mp2, framePtr, pcm, outL, outR)
fun getBaseAddr(): Int? = getFirstSnd()?.let { return it.vm.findPeriSlotNum(it)?.times(-131072)?.minus(1) }
fun getMemAddr(): Int? = getFirstSnd()?.let { return it.vm.findPeriSlotNum(it)?.times(-1048576)?.minus(1) }
fun mp2Init() = getFirstSnd()?.mmio_write(40L, 16)
fun mp2Decode() = getFirstSnd()?.mmio_write(40L, 1)
fun mp2InitThenDecode() = getFirstSnd()?.mmio_write(40L, 17)
@@ -93,6 +94,39 @@ class AudioJSR223Delegate(private val vm: VM) {
}
}
// TAD (Terrarum Advanced Audio) decoder functions
fun tadSetQuality(quality: Int) {
getFirstSnd()?.mmio_write(43L, quality.toByte())
}
fun tadGetQuality() = getFirstSnd()?.mmio_read(43L)?.toInt()
fun tadDecode() {
getFirstSnd()?.mmio_write(42L, 1)
}
fun tadIsBusy() = getFirstSnd()?.mmio_read(44L)?.toInt() == 1
fun tadUploadDecoded(playhead: Int) {
getFirstSnd()?.let { snd ->
val ba = ByteArray(65536) // 32768 samples * 2 channels
UnsafeHelper.memcpyRaw(null, snd.tadDecodedBin.ptr, ba, UnsafeHelper.getArrayOffset(ba), 65536)
snd.playheads[playhead].pcmQueue.addLast(ba)
}
}
fun putTadDataByPtr(ptr: Int, length: Int, destOffset: Int) {
getFirstSnd()?.let { snd ->
val vkMult = if (ptr >= 0) 1 else -1
for (k in 0L until length) {
val vk = k * vkMult
snd.tadInputBin[k + destOffset] = vm.peek(ptr + vk)!!
}
}
}
fun getTadData(index: Int) = getFirstSnd()?.tadDecodedBin?.get(index.toLong())
/*

View File

@@ -4,6 +4,7 @@ import com.badlogic.gdx.Gdx
import com.badlogic.gdx.backends.lwjgl3.audio.OpenALLwjgl3Audio
import com.badlogic.gdx.utils.GdxRuntimeException
import com.badlogic.gdx.utils.Queue
import io.airlift.compress.zstd.ZstdInputStream
import net.torvald.UnsafeHelper
import net.torvald.UnsafePtr
import net.torvald.terrarum.modulecomputers.virtualcomputer.tvd.toUint
@@ -11,6 +12,7 @@ import net.torvald.tsvm.ThreeFiveMiniUfloat
import net.torvald.tsvm.VM
import net.torvald.tsvm.getHashStr
import net.torvald.tsvm.toInt
import java.io.ByteArrayInputStream
private class RenderRunnable(val playhead: AudioAdapter.Playhead) : Runnable {
private fun printdbg(msg: Any) {
@@ -125,6 +127,12 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
@Volatile private var mp2Busy = false
// TAD (Terrarum Advanced Audio) decoder buffers
internal val tadInputBin = UnsafeHelper.allocate(65536L, this) // Input: compressed TAD chunk (max 64KB)
internal val tadDecodedBin = UnsafeHelper.allocate(65536L, this) // Output: PCMu8 stereo (32768 samples * 2 channels)
internal var tadQuality = 2 // Quality level used during encoding (0-5)
@Volatile private var tadBusy = false
private val renderRunnables: Array<RenderRunnable>
private val renderThreads: Array<Thread>
private val writeQueueingRunnables: Array<WriteQueueingRunnable>
@@ -216,7 +224,9 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
in 0..114687 -> sampleBin[addr]
in 114688..131071 -> (adi - 114688).let { instruments[it / 64].getByte(it % 64) }
in 131072..262143 -> (adi - 131072).let { playdata[it / (8*64)][(it / 8) % 64].getByte(it % 8) }
else -> peek(addr % 262144)
in 262144..327679 -> tadInputBin[addr - 262144] // TAD input buffer (65536 bytes)
in 327680..393215 -> tadDecodedBin[addr - 327680] // TAD decoded output (65536 bytes)
else -> peek(addr % 393216)
}
}
@@ -227,6 +237,8 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
in 0..114687 -> { sampleBin[addr] = byte }
in 114688..131071 -> (adi - 114688).let { instruments[it / 64].setByte(it % 64, bi) }
in 131072..262143 -> (adi - 131072).let { playdata[it / (8*64)][(it / 8) % 64].setByte(it % 8, bi) }
in 262144..327679 -> tadInputBin[addr - 262144] = byte // TAD input buffer
in 327680..393215 -> tadDecodedBin[addr - 327680] = byte // TAD decoded output
}
}
@@ -239,6 +251,9 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
in 30..39 -> playheads[3].read(adi - 30)
40 -> -1
41 -> mp2Busy.toInt().toByte()
42 -> -1 // TAD control (write-only)
43 -> tadQuality.toByte()
44 -> tadBusy.toInt().toByte()
in 64..2367 -> mediaDecodedBin[addr - 64]
in 2368..4095 -> mediaFrameBin[addr - 2368]
in 4096..4097 -> 0
@@ -265,6 +280,14 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
if (bi and 16 != 0) { mp2Context = mp2Env.initialise() }
if (bi and 1 != 0) decodeMp2()
}
42 -> {
// TAD control: bit 0 = decode
if (bi and 1 != 0) decodeTad()
}
43 -> {
// TAD quality (0-5)
tadQuality = bi.coerceIn(0, 5)
}
in 64..2367 -> { mediaDecodedBin[addr - 64] = byte }
in 2368..4095 -> { mediaFrameBin[addr - 2368] = byte }
in 32768..65535 -> { (adi - 32768).let {
@@ -287,6 +310,8 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
pcmBin.destroy()
mediaFrameBin.destroy()
mediaDecodedBin.destroy()
tadInputBin.destroy()
tadDecodedBin.destroy()
}
else {
System.err.println("AudioAdapter already disposed")
@@ -304,6 +329,250 @@ class AudioAdapter(val vm: VM) : PeriBase(VM.PERITYPE_SOUND) {
mp2Env.decodeFrameU8(mp2Context, periMmioBase - 2368, true, periMmioBase - 64)
}
//=============================================================================
// TAD (Terrarum Advanced Audio) Decoder
//=============================================================================
private fun decodeTad() {
tadBusy = true
try {
// Read chunk header from tadInputBin
var offset = 0L
val sampleCount = (
(tadInputBin[offset++].toInt() and 0xFF) or
((tadInputBin[offset++].toInt() and 0xFF) shl 8)
)
val payloadSize = (
(tadInputBin[offset++].toInt() and 0xFF) or
((tadInputBin[offset++].toInt() and 0xFF) shl 8) or
((tadInputBin[offset++].toInt() and 0xFF) shl 16) or
((tadInputBin[offset++].toInt() and 0xFF) shl 24)
)
// Decompress payload if needed
val compressed = ByteArray(payloadSize)
UnsafeHelper.memcpyRaw(null, tadInputBin.ptr + offset, compressed, UnsafeHelper.getArrayOffset(compressed), payloadSize.toLong())
val payload: ByteArray = try {
ZstdInputStream(ByteArrayInputStream(compressed)).use { zstd ->
zstd.readBytes()
}
} catch (e: Exception) {
println("ERROR: Zstd decompression failed: ${e.message}")
} as ByteArray
// Decode significance maps
val quantMid = ShortArray(sampleCount)
val quantSide = ShortArray(sampleCount)
var payloadOffset = 0
val midBytes = decodeSigmap2bit(payload, payloadOffset, quantMid, sampleCount)
payloadOffset += midBytes
val sideBytes = decodeSigmap2bit(payload, payloadOffset, quantSide, sampleCount)
// Calculate DWT levels from sample count
val dwtLevels = calculateDwtLevels(sampleCount)
// Dequantize
val dwtMid = FloatArray(sampleCount)
val dwtSide = FloatArray(sampleCount)
dequantizeDwtCoefficients(quantMid, dwtMid, sampleCount, tadQuality, dwtLevels)
dequantizeDwtCoefficients(quantSide, dwtSide, sampleCount, tadQuality, dwtLevels)
// Inverse DWT
dwtDD4InverseMultilevel(dwtMid, sampleCount, dwtLevels)
dwtDD4InverseMultilevel(dwtSide, sampleCount, dwtLevels)
// Convert to signed PCM8
val pcm8Mid = ByteArray(sampleCount)
val pcm8Side = ByteArray(sampleCount)
for (i in 0 until sampleCount) {
pcm8Mid[i] = dwtMid[i].coerceIn(-128f, 127f).toInt().toByte()
pcm8Side[i] = dwtSide[i].coerceIn(-128f, 127f).toInt().toByte()
}
// M/S to L/R correlation and write to tadDecodedBin
for (i in 0 until sampleCount) {
val m = pcm8Mid[i].toInt()
val s = pcm8Side[i].toInt()
var l = m + s
var r = m - s
if (l < -128) l = -128
if (l > 127) l = 127
if (r < -128) r = -128
if (r > 127) r = 127
tadDecodedBin[i * 2L] = (l + 128).toByte() // Left (PCMu8)
tadDecodedBin[i * 2L + 1] = (r + 128).toByte() // Right (PCMu8)
}
} catch (e: Exception) {
e.printStackTrace()
} finally {
tadBusy = false
}
}
private fun decodeSigmap2bit(input: ByteArray, offset: Int, values: ShortArray, count: Int): Int {
val mapBytes = (count * 2 + 7) / 8
var readPtr = offset + mapBytes
var otherIdx = 0
for (i in 0 until count) {
val bitPos = i * 2
val byteIdx = offset + bitPos / 8
val bitOffset = bitPos % 8
var code = ((input[byteIdx].toInt() and 0xFF) shr bitOffset) and 0x03
// Handle bit spillover
if (bitOffset == 7) {
code = ((input[byteIdx].toInt() and 0xFF) shr 7) or
(((input[byteIdx + 1].toInt() and 0xFF) and 0x01) shl 1)
}
values[i] = when (code) {
0 -> 0
1 -> 1
2 -> (-1).toShort()
3 -> {
val v = ((input[readPtr].toInt() and 0xFF) or
((input[readPtr + 1].toInt() and 0xFF) shl 8)).toShort()
readPtr += 2
otherIdx++
v
}
else -> 0
}
}
return mapBytes + otherIdx * 2
}
private fun calculateDwtLevels(chunkSize: Int): Int {
if (chunkSize < 1024) {
throw IllegalArgumentException("Chunk size $chunkSize is below minimum 1024")
}
var levels = 0
var size = chunkSize
while (size > 1) {
size = size shr 1
levels++
}
return levels - 2 // Maximum decomposition leaves 4-sample approximation
}
private fun getQuantizationWeights(quality: Int, dwtLevels: Int): FloatArray {
// Extended base weights to support up to 16 DWT levels
val baseWeights = arrayOf(
/* 0*/floatArrayOf(1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f),
/* 1*/floatArrayOf(1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f),
/* 2*/floatArrayOf(1.0f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 3*/floatArrayOf(0.2f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 4*/floatArrayOf(0.2f, 0.8f, 1.0f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 5*/floatArrayOf(0.2f, 0.8f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 6*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 7*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 8*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/* 9*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/*10*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/*11*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/*12*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f),
/*13*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f, 1.5f),
/*14*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f, 1.5f),
/*15*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f, 1.5f),
/*16*/floatArrayOf(0.2f, 0.2f, 0.8f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.25f, 1.5f)
)
val qualityScale = 1.0f + ((3 - quality) * 0.5f).coerceAtLeast(0.0f)
return FloatArray(dwtLevels) { i -> (baseWeights[dwtLevels][i.coerceIn(0, 15)] * qualityScale).coerceAtLeast(1.0f) }
}
private fun dequantizeDwtCoefficients(quantized: ShortArray, coeffs: FloatArray, count: Int, quality: Int, dwtLevels: Int) {
val weights = getQuantizationWeights(quality, dwtLevels)
// Calculate sideband boundaries dynamically based on chunk size and DWT levels
val firstBandSize = count shr dwtLevels
val sidebandStarts = IntArray(dwtLevels + 2)
sidebandStarts[0] = 0
sidebandStarts[1] = firstBandSize
for (i in 2..dwtLevels + 1) {
sidebandStarts[i] = sidebandStarts[i - 1] + (firstBandSize shl (i - 2))
}
for (i in 0 until count) {
var sideband = dwtLevels
for (s in 0 until dwtLevels + 1) {
if (i < sidebandStarts[s + 1]) {
sideband = s
break
}
}
val weightIdx = if (sideband == 0) 0 else sideband - 1
val weight = weights[weightIdx.coerceIn(0, dwtLevels - 1)]
coeffs[i] = quantized[i].toFloat() * weight
}
}
private fun dwtDD4Inverse1d(data: FloatArray, length: Int) {
if (length < 2) return
val temp = FloatArray(length)
val half = (length + 1) / 2
// Split into low and high parts
for (i in 0 until half) {
temp[i] = data[i] // Even (low-pass)
}
for (i in 0 until length / 2) {
temp[half + i] = data[half + i] // Odd (high-pass)
}
// Undo update step: s[i] -= 0.25 * (d[i-1] + d[i])
for (i in 0 until half) {
val dCurr = if (i < length / 2) temp[half + i] else 0.0f
val dPrev = if (i > 0 && i - 1 < length / 2) temp[half + i - 1] else 0.0f
temp[i] -= 0.25f * (dPrev + dCurr)
}
// Undo prediction step: d[i] += P(s[i-1], s[i], s[i+1], s[i+2])
for (i in 0 until length / 2) {
val sM1 = if (i > 0) temp[i - 1] else temp[0] // mirror boundary
val s0 = temp[i]
val s1 = if (i + 1 < half) temp[i + 1] else temp[half - 1]
val s2 = if (i + 2 < half) temp[i + 2] else if (half > 1) temp[half - 2] else temp[half - 1]
val prediction = (-1.0f/16.0f)*sM1 + (9.0f/16.0f)*s0 + (9.0f/16.0f)*s1 + (-1.0f/16.0f)*s2
temp[half + i] += prediction
}
// Merge evens and odds back
for (i in 0 until half) {
data[2 * i] = temp[i]
if (2 * i + 1 < length)
data[2 * i + 1] = temp[half + i]
}
}
private fun dwtDD4InverseMultilevel(data: FloatArray, length: Int, levels: Int) {
// Calculate the length at the deepest level
var currentLength = length
for (level in 0 until levels) {
currentLength = (currentLength + 1) / 2
}
// Inverse transform: double size FIRST, then apply inverse DWT
for (level in levels - 1 downTo 0) {
currentLength *= 2 // MULTIPLY FIRST
if (currentLength > length) currentLength = length
dwtDD4Inverse1d(data, currentLength) // THEN apply inverse
}
}