TAV: EZBC entropy coding

This commit is contained in:
minjaesong
2025-10-20 16:40:45 +09:00
parent 019f0aaed5
commit 9553b281af
3 changed files with 1596 additions and 1044 deletions

View File

@@ -4230,6 +4230,335 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
}
/**
* EZBC (Embedded Zero Block Coding) decoder for a single channel
* Decodes hierarchical zero-block coded bitstream
*
* @param ezbc_data EZBC-encoded bitstream
* @param offset Starting offset in ezbc_data
* @param size Size of EZBC data in bytes
* @param outputCoeffs Output array for decoded coefficients
*/
private fun decodeChannelEZBC(ezbcData: ByteArray, offset: Int, size: Int, outputCoeffs: ShortArray) {
var bytePos = offset
var bitPos = 0
// Helper: read N bits from bitstream
var hitEndOfStream = false
fun readBits(numBits: Int): Int {
var result = 0
for (i in 0 until numBits) {
if (bytePos >= offset + size) {
if (!hitEndOfStream) {
println("[EZBC-BITS] HIT END OF STREAM at byte $bytePos (size=$size, requested $numBits bits)")
hitEndOfStream = true
}
return result
}
val bit = (ezbcData[bytePos].toInt() shr bitPos) and 1
result = result or (bit shl i)
bitPos++
if (bitPos == 8) {
bitPos = 0
bytePos++
}
}
return result
}
// Debug: print raw bytes before reading header
if (ezbcData.size >= offset + 9) {
println("[EZBC-DEC] First 9 bytes at offset $offset: ${(0..8).map {
String.format("%02X", ezbcData[offset + it].toInt() and 0xFF)
}.joinToString(" ")}")
}
// Read header: MSB bitplane, width, height
val msbBitplane = readBits(8)
val width = readBits(16)
val height = readBits(16)
println("[EZBC-DEC] Decoded header: MSB=$msbBitplane, width=$width, height=$height")
if (width * height != outputCoeffs.size) {
System.err.println("EZBC dimension mismatch: ${width}x${height} != ${outputCoeffs.size}")
return
}
// Initialize coefficient state tracking
val significant = BooleanArray(outputCoeffs.size)
val firstBitplane = IntArray(outputCoeffs.size)
// Initialize output to zero
outputCoeffs.fill(0)
var totalSignificantCoeffs = 0
// Queue structures for block processing
data class Block(val x: Int, val y: Int, val width: Int, val height: Int)
var insignificantQueue = ArrayList<Block>()
var nextInsignificant = ArrayList<Block>()
var significantQueue = ArrayList<Block>()
var nextSignificant = ArrayList<Block>()
// Start with root block
insignificantQueue.add(Block(0, 0, width, height))
// Recursive function to process a significant block and its children
fun processSignificantBlockRecursive(block: Block, bitplane: Int, threshold: Int): Int {
var signBitsRead = 0
// If 1x1 block: read sign bit and add to significant queue
if (block.width == 1 && block.height == 1) {
val idx = block.y * width + block.x
val signBit = readBits(1)
signBitsRead++
// Set coefficient to threshold value with sign
outputCoeffs[idx] = (if (signBit == 1) -threshold else threshold).toShort()
significant[idx] = true
firstBitplane[idx] = bitplane
nextSignificant.add(block)
return signBitsRead
}
// Block is > 1x1: subdivide and recursively process children
var midX = block.width / 2
var midY = block.height / 2
if (midX == 0) midX = 1
if (midY == 0) midY = 1
// Top-left child
val tl = Block(block.x, block.y, midX, midY)
val tlFlag = readBits(1)
if (tlFlag == 1) {
signBitsRead += processSignificantBlockRecursive(tl, bitplane, threshold)
} else {
nextInsignificant.add(tl)
}
// Top-right child (if exists)
if (block.width > midX) {
val tr = Block(block.x + midX, block.y, block.width - midX, midY)
val trFlag = readBits(1)
if (trFlag == 1) {
signBitsRead += processSignificantBlockRecursive(tr, bitplane, threshold)
} else {
nextInsignificant.add(tr)
}
}
// Bottom-left child (if exists)
if (block.height > midY) {
val bl = Block(block.x, block.y + midY, midX, block.height - midY)
val blFlag = readBits(1)
if (blFlag == 1) {
signBitsRead += processSignificantBlockRecursive(bl, bitplane, threshold)
} else {
nextInsignificant.add(bl)
}
}
// Bottom-right child (if exists)
if (block.width > midX && block.height > midY) {
val br = Block(block.x + midX, block.y + midY, block.width - midX, block.height - midY)
val brFlag = readBits(1)
if (brFlag == 1) {
signBitsRead += processSignificantBlockRecursive(br, bitplane, threshold)
} else {
nextInsignificant.add(br)
}
}
return signBitsRead
}
// Process bitplanes from MSB to LSB
for (bitplane in msbBitplane downTo 0) {
val threshold = 1 shl bitplane
val insignifCountBefore = insignificantQueue.size
val signifCountBefore = significantQueue.size
// Process insignificant blocks
for (block in insignificantQueue) {
val flag = readBits(1)
if (flag == 0) {
// Still insignificant
nextInsignificant.add(block)
} else {
// Became significant - use recursive processing
val signBitsRead = processSignificantBlockRecursive(block, bitplane, threshold)
totalSignificantCoeffs += signBitsRead
}
}
// Process significant 1x1 blocks (refinement)
for (block in significantQueue) {
val idx = block.y * width + block.x
val refineBit = readBits(1)
// Add refinement bit at current bitplane
if (refineBit == 1) {
val bitValue = 1 shl bitplane
if (outputCoeffs[idx] < 0) {
outputCoeffs[idx] = (outputCoeffs[idx] - bitValue).toShort()
} else {
outputCoeffs[idx] = (outputCoeffs[idx] + bitValue).toShort()
}
}
// Keep in significant queue
nextSignificant.add(block)
}
// Swap queues
insignificantQueue = nextInsignificant
significantQueue = nextSignificant
nextInsignificant = ArrayList()
nextSignificant = ArrayList()
if (bitplane == msbBitplane || bitplane == 0 || (msbBitplane - bitplane) % 3 == 0) {
println("[EZBC-BP] Bitplane $bitplane: threshold=$threshold, insignif=${insignifCountBefore}->${insignificantQueue.size}, signif=${signifCountBefore}->${significantQueue.size}, totalSig=$totalSignificantCoeffs")
}
}
// Debug summary
println("[EZBC-CH] Decoded $totalSignificantCoeffs significant coefficients out of ${outputCoeffs.size}")
val nonZeroCount = outputCoeffs.count { it != 0.toShort() }
println("[EZBC-CH] Non-zero coefficients: $nonZeroCount")
val maxVal = outputCoeffs.maxOrNull() ?: 0
val minVal = outputCoeffs.minOrNull() ?: 0
println("[EZBC-CH] Value range: [$minVal, $maxVal]")
}
/**
* EZBC decoder wrapper for variable channel layout
* Detects and decodes EZBC-encoded significance maps
*
* Format: [size_y(4)][ezbc_y][size_co(4)][ezbc_co][size_cg(4)][ezbc_cg]...
*/
private fun postprocessCoefficientsEZBC(compressedData: ByteArray, compressedOffset: Int, coeffCount: Int,
channelLayout: Int, outputY: ShortArray?, outputCo: ShortArray?,
outputCg: ShortArray?, outputAlpha: ShortArray?) {
// Determine active channels based on channel_layout bitfield
// Bit 2 (value 4): 0=has Y/I, 1=no Y/I
// Bit 1 (value 2): 0=has Co/Cg or Ct/Cp, 1=no chroma
// Bit 0 (value 1): 1=has alpha, 0=no alpha
val hasY = (channelLayout and 4) == 0
val hasCo = (channelLayout and 2) == 0
val hasCg = (channelLayout and 2) == 0 // Same as Co - both chroma channels present together
val hasAlpha = (channelLayout and 1) != 0
println("[EZBC] Decoding: coeffCount=$coeffCount, channelLayout=$channelLayout, hasY=$hasY, hasCo=$hasCo, hasCg=$hasCg")
var offset = compressedOffset
// Decode Y channel
if (hasY && outputY != null) {
val size = ((compressedData[offset].toInt() and 0xFF) or
((compressedData[offset + 1].toInt() and 0xFF) shl 8) or
((compressedData[offset + 2].toInt() and 0xFF) shl 16) or
((compressedData[offset + 3].toInt() and 0xFF) shl 24))
println("[EZBC] Y channel: size=$size, offset=$offset")
offset += 4
decodeChannelEZBC(compressedData, offset, size, outputY)
println("[EZBC] Y channel decoded: first 10 values = ${outputY.take(10)}")
offset += size
}
// Decode Co channel
if (hasCo && outputCo != null) {
val size = ((compressedData[offset].toInt() and 0xFF) or
((compressedData[offset + 1].toInt() and 0xFF) shl 8) or
((compressedData[offset + 2].toInt() and 0xFF) shl 16) or
((compressedData[offset + 3].toInt() and 0xFF) shl 24))
println("[EZBC] Co channel: size=$size, offset=$offset")
offset += 4
decodeChannelEZBC(compressedData, offset, size, outputCo)
println("[EZBC] Co channel decoded: first 10 values = ${outputCo.take(10)}")
offset += size
}
// Decode Cg channel
if (hasCg && outputCg != null) {
val size = ((compressedData[offset].toInt() and 0xFF) or
((compressedData[offset + 1].toInt() and 0xFF) shl 8) or
((compressedData[offset + 2].toInt() and 0xFF) shl 16) or
((compressedData[offset + 3].toInt() and 0xFF) shl 24))
println("[EZBC] Cg channel: size=$size, offset=$offset")
offset += 4
decodeChannelEZBC(compressedData, offset, size, outputCg)
println("[EZBC] Cg channel decoded: first 10 values = ${outputCg.take(10)}")
offset += size
}
// Decode Alpha channel
if (hasAlpha && outputAlpha != null) {
val size = ((compressedData[offset].toInt() and 0xFF) or
((compressedData[offset + 1].toInt() and 0xFF) shl 8) or
((compressedData[offset + 2].toInt() and 0xFF) shl 16) or
((compressedData[offset + 3].toInt() and 0xFF) shl 24))
offset += 4
decodeChannelEZBC(compressedData, offset, size, outputAlpha)
}
}
/**
* Auto-detecting coefficient decoder wrapper
* Detects EZBC vs twobit-map format and calls appropriate decoder
*/
private fun postprocessCoefficientsAuto(compressedData: ByteArray, compressedOffset: Int, coeffCount: Int,
channelLayout: Int, outputY: ShortArray?, outputCo: ShortArray?,
outputCg: ShortArray?, outputAlpha: ShortArray?): Boolean {
// TEMPORARY: Force EZBC mode until entropy coding method flag is added
val isEZBC = true
/* Auto-detection disabled for now - will use entropy coding method flag later
// Better auto-detection: Check EZBC header structure
// EZBC format: [size(4)][msb_bitplane(1)][width(2)][height(2)][bits...]
// Twobit-map format: [2-bit map + values...]
val isEZBC = if (compressedData.size >= compressedOffset + 9) {
// Read first uint32 (should be EZBC channel size)
val possibleSize = ((compressedData[compressedOffset].toInt() and 0xFF) or
((compressedData[compressedOffset + 1].toInt() and 0xFF) shl 8) or
((compressedData[compressedOffset + 2].toInt() and 0xFF) shl 16) or
((compressedData[compressedOffset + 3].toInt() and 0xFF) shl 24))
val msbBitplane = compressedData[compressedOffset + 4].toInt() and 0xFF
val width = ((compressedData[compressedOffset + 5].toInt() and 0xFF) or
((compressedData[compressedOffset + 6].toInt() and 0xFF) shl 8))
val height = ((compressedData[compressedOffset + 7].toInt() and 0xFF) or
((compressedData[compressedOffset + 8].toInt() and 0xFF) shl 8))
println("[AUTO] Checking EZBC: possibleSize=$possibleSize, msb=$msbBitplane, w=$width, h=$height, coeffCount=$coeffCount")
// Valid EZBC header should have reasonable size, MSB bitplane, and dimensions matching coeffCount
val detected = possibleSize in 10..(coeffCount * 4) && msbBitplane < 20 && width > 0 && height > 0 && width * height == coeffCount
println("[AUTO] Detection result: isEZBC=$detected")
detected
} else {
println("[AUTO] Not enough data for EZBC detection, using twobit-map")
false
}
*/
if (isEZBC) {
println("[AUTO] Using EZBC decoder (FORCED)")
postprocessCoefficientsEZBC(compressedData, compressedOffset, coeffCount,
channelLayout, outputY, outputCo, outputCg, outputAlpha)
} else {
println("[AUTO] Using twobit-map decoder")
postprocessCoefficientsVariableLayout(compressedData, compressedOffset, coeffCount,
channelLayout, outputY, outputCo, outputCg, outputAlpha)
}
return isEZBC
}
/**
* Reconstruct per-frame coefficients from unified GOP block (2-bit format)
* Reverse of encoder's preprocess_gop_unified()
@@ -4408,6 +4737,95 @@ class GraphicsJSR223Delegate(private val vm: VM) {
return output
}
/**
* Reconstruct per-frame coefficients from unified GOP block (EZBC format)
* Format: [frame0_size(4)][frame0_ezbc][frame1_size(4)][frame1_ezbc]...
*
* @param decompressedData Unified EZBC block data (after Zstd decompression)
* @param numFrames Number of frames in GOP
* @param numPixels Pixels per frame (width × height)
* @param channelLayout Channel layout (0=YCoCg, 2=Y-only, etc)
* @return Array of [frame][channel] where channel: 0=Y, 1=Co, 2=Cg
*/
private fun tavPostprocessGopEZBC(
decompressedData: ByteArray,
numFrames: Int,
numPixels: Int,
channelLayout: Int
): Array<Array<ShortArray>> {
// Allocate output arrays
val output = Array(numFrames) { Array(3) { ShortArray(numPixels) } }
var offset = 0
for (frame in 0 until numFrames) {
if (offset + 4 > decompressedData.size) break
// Read frame size
val frameSize = ((decompressedData[offset].toInt() and 0xFF) or
((decompressedData[offset + 1].toInt() and 0xFF) shl 8) or
((decompressedData[offset + 2].toInt() and 0xFF) shl 16) or
((decompressedData[offset + 3].toInt() and 0xFF) shl 24))
offset += 4
if (offset + frameSize > decompressedData.size) break
// Decode this frame with EZBC
postprocessCoefficientsEZBC(
decompressedData, offset, numPixels, channelLayout,
output[frame][0], output[frame][1], output[frame][2], null
)
offset += frameSize
}
return output
}
/**
* Auto-detecting GOP postprocessor
* Detects EZBC vs twobit-map format and calls appropriate decoder
*/
private fun tavPostprocessGopAuto(
decompressedData: ByteArray,
numFrames: Int,
numPixels: Int,
channelLayout: Int
): Pair<Boolean, Array<Array<ShortArray>>> {
// TEMPORARY: Force EZBC mode until entropy coding method flag is added
val isEZBC = true
/* Auto-detection disabled for now - will use entropy coding method flag later
// Auto-detect: EZBC format has frame size headers
// Check if first 4 bytes look like a reasonable frame size
val isEZBC = if (decompressedData.size >= 8) {
val possibleSize = ((decompressedData[0].toInt() and 0xFF) or
((decompressedData[1].toInt() and 0xFF) shl 8) or
((decompressedData[2].toInt() and 0xFF) shl 16) or
((decompressedData[3].toInt() and 0xFF) shl 24))
// Check if this looks like an EZBC header (size followed by MSB bitplane)
if (possibleSize in 10..(numPixels * 16)) {
val msbBitplane = decompressedData[4].toInt() and 0xFF
msbBitplane < 20 // Valid MSB bitplane
} else {
false
}
} else {
false
}
*/
println("[GOP AUTO] Using ${if (isEZBC) "EZBC (FORCED)" else "twobit-map"} decoder")
val data = if (isEZBC) {
tavPostprocessGopEZBC(decompressedData, numFrames, numPixels, channelLayout)
} else {
tavPostprocessGopUnified(decompressedData, numFrames, numPixels, channelLayout)
}
return Pair(isEZBC, data)
}
// TAV Simulated overlapping tiles constants (must match encoder)
private val TAV_TILE_SIZE_X = 640
private val TAV_TILE_SIZE_Y = 540
@@ -4658,7 +5076,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
private fun dequantiseDWTSubbandsPerceptual(qIndex: Int, qYGlobal: Int, quantised: ShortArray, dequantised: FloatArray,
subbands: List<DWTSubbandInfo>, baseQuantiser: Float, isChroma: Boolean, decompLevels: Int) {
subbands: List<DWTSubbandInfo>, baseQuantiser: Float, isChroma: Boolean, decompLevels: Int,
isEZBC: Boolean) {
// CRITICAL FIX: Encoder stores coefficients in LINEAR order, not subband-mapped order!
// The subband layout calculation is only used for determining perceptual weights,
@@ -4681,10 +5100,23 @@ class GraphicsJSR223Delegate(private val vm: VM) {
}
// Apply linear dequantisation with perceptual weights (matching encoder's linear storage)
// EZBC mode: coefficients are ALREADY DENORMALIZED by encoder
// e.g., encoder: coeff=377 → quantize: 377/48=7.85→8 → denormalize: 8*48=384 → store 384
// decoder: read 384 → pass through as-is (already in correct range for IDWT)
// Significance-map mode: coefficients are normalized (quantized only)
// e.g., encoder stores 8 = round(377/48)
// decoder must multiply: 8 * 48 = 384 (denormalize for IDWT)
for (i in quantised.indices) {
if (i < dequantised.size) {
val effectiveQuantiser = baseQuantiser * weights[i]
dequantised[i] = quantised[i] * effectiveQuantiser
dequantised[i] = if (isEZBC) {
// EZBC mode: pass through as-is (coefficients already denormalized)
quantised[i].toFloat()
} else {
// Significance-map mode: multiply to denormalize (coefficients are normalized)
quantised[i] * effectiveQuantiser
}
}
}
@@ -4696,11 +5128,14 @@ class GraphicsJSR223Delegate(private val vm: VM) {
val weightRange = if (weightStats.isNotEmpty())
"weights: ${weightStats.first()}-${weightStats.last()}" else "no weights"
for (coeff in quantised) {
if (coeff != 0.toShort()) nonZeroCoeffs++
for (i in quantised.indices) {
if (quantised[i] != 0.toShort()) {
nonZeroCoeffs++
}
}
println("LINEAR PERCEPTUAL DEQUANT: $channelType - coeffs=${quantised.size}, nonzero=$nonZeroCoeffs, $weightRange")
val mode = if (isEZBC) "EZBC (pass-through)" else "Sigmap (multiply)"
println("LINEAR PERCEPTUAL DEQUANT: $channelType - mode=$mode, coeffs=${quantised.size}, nonzero=$nonZeroCoeffs, $weightRange")
}
}
@@ -4971,8 +5406,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
return count
}
// Use variable channel layout concatenated maps format
postprocessCoefficientsVariableLayout(coeffBuffer, 0, coeffCount, channelLayout, quantisedY, quantisedCo, quantisedCg, quantisedAlpha)
// Use auto-detecting decoder (EZBC or variable channel layout concatenated maps)
val isEZBCMode = postprocessCoefficientsAuto(coeffBuffer, 0, coeffCount, channelLayout, quantisedY, quantisedCo, quantisedCg, quantisedAlpha)
// Calculate total size for variable channel layout format
val numChannels = when (channelLayout) {
@@ -5013,9 +5448,9 @@ class GraphicsJSR223Delegate(private val vm: VM) {
val tileHeight = if (isMonoblock) height else TAV_PADDED_TILE_SIZE_Y
val subbands = calculateSubbandLayout(tileWidth, tileHeight, decompLevels)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedY, yTile, subbands, qY.toFloat(), false, decompLevels)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedCo, coTile, subbands, qCo.toFloat(), true, decompLevels)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedCg, cgTile, subbands, qCg.toFloat(), true, decompLevels)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedY, yTile, subbands, qY.toFloat(), false, decompLevels, isEZBCMode)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedCo, coTile, subbands, qCo.toFloat(), true, decompLevels, isEZBCMode)
dequantiseDWTSubbandsPerceptual(qIndex, qYGlobal, quantisedCg, cgTile, subbands, qCg.toFloat(), true, decompLevels, isEZBCMode)
// Remove grain synthesis from Y channel (must happen after dequantization, before inverse DWT)
// Use perceptual weights since this is the perceptual quantization path
@@ -5584,8 +6019,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
return count
}
// Use variable channel layout concatenated maps format for deltas
postprocessCoefficientsVariableLayout(coeffBuffer, 0, coeffCount, channelLayout, deltaY, deltaCo, deltaCg, deltaAlpha)
// Use auto-detecting decoder for deltas (EZBC or variable channel layout concatenated maps)
postprocessCoefficientsAuto(coeffBuffer, 0, coeffCount, channelLayout, deltaY, deltaCo, deltaCg, deltaAlpha)
// Calculate total size for variable channel layout format (deltas)
val numChannels = when (channelLayout) {
@@ -6352,8 +6787,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
return arrayOf(0, dbgOut)
}
// Step 2: Postprocess unified block to per-frame coefficients
val quantizedCoeffs = tavPostprocessGopUnified(
// Step 2: Postprocess unified block to per-frame coefficients (auto-detect EZBC vs twobit-map)
val (isEZBCMode, quantizedCoeffs) = tavPostprocessGopAuto(
decompressedData,
gopSize,
canvasPixels, // Use expanded canvas size
@@ -6382,19 +6817,22 @@ class GraphicsJSR223Delegate(private val vm: VM) {
dequantiseDWTSubbandsPerceptual(
qIndex, qYGlobal,
quantizedCoeffs[t][0], gopY[t],
subbands, baseQY, false, spatialLevels // isChroma=false
subbands, baseQY, false, spatialLevels, // isChroma=false
isEZBCMode
)
dequantiseDWTSubbandsPerceptual(
qIndex, qYGlobal,
quantizedCoeffs[t][1], gopCo[t],
subbands, baseQCo, true, spatialLevels // isChroma=true
subbands, baseQCo, true, spatialLevels, // isChroma=true
isEZBCMode
)
dequantiseDWTSubbandsPerceptual(
qIndex, qYGlobal,
quantizedCoeffs[t][2], gopCg[t],
subbands, baseQCg, true, spatialLevels // isChroma=true
subbands, baseQCg, true, spatialLevels, // isChroma=true
isEZBCMode
)
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,413 +9,6 @@
extern "C" {
// Helper: Compute SAD (Sum of Absolute Differences) for a block
static int compute_sad(
const unsigned char *ref, const unsigned char *cur,
int ref_x, int ref_y, int cur_x, int cur_y,
int width, int height, int block_size
) {
int sad = 0;
for (int by = 0; by < block_size; by++) {
for (int bx = 0; bx < block_size; bx++) {
int ry = ref_y + by;
int rx = ref_x + bx;
int cy = cur_y + by;
int cx = cur_x + bx;
// Boundary check
if (rx < 0 || rx >= width || ry < 0 || ry >= height ||
cx < 0 || cx >= width || cy < 0 || cy >= height) {
sad += 255; // Penalty for out-of-bounds
continue;
}
int ref_val = ref[ry * width + rx];
int cur_val = cur[cy * width + cx];
sad += abs(ref_val - cur_val);
}
}
return sad;
}
// Parabolic interpolation for sub-pixel refinement
// Given SAD values at positions (-1, 0, +1), estimate peak location
static float parabolic_interp(int sad_m1, int sad_0, int sad_p1) {
// Fit parabola: y = a*x^2 + b*x + c
// Peak at x = -b/(2a) = (sad_m1 - sad_p1) / (2*(sad_m1 - 2*sad_0 + sad_p1))
int denom = 2 * (sad_m1 - 2 * sad_0 + sad_p1);
if (denom == 0) return 0.0f;
float offset = (float)(sad_m1 - sad_p1) / denom;
// Clamp to ±0.5 for reasonable sub-pixel values
if (offset < -0.5f) offset = -0.5f;
if (offset > 0.5f) offset = 0.5f;
return offset;
}
// Diamond search pattern for integer-pixel motion estimation
static void diamond_search(
const unsigned char *ref, const unsigned char *cur,
int cx, int cy, int width, int height, int block_size,
int search_range, int *best_dx, int *best_dy
) {
// Large diamond pattern (distance 2)
const int large_diamond[8][2] = {
{0, -2}, {-1, -1}, {1, -1}, {-2, 0},
{2, 0}, {-1, 1}, {1, 1}, {0, 2}
};
// Small diamond pattern (distance 1)
const int small_diamond[4][2] = {
{0, -1}, {-1, 0}, {1, 0}, {0, 1}
};
int dx = 0, dy = 0;
int best_sad = compute_sad(ref, cur, cx + dx, cy + dy, cx, cy, width, height, block_size);
// Large diamond search
bool improved = true;
while (improved) {
improved = false;
for (int i = 0; i < 8; i++) {
int test_dx = dx + large_diamond[i][0];
int test_dy = dy + large_diamond[i][1];
if (abs(test_dx) > search_range || abs(test_dy) > search_range) {
continue;
}
int sad = compute_sad(ref, cur, cx + test_dx, cy + test_dy, cx, cy, width, height, block_size);
if (sad < best_sad) {
best_sad = sad;
dx = test_dx;
dy = test_dy;
improved = true;
break;
}
}
}
// Small diamond refinement
improved = true;
while (improved) {
improved = false;
for (int i = 0; i < 4; i++) {
int test_dx = dx + small_diamond[i][0];
int test_dy = dy + small_diamond[i][1];
if (abs(test_dx) > search_range || abs(test_dy) > search_range) {
continue;
}
int sad = compute_sad(ref, cur, cx + test_dx, cy + test_dy, cx, cy, width, height, block_size);
if (sad < best_sad) {
best_sad = sad;
dx = test_dx;
dy = test_dy;
improved = true;
break;
}
}
}
*best_dx = dx;
*best_dy = dy;
}
// Sub-pixel refinement using parabolic interpolation
static void subpixel_refinement(
const unsigned char *ref, const unsigned char *cur,
int cx, int cy, int width, int height, int block_size,
int int_dx, int int_dy, // Integer-pixel motion
float *subpix_dx, float *subpix_dy // Output: 1/4-pixel precision
) {
// Get SAD at integer position and neighbors
int sad_0_0 = compute_sad(ref, cur, cx + int_dx, cy + int_dy, cx, cy, width, height, block_size);
// Horizontal neighbors
int sad_m1_0 = compute_sad(ref, cur, cx + int_dx - 1, cy + int_dy, cx, cy, width, height, block_size);
int sad_p1_0 = compute_sad(ref, cur, cx + int_dx + 1, cy + int_dy, cx, cy, width, height, block_size);
// Vertical neighbors
int sad_0_m1 = compute_sad(ref, cur, cx + int_dx, cy + int_dy - 1, cx, cy, width, height, block_size);
int sad_0_p1 = compute_sad(ref, cur, cx + int_dx, cy + int_dy + 1, cx, cy, width, height, block_size);
// Parabolic interpolation
float offset_x = parabolic_interp(sad_m1_0, sad_0_0, sad_p1_0);
float offset_y = parabolic_interp(sad_0_m1, sad_0_0, sad_0_p1);
// Quantize to 1/4-pixel precision
*subpix_dx = int_dx + roundf(offset_x * 4.0f) / 4.0f;
*subpix_dy = int_dy + roundf(offset_y * 4.0f) / 4.0f;
}
// MPEG-style bidirectional motion estimation
// Uses variable block sizes (16×16, optionally split to 8×8)
// 4-pixel overlap between blocks to reduce blocking artifacts
// Diamond search + parabolic sub-pixel refinement
void estimate_motion_optical_flow(
const unsigned char *frame1_rgb, const unsigned char *frame2_rgb,
int width, int height,
float **out_flow_x, float **out_flow_y
) {
// Convert RGB to grayscale
unsigned char *gray1 = (unsigned char*)std::malloc(width * height);
unsigned char *gray2 = (unsigned char*)std::malloc(width * height);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = y * width + x;
int rgb_idx = idx * 3;
gray1[idx] = (unsigned char)(0.299f * frame1_rgb[rgb_idx] +
0.587f * frame1_rgb[rgb_idx + 1] +
0.114f * frame1_rgb[rgb_idx + 2]);
gray2[idx] = (unsigned char)(0.299f * frame2_rgb[rgb_idx] +
0.587f * frame2_rgb[rgb_idx + 1] +
0.114f * frame2_rgb[rgb_idx + 2]);
}
}
*out_flow_x = (float*)std::malloc(width * height * sizeof(float));
*out_flow_y = (float*)std::malloc(width * height * sizeof(float));
std::memset(*out_flow_x, 0, width * height * sizeof(float));
std::memset(*out_flow_y, 0, width * height * sizeof(float));
// Block parameters
const int block_size = 16;
const int overlap = 4;
const int stride = block_size - overlap; // 12 pixels
const int search_range = 16; // ±16 pixels
// Process overlapping blocks
for (int by = 0; by < height; by += stride) {
for (int bx = 0; bx < width; bx += stride) {
int actual_block_size = block_size;
// Clamp block to frame boundary
if (bx + block_size > width || by + block_size > height) {
continue; // Skip partial blocks at edges
}
// Integer-pixel diamond search
int int_dx = 0, int_dy = 0;
diamond_search(gray1, gray2, bx, by, width, height,
actual_block_size, search_range, &int_dx, &int_dy);
// Sub-pixel refinement
float subpix_dx = 0.0f, subpix_dy = 0.0f;
subpixel_refinement(gray1, gray2, bx, by, width, height,
actual_block_size, int_dx, int_dy,
&subpix_dx, &subpix_dy);
// Fill motion vectors for block with distance-weighted blending in overlap regions
for (int y = by; y < by + actual_block_size && y < height; y++) {
for (int x = bx; x < bx + actual_block_size && x < width; x++) {
int idx = y * width + x;
// Distance from block center for blending weight
float dx_from_center = (x - (bx + actual_block_size / 2));
float dy_from_center = (y - (by + actual_block_size / 2));
float dist = sqrtf(dx_from_center * dx_from_center +
dy_from_center * dy_from_center);
// Weight decreases with distance from center (for smooth blending in overlaps)
float weight = 1.0f / (1.0f + dist / actual_block_size);
// Accumulate weighted motion (will be normalized later)
(*out_flow_x)[idx] += subpix_dx * weight;
(*out_flow_y)[idx] += subpix_dy * weight;
}
}
}
}
std::free(gray1);
std::free(gray2);
}
// Build distortion mesh from dense optical flow field
void build_mesh_from_flow(
const float *flow_x, const float *flow_y,
int width, int height,
int mesh_w, int mesh_h,
short *mesh_dx, short *mesh_dy // Output: 1/8 pixel precision
) {
int cell_w = width / mesh_w;
int cell_h = height / mesh_h;
for (int my = 0; my < mesh_h; my++) {
for (int mx = 0; mx < mesh_w; mx++) {
// Cell center coordinates
int cx = mx * cell_w + cell_w / 2;
int cy = my * cell_h + cell_h / 2;
// Sample flow at cell center (5×5 neighborhood for robustness)
float sum_dx = 0.0f, sum_dy = 0.0f;
int count = 0;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
int px = cx + dx;
int py = cy + dy;
if (px >= 0 && px < width && py >= 0 && py < height) {
int idx = py * width + px;
sum_dx += flow_x[idx];
sum_dy += flow_y[idx];
count++;
}
}
}
float avg_dx = (count > 0) ? (sum_dx / count) : 0.0f;
float avg_dy = (count > 0) ? (sum_dy / count) : 0.0f;
int mesh_idx = my * mesh_w + mx;
mesh_dx[mesh_idx] = (short)(avg_dx * 4.0f); // 1/4 pixel precision
mesh_dy[mesh_idx] = (short)(avg_dy * 4.0f);
}
}
}
// Laplacian smoothing for mesh spatial coherence
void smooth_mesh_laplacian(
short *mesh_dx, short *mesh_dy,
int mesh_width, int mesh_height,
float smoothness, int iterations
) {
short *temp_dx = (short*)std::malloc(mesh_width * mesh_height * sizeof(short));
short *temp_dy = (short*)std::malloc(mesh_width * mesh_height * sizeof(short));
for (int iter = 0; iter < iterations; iter++) {
std::memcpy(temp_dx, mesh_dx, mesh_width * mesh_height * sizeof(short));
std::memcpy(temp_dy, mesh_dy, mesh_width * mesh_height * sizeof(short));
for (int my = 0; my < mesh_height; my++) {
for (int mx = 0; mx < mesh_width; mx++) {
int idx = my * mesh_width + mx;
float neighbor_dx = 0.0f, neighbor_dy = 0.0f;
int neighbor_count = 0;
int neighbors[4][2] = {{0, -1}, {0, 1}, {-1, 0}, {1, 0}};
for (int n = 0; n < 4; n++) {
int nx = mx + neighbors[n][0];
int ny = my + neighbors[n][1];
if (nx >= 0 && nx < mesh_width && ny >= 0 && ny < mesh_height) {
int nidx = ny * mesh_width + nx;
neighbor_dx += temp_dx[nidx];
neighbor_dy += temp_dy[nidx];
neighbor_count++;
}
}
if (neighbor_count > 0) {
neighbor_dx /= neighbor_count;
neighbor_dy /= neighbor_count;
float data_weight = 1.0f - smoothness;
mesh_dx[idx] = (short)(data_weight * temp_dx[idx] + smoothness * neighbor_dx);
mesh_dy[idx] = (short)(data_weight * temp_dy[idx] + smoothness * neighbor_dy);
}
}
}
}
std::free(temp_dx);
std::free(temp_dy);
}
// Bilinear mesh warp
void warp_frame_with_mesh(
const float *src_frame, int width, int height,
const short *mesh_dx, const short *mesh_dy,
int mesh_width, int mesh_height,
float *dst_frame
) {
int cell_w = width / mesh_width;
int cell_h = height / mesh_height;
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int cell_x = x / cell_w;
int cell_y = y / cell_h;
if (cell_x >= mesh_width - 1) cell_x = mesh_width - 2;
if (cell_y >= mesh_height - 1) cell_y = mesh_height - 2;
if (cell_x < 0) cell_x = 0;
if (cell_y < 0) cell_y = 0;
int idx_00 = cell_y * mesh_width + cell_x;
int idx_10 = idx_00 + 1;
int idx_01 = (cell_y + 1) * mesh_width + cell_x;
int idx_11 = idx_01 + 1;
float cp_x0 = cell_x * cell_w + cell_w / 2.0f;
float cp_y0 = cell_y * cell_h + cell_h / 2.0f;
float cp_x1 = (cell_x + 1) * cell_w + cell_w / 2.0f;
float cp_y1 = (cell_y + 1) * cell_h + cell_h / 2.0f;
float alpha = (x - cp_x0) / (cp_x1 - cp_x0);
float beta = (y - cp_y0) / (cp_y1 - cp_y0);
if (alpha < 0.0f) alpha = 0.0f;
if (alpha > 1.0f) alpha = 1.0f;
if (beta < 0.0f) beta = 0.0f;
if (beta > 1.0f) beta = 1.0f;
float dx_00 = mesh_dx[idx_00] / 4.0f;
float dy_00 = mesh_dy[idx_00] / 4.0f;
float dx_10 = mesh_dx[idx_10] / 4.0f;
float dy_10 = mesh_dy[idx_10] / 4.0f;
float dx_01 = mesh_dx[idx_01] / 4.0f;
float dy_01 = mesh_dy[idx_01] / 4.0f;
float dx_11 = mesh_dx[idx_11] / 4.0f;
float dy_11 = mesh_dy[idx_11] / 4.0f;
float dx = (1 - alpha) * (1 - beta) * dx_00 +
alpha * (1 - beta) * dx_10 +
(1 - alpha) * beta * dx_01 +
alpha * beta * dx_11;
float dy = (1 - alpha) * (1 - beta) * dy_00 +
alpha * (1 - beta) * dy_10 +
(1 - alpha) * beta * dy_01 +
alpha * beta * dy_11;
float src_x = x + dx;
float src_y = y + dy;
int sx0 = (int)std::floor(src_x);
int sy0 = (int)std::floor(src_y);
int sx1 = sx0 + 1;
int sy1 = sy0 + 1;
if (sx0 < 0) sx0 = 0;
if (sy0 < 0) sy0 = 0;
if (sx1 >= width) sx1 = width - 1;
if (sy1 >= height) sy1 = height - 1;
if (sx0 >= width) sx0 = width - 1;
if (sy0 >= height) sy0 = height - 1;
float fx = src_x - sx0;
float fy = src_y - sy0;
float val_00 = src_frame[sy0 * width + sx0];
float val_10 = src_frame[sy0 * width + sx1];
float val_01 = src_frame[sy1 * width + sx0];
float val_11 = src_frame[sy1 * width + sx1];
float val = (1 - fx) * (1 - fy) * val_00 +
fx * (1 - fy) * val_10 +
(1 - fx) * fy * val_01 +
fx * fy * val_11;
dst_frame[y * width + x] = val;
}
}
}
// Dense optical flow estimation using Farneback algorithm
// Computes flow at every pixel, then samples at block centers for motion vectors
// Much more spatially coherent than independent block matching
@@ -491,4 +84,100 @@ void estimate_optical_flow_motion(
}
}
// Block-based motion compensation with bilinear interpolation (sub-pixel precision)
// MVs are in 1/4-pixel units
// This implements the warp() function from MC-EZBC pseudocode
void warp_block_motion(
const float *src, // Source frame
int width, int height,
const int16_t *mvs_x, // Motion vectors X (1/4-pixel units)
const int16_t *mvs_y, // Motion vectors Y (1/4-pixel units)
int block_size, // Block size (e.g., 16)
float *dst // Output warped frame
) {
int num_blocks_x = (width + block_size - 1) / block_size;
int num_blocks_y = (height + block_size - 1) / block_size;
// Process each block
for (int by = 0; by < num_blocks_y; by++) {
for (int bx = 0; bx < num_blocks_x; bx++) {
int block_idx = by * num_blocks_x + bx;
// Get motion vector for this block (in 1/4-pixel units)
float mv_x = mvs_x[block_idx] / 4.0f; // Convert to pixels
float mv_y = mvs_y[block_idx] / 4.0f;
// Block boundaries in destination frame
int block_x_start = bx * block_size;
int block_y_start = by * block_size;
int block_x_end = std::min(block_x_start + block_size, width);
int block_y_end = std::min(block_y_start + block_size, height);
// Warp each pixel in the block
for (int y = block_y_start; y < block_y_end; y++) {
for (int x = block_x_start; x < block_x_end; x++) {
// Source position (backward warping)
float src_x = x - mv_x;
float src_y = y - mv_y;
// Clamp to valid range
src_x = std::max(0.0f, std::min((float)(width - 1), src_x));
src_y = std::max(0.0f, std::min((float)(height - 1), src_y));
// Bilinear interpolation
int x0 = (int)src_x;
int y0 = (int)src_y;
int x1 = std::min(x0 + 1, width - 1);
int y1 = std::min(y0 + 1, height - 1);
float fx = src_x - x0;
float fy = src_y - y0;
float val00 = src[y0 * width + x0];
float val10 = src[y0 * width + x1];
float val01 = src[y1 * width + x0];
float val11 = src[y1 * width + x1];
float val_top = (1.0f - fx) * val00 + fx * val10;
float val_bot = (1.0f - fx) * val01 + fx * val11;
float val = (1.0f - fy) * val_top + fy * val_bot;
dst[y * width + x] = val;
}
}
}
}
}
// Bidirectional motion compensation for MC-EZBC predict step
// Implements: prediction = 0.5 * (warp(f0, MV_fwd) + warp(f1, MV_bwd))
void warp_bidirectional(
const float *f0, const float *f1,
int width, int height,
const int16_t *mvs_fwd_x, const int16_t *mvs_fwd_y, // F0 → F1
const int16_t *mvs_bwd_x, const int16_t *mvs_bwd_y, // F1 → F0
int block_size,
float *prediction // Output: 0.5 * (warped_f0 + warped_f1)
) {
int num_pixels = width * height;
// Allocate temporary buffers
float *warped_f0 = new float[num_pixels];
float *warped_f1 = new float[num_pixels];
// Warp f0 forward using forward MVs
warp_block_motion(f0, width, height, mvs_fwd_x, mvs_fwd_y, block_size, warped_f0);
// Warp f1 backward using backward MVs
warp_block_motion(f1, width, height, mvs_bwd_x, mvs_bwd_y, block_size, warped_f1);
// Average the two warped frames
for (int i = 0; i < num_pixels; i++) {
prediction[i] = 0.5f * (warped_f0[i] + warped_f1[i]);
}
delete[] warped_f0;
delete[] warped_f1;
}
} // extern "C"