diff --git a/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt b/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt index c0ae0b0..e145c19 100644 --- a/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt +++ b/tsvm_core/src/net/torvald/tsvm/GraphicsJSR223Delegate.kt @@ -16,6 +16,11 @@ import kotlin.math.* class GraphicsJSR223Delegate(private val vm: VM) { + // TAV Simulated overlapping tiles constants (must match encoder) + private val TILE_SIZE = 112 + private val TILE_MARGIN = 32 // 32-pixel margin for 3 DWT levels (4 * 2^3 = 32px) + private val PADDED_TILE_SIZE = TILE_SIZE + 2 * TILE_MARGIN // 112 + 64 = 176px + // Reusable working arrays to reduce allocation overhead private val idct8TempBuffer = FloatArray(64) private val idct16TempBuffer = FloatArray(256) // For 16x16 IDCT @@ -3978,62 +3983,78 @@ class GraphicsJSR223Delegate(private val vm: VM) { println("TAV decode error: ${e.message}") } - // Apply deblocking filter if enabled to reduce DWT quantization artifacts // if (enableDeblocking) { -// tavDeblockingFilter(currentRGBAddr, width, height) +// tavAdaptiveDeblockingFilter(currentRGBAddr, width, height) // } } private fun decodeDWTIntraTileRGB(readPtr: Long, tileX: Int, tileY: Int, currentRGBAddr: Long, width: Int, height: Int, qY: Int, qCo: Int, qCg: Int, rcf: Float, waveletFilter: Int, decompLevels: Int, isLossless: Boolean, tavVersion: Int): Long { - val tileSize = 112 - val coeffCount = tileSize * tileSize + // Now reading padded coefficient tiles (176x176) instead of core tiles (112x112) + val paddedSize = PADDED_TILE_SIZE + val paddedCoeffCount = paddedSize * paddedSize var ptr = readPtr - // Read quantized DWT coefficients for Y, Co, Cg channels - val quantizedY = ShortArray(coeffCount) - val quantizedCo = ShortArray(coeffCount) - val quantizedCg = ShortArray(coeffCount) + // Read quantized DWT coefficients for padded tile Y, Co, Cg channels (176x176) + val quantizedY = ShortArray(paddedCoeffCount) + val quantizedCo = ShortArray(paddedCoeffCount) + val quantizedCg = ShortArray(paddedCoeffCount) - // Read Y coefficients - for (i in 0 until coeffCount) { + // Read Y coefficients (176x176) + for (i in 0 until paddedCoeffCount) { quantizedY[i] = vm.peekShort(ptr) ptr += 2 } - // Read Co coefficients - for (i in 0 until coeffCount) { + // Read Co coefficients (176x176) + for (i in 0 until paddedCoeffCount) { quantizedCo[i] = vm.peekShort(ptr) ptr += 2 } - // Read Cg coefficients - for (i in 0 until coeffCount) { + // Read Cg coefficients (176x176) + for (i in 0 until paddedCoeffCount) { quantizedCg[i] = vm.peekShort(ptr) ptr += 2 } - // Dequantize and apply inverse DWT - val yTile = FloatArray(coeffCount) - val coTile = FloatArray(coeffCount) - val cgTile = FloatArray(coeffCount) + // Dequantize padded coefficient tiles (176x176) + val yPaddedTile = FloatArray(paddedCoeffCount) + val coPaddedTile = FloatArray(paddedCoeffCount) + val cgPaddedTile = FloatArray(paddedCoeffCount) - for (i in 0 until coeffCount) { - yTile[i] = quantizedY[i] * qY * rcf - coTile[i] = quantizedCo[i] * qCo * rcf - cgTile[i] = quantizedCg[i] * qCg * rcf + for (i in 0 until paddedCoeffCount) { + yPaddedTile[i] = quantizedY[i] * qY * rcf + coPaddedTile[i] = quantizedCo[i] * qCo * rcf + cgPaddedTile[i] = quantizedCg[i] * qCg * rcf } - // Apply inverse DWT using specified filter with decomposition levels + // Apply inverse DWT on full padded tiles (176x176) if (isLossless) { - applyDWTInverseMultiLevel(yTile, tileSize, tileSize, decompLevels, 0) - applyDWTInverseMultiLevel(coTile, tileSize, tileSize, decompLevels, 0) - applyDWTInverseMultiLevel(cgTile, tileSize, tileSize, decompLevels, 0) + applyDWTInverseMultiLevel(yPaddedTile, paddedSize, paddedSize, decompLevels, 0) + applyDWTInverseMultiLevel(coPaddedTile, paddedSize, paddedSize, decompLevels, 0) + applyDWTInverseMultiLevel(cgPaddedTile, paddedSize, paddedSize, decompLevels, 0) } else { - applyDWTInverseMultiLevel(yTile, tileSize, tileSize, decompLevels, waveletFilter) - applyDWTInverseMultiLevel(coTile, tileSize, tileSize, decompLevels, waveletFilter) - applyDWTInverseMultiLevel(cgTile, tileSize, tileSize, decompLevels, waveletFilter) + applyDWTInverseMultiLevel(yPaddedTile, paddedSize, paddedSize, decompLevels, waveletFilter) + applyDWTInverseMultiLevel(coPaddedTile, paddedSize, paddedSize, decompLevels, waveletFilter) + applyDWTInverseMultiLevel(cgPaddedTile, paddedSize, paddedSize, decompLevels, waveletFilter) + } + + // Extract core 112x112 pixels from reconstructed padded tiles (176x176) + val yTile = FloatArray(TILE_SIZE * TILE_SIZE) + val coTile = FloatArray(TILE_SIZE * TILE_SIZE) + val cgTile = FloatArray(TILE_SIZE * TILE_SIZE) + + for (y in 0 until TILE_SIZE) { + for (x in 0 until TILE_SIZE) { + val coreIdx = y * TILE_SIZE + x + val paddedIdx = (y + TILE_MARGIN) * paddedSize + (x + TILE_MARGIN) + + yTile[coreIdx] = yPaddedTile[paddedIdx] + coTile[coreIdx] = coPaddedTile[paddedIdx] + cgTile[coreIdx] = cgPaddedTile[paddedIdx] + } } // Convert to RGB based on TAV version (YCoCg-R for v1, ICtCp for v2) @@ -4326,6 +4347,14 @@ class GraphicsJSR223Delegate(private val vm: VM) { // Lifting scheme implementation for 9/7 irreversible filter } + private fun generateWindowFunction(window: FloatArray, size: Int) { + // Raised cosine (Hann) window for smooth blending + for (i in 0 until size) { + val t = i.toFloat() / (size - 1) + window[i] = 0.5f * (1.0f - kotlin.math.cos(PI * t)) + } + } + private fun applyDWTInverseMultiLevel(data: FloatArray, width: Int, height: Int, levels: Int, filterType: Int) { // Multi-level inverse DWT - reconstruct from smallest to largest (reverse of encoder) val size = width // Full tile size (112 for TAV) @@ -4602,12 +4631,302 @@ class GraphicsJSR223Delegate(private val vm: VM) { if (half + idx < length) { data[i] = temp[half + idx] } else { - data[i] = 0.0f // Boundary case + // Symmetric extension: mirror the last available high-pass coefficient + val lastHighIdx = (length / 2) - 1 + if (lastHighIdx >= 0 && half + lastHighIdx < length) { + data[i] = temp[half + lastHighIdx] + } else { + data[i] = 0.0f + } } } } } + private fun tavAdaptiveDeblockingFilter(rgbAddr: Long, width: Int, height: Int) { + val tileSize = 112 + val tilesX = (width + tileSize - 1) / tileSize + val tilesY = (height + tileSize - 1) / tileSize + + // Process vertical seams (between horizontally adjacent tiles) + for (tileY in 0 until tilesY) { + for (tileX in 0 until tilesX - 1) { + val seamX = (tileX + 1) * tileSize // Actual boundary between tiles + deblockVerticalSeamStrong(rgbAddr, width, height, seamX, tileY * tileSize, tileSize) + } + } + + // Process horizontal seams (between vertically adjacent tiles) + for (tileY in 0 until tilesY - 1) { + for (tileX in 0 until tilesX) { + val seamY = (tileY + 1) * tileSize // Actual boundary between tiles + deblockHorizontalSeamStrong(rgbAddr, width, height, tileX * tileSize, seamY, tileSize) + } + } + } + + private fun deblockVerticalSeamStrong(rgbAddr: Long, width: Int, height: Int, seamX: Int, startY: Int, tileHeight: Int) { + if (seamX >= width) return + + val endY = minOf(startY + tileHeight, height) + + for (y in startY until endY) { + if (y >= height) break + + // Check for discontinuity across the seam + val leftX = seamX - 1 + val rightX = seamX + + if (leftX >= 0 && rightX < width) { + val leftOffset = (y * width + leftX) * 3L + val rightOffset = (y * width + rightX) * 3L + + val leftR = vm.peek(rgbAddr + leftOffset).toInt() and 0xFF + val leftG = vm.peek(rgbAddr + leftOffset + 1).toInt() and 0xFF + val leftB = vm.peek(rgbAddr + leftOffset + 2).toInt() and 0xFF + + val rightR = vm.peek(rgbAddr + rightOffset).toInt() and 0xFF + val rightG = vm.peek(rgbAddr + rightOffset + 1).toInt() and 0xFF + val rightB = vm.peek(rgbAddr + rightOffset + 2).toInt() and 0xFF + + // Calculate discontinuity strength + val diffR = abs(leftR - rightR) + val diffG = abs(leftG - rightG) + val diffB = abs(leftB - rightB) + val maxDiff = maxOf(diffR, diffG, diffB) + + // Only apply deblocking if there's a significant discontinuity + if (maxDiff in 2 until 120) { + // Adaptive filter radius: wider for smooth gradients, narrower for sharp edges + val filterRadius = when { + maxDiff <= 15 -> 6 // Very smooth gradients: wide filter (13 pixels) + maxDiff <= 30 -> 4 // Moderate gradients: medium filter (9 pixels) + maxDiff <= 60 -> 3 // Sharp transitions: narrow filter (7 pixels) + else -> 2 // Very sharp edges: minimal filter (5 pixels) + } + + for (dx in -filterRadius..filterRadius) { + val x = seamX + dx + if (x in 0 until width) { + val offset = (y * width + x) * 3L + + val currentR = vm.peek(rgbAddr + offset).toInt() and 0xFF + val currentG = vm.peek(rgbAddr + offset + 1).toInt() and 0xFF + val currentB = vm.peek(rgbAddr + offset + 2).toInt() and 0xFF + + var sumR = 0.0f + var sumG = 0.0f + var sumB = 0.0f + var weightSum = 0.0f + + // Bilateral filtering with spatial and intensity weights + for (sx in maxOf(0, x-filterRadius)..minOf(width-1, x+filterRadius)) { + val sOffset = (y * width + sx) * 3L + val sR = vm.peek(rgbAddr + sOffset).toInt() and 0xFF + val sG = vm.peek(rgbAddr + sOffset + 1).toInt() and 0xFF + val sB = vm.peek(rgbAddr + sOffset + 2).toInt() and 0xFF + + // Spatial weight (distance from current pixel) + val spatialWeight = 1.0f / (1.0f + abs(sx - x)) + + // Intensity weight (color similarity) + val colorDiff = sqrt(((sR - currentR) * (sR - currentR) + + (sG - currentG) * (sG - currentG) + + (sB - currentB) * (sB - currentB)).toFloat()) + val intensityWeight = exp(-colorDiff / 30.0f) + + val totalWeight = spatialWeight * intensityWeight + + sumR += sR * totalWeight + sumG += sG * totalWeight + sumB += sB * totalWeight + weightSum += totalWeight + } + + if (weightSum > 0) { + val filteredR = (sumR / weightSum).toInt() + val filteredG = (sumG / weightSum).toInt() + val filteredB = (sumB / weightSum).toInt() + + // Concentrate blur heavily at the seam boundary + val distance = abs(dx).toFloat() + val blendWeight = when { + distance == 0.0f -> 0.95f // Maximum blur at exact seam + distance == 1.0f -> 0.8f // Strong blur adjacent to seam + distance == 2.0f -> 0.5f // Medium blur 2 pixels away + else -> exp(-distance * distance / 1.5f) * 0.3f // Gentle falloff beyond + } + + val finalR = (currentR * (1 - blendWeight) + filteredR * blendWeight).toInt().coerceIn(0, 255) + val finalG = (currentG * (1 - blendWeight) + filteredG * blendWeight).toInt().coerceIn(0, 255) + val finalB = (currentB * (1 - blendWeight) + filteredB * blendWeight).toInt().coerceIn(0, 255) + + vm.poke(rgbAddr + offset, finalR.toByte()) + vm.poke(rgbAddr + offset + 1, finalG.toByte()) + vm.poke(rgbAddr + offset + 2, finalB.toByte()) + } + } + } + } + } + } + } + + private fun deblockHorizontalSeamStrong(rgbAddr: Long, width: Int, height: Int, startX: Int, seamY: Int, tileWidth: Int) { + if (seamY >= height) return + + val endX = minOf(startX + tileWidth, width) + + for (x in startX until endX) { + if (x >= width) break + + // Check for discontinuity across the seam + val topY = seamY - 1 + val bottomY = seamY + + if (topY >= 0 && bottomY < height) { + val topOffset = (topY * width + x) * 3L + val bottomOffset = (bottomY * width + x) * 3L + + val topR = vm.peek(rgbAddr + topOffset).toInt() and 0xFF + val topG = vm.peek(rgbAddr + topOffset + 1).toInt() and 0xFF + val topB = vm.peek(rgbAddr + topOffset + 2).toInt() and 0xFF + + val bottomR = vm.peek(rgbAddr + bottomOffset).toInt() and 0xFF + val bottomG = vm.peek(rgbAddr + bottomOffset + 1).toInt() and 0xFF + val bottomB = vm.peek(rgbAddr + bottomOffset + 2).toInt() and 0xFF + + // Calculate discontinuity strength + val diffR = abs(topR - bottomR) + val diffG = abs(topG - bottomG) + val diffB = abs(topB - bottomB) + val maxDiff = maxOf(diffR, diffG, diffB) + + // Only apply deblocking if there's a significant discontinuity + if (maxDiff in 2 until 120) { + // Adaptive filter radius: wider for smooth gradients, narrower for sharp edges + val filterRadius = when { + maxDiff <= 15 -> 6 // Very smooth gradients: wide filter (13 pixels) + maxDiff <= 30 -> 4 // Moderate gradients: medium filter (9 pixels) + maxDiff <= 60 -> 3 // Sharp transitions: narrow filter (7 pixels) + else -> 2 // Very sharp edges: minimal filter (5 pixels) + } + + for (dy in -filterRadius..filterRadius) { + val y = seamY + dy + if (y in 0 until height) { + val offset = (y * width + x) * 3L + + val currentR = vm.peek(rgbAddr + offset).toInt() and 0xFF + val currentG = vm.peek(rgbAddr + offset + 1).toInt() and 0xFF + val currentB = vm.peek(rgbAddr + offset + 2).toInt() and 0xFF + + var sumR = 0.0f + var sumG = 0.0f + var sumB = 0.0f + var weightSum = 0.0f + + // Bilateral filtering with spatial and intensity weights + for (sy in maxOf(0, y-filterRadius)..minOf(height-1, y+filterRadius)) { + val sOffset = (sy * width + x) * 3L + val sR = vm.peek(rgbAddr + sOffset).toInt() and 0xFF + val sG = vm.peek(rgbAddr + sOffset + 1).toInt() and 0xFF + val sB = vm.peek(rgbAddr + sOffset + 2).toInt() and 0xFF + + // Spatial weight (distance from current pixel) + val spatialWeight = 1.0f / (1.0f + abs(sy - y)) + + // Intensity weight (color similarity) + val colorDiff = sqrt(((sR - currentR) * (sR - currentR) + + (sG - currentG) * (sG - currentG) + + (sB - currentB) * (sB - currentB)).toFloat()) + val intensityWeight = exp(-colorDiff / 30.0f) + + val totalWeight = spatialWeight * intensityWeight + + sumR += sR * totalWeight + sumG += sG * totalWeight + sumB += sB * totalWeight + weightSum += totalWeight + } + + if (weightSum > 0) { + val filteredR = (sumR / weightSum).toInt() + val filteredG = (sumG / weightSum).toInt() + val filteredB = (sumB / weightSum).toInt() + + // Concentrate blur heavily at the seam boundary + val distance = abs(dy).toFloat() + val blendWeight = when { + distance == 0.0f -> 0.95f // Maximum blur at exact seam + distance == 1.0f -> 0.8f // Strong blur adjacent to seam + distance == 2.0f -> 0.5f // Medium blur 2 pixels away + else -> exp(-distance * distance / 1.5f) * 0.3f // Gentle falloff beyond + } + + val finalR = (currentR * (1 - blendWeight) + filteredR * blendWeight).toInt().coerceIn(0, 255) + val finalG = (currentG * (1 - blendWeight) + filteredG * blendWeight).toInt().coerceIn(0, 255) + val finalB = (currentB * (1 - blendWeight) + filteredB * blendWeight).toInt().coerceIn(0, 255) + + vm.poke(rgbAddr + offset, finalR.toByte()) + vm.poke(rgbAddr + offset + 1, finalG.toByte()) + vm.poke(rgbAddr + offset + 2, finalB.toByte()) + } + } + } + } + } + } + } + + private fun analyzeTextureComplexity(rgbAddr: Long, width: Int, height: Int, centerX: Int, centerY: Int, isVerticalSeam: Boolean): Float { + val radius = 4 + var totalVariance = 0.0f + var count = 0 + + // Calculate variance in a small window around the seam + for (dy in -radius..radius) { + for (dx in -radius..radius) { + val x = centerX + dx + val y = centerY + dy + + if (x >= 0 && x < width && y >= 0 && y < height) { + val offset = (y * width + x) * 3L + val r = vm.peek(rgbAddr + offset).toInt() and 0xFF + val g = vm.peek(rgbAddr + offset + 1).toInt() and 0xFF + val b = vm.peek(rgbAddr + offset + 2).toInt() and 0xFF + + val luma = 0.299f * r + 0.587f * g + 0.114f * b + + // Compare with adjacent pixels to measure local variance + if (x > 0) { + val leftOffset = (y * width + (x-1)) * 3L + val leftR = vm.peek(rgbAddr + leftOffset).toInt() and 0xFF + val leftG = vm.peek(rgbAddr + leftOffset + 1).toInt() and 0xFF + val leftB = vm.peek(rgbAddr + leftOffset + 2).toInt() and 0xFF + val leftLuma = 0.299f * leftR + 0.587f * leftG + 0.114f * leftB + + totalVariance += abs(luma - leftLuma) + count++ + } + + if (y > 0) { + val topOffset = ((y-1) * width + x) * 3L + val topR = vm.peek(rgbAddr + topOffset).toInt() and 0xFF + val topG = vm.peek(rgbAddr + topOffset + 1).toInt() and 0xFF + val topB = vm.peek(rgbAddr + topOffset + 2).toInt() and 0xFF + val topLuma = 0.299f * topR + 0.587f * topG + 0.114f * topB + + totalVariance += abs(luma - topLuma) + count++ + } + } + } + } + + return if (count > 0) totalVariance / count else 0.0f + } + private fun bilinearInterpolate( dataPtr: Long, width: Int, height: Int, x: Float, y: Float diff --git a/video_encoder/encoder_tav.c b/video_encoder/encoder_tav.c index 5653ab8..9149f5f 100644 --- a/video_encoder/encoder_tav.c +++ b/video_encoder/encoder_tav.c @@ -16,6 +16,10 @@ #include #include +#ifndef PI +#define PI 3.14159265358979323846f +#endif + // TSVM Advanced Video (TAV) format constants #define TAV_MAGIC "\x1F\x54\x53\x56\x4D\x54\x41\x56" // "\x1FTSVM TAV" // TAV version - dynamic based on color space mode @@ -40,6 +44,12 @@ #define MAX_DECOMP_LEVELS 6 // Can go deeper: 112→56→28→14→7→3→1 #define DEFAULT_DECOMP_LEVELS 5 // Increased default for better compression +// Simulated overlapping tiles settings for seamless DWT processing +#define DWT_FILTER_HALF_SUPPORT 4 // For 9/7 filter (filter lengths 9,7 → L=4) +#define TILE_MARGIN_LEVELS 3 // Use margin for 3 levels: 4 * (2^3) = 4 * 8 = 32px +#define TILE_MARGIN (DWT_FILTER_HALF_SUPPORT * (1 << TILE_MARGIN_LEVELS)) // 4 * 8 = 32px +#define PADDED_TILE_SIZE (TILE_SIZE + 2 * TILE_MARGIN) // 112 + 64 = 176px + // Wavelet filter types #define WAVELET_5_3_REVERSIBLE 0 // Lossless capable #define WAVELET_9_7_IRREVERSIBLE 1 // Higher compression @@ -478,6 +488,92 @@ static void dwt_97_forward_1d(float *data, int length) { free(temp); } +// Extract padded tile with margins for seamless DWT processing (correct implementation) +static void extract_padded_tile(tav_encoder_t *enc, int tile_x, int tile_y, + float *padded_y, float *padded_co, float *padded_cg) { + const int core_start_x = tile_x * TILE_SIZE; + const int core_start_y = tile_y * TILE_SIZE; + + // Extract padded tile: margin + core + margin + for (int py = 0; py < PADDED_TILE_SIZE; py++) { + for (int px = 0; px < PADDED_TILE_SIZE; px++) { + // Map padded coordinates to source image coordinates + int src_x = core_start_x + px - TILE_MARGIN; + int src_y = core_start_y + py - TILE_MARGIN; + + // Handle boundary conditions with mirroring + if (src_x < 0) src_x = -src_x; + else if (src_x >= enc->width) src_x = enc->width - 1 - (src_x - enc->width); + + if (src_y < 0) src_y = -src_y; + else if (src_y >= enc->height) src_y = enc->height - 1 - (src_y - enc->height); + + // Clamp to valid bounds + src_x = CLAMP(src_x, 0, enc->width - 1); + src_y = CLAMP(src_y, 0, enc->height - 1); + + int src_idx = src_y * enc->width + src_x; + int padded_idx = py * PADDED_TILE_SIZE + px; + + padded_y[padded_idx] = enc->current_frame_y[src_idx]; + padded_co[padded_idx] = enc->current_frame_co[src_idx]; + padded_cg[padded_idx] = enc->current_frame_cg[src_idx]; + } + } +} + + +// 2D DWT forward transform for padded tile +static void dwt_2d_forward_padded(float *tile_data, int levels, int filter_type) { + const int size = PADDED_TILE_SIZE; + float *temp_row = malloc(size * sizeof(float)); + float *temp_col = malloc(size * sizeof(float)); + + for (int level = 0; level < levels; level++) { + int current_size = size >> level; + if (current_size < 1) break; + + // Row transform + for (int y = 0; y < current_size; y++) { + for (int x = 0; x < current_size; x++) { + temp_row[x] = tile_data[y * size + x]; + } + + if (filter_type == WAVELET_5_3_REVERSIBLE) { + dwt_53_forward_1d(temp_row, current_size); + } else { + dwt_97_forward_1d(temp_row, current_size); + } + + for (int x = 0; x < current_size; x++) { + tile_data[y * size + x] = temp_row[x]; + } + } + + // Column transform + for (int x = 0; x < current_size; x++) { + for (int y = 0; y < current_size; y++) { + temp_col[y] = tile_data[y * size + x]; + } + + if (filter_type == WAVELET_5_3_REVERSIBLE) { + dwt_53_forward_1d(temp_col, current_size); + } else { + dwt_97_forward_1d(temp_col, current_size); + } + + for (int y = 0; y < current_size; y++) { + tile_data[y * size + x] = temp_col[y]; + } + } + } + + free(temp_row); + free(temp_col); +} + + + // 2D DWT forward transform for 112x112 tile static void dwt_2d_forward(float *tile_data, int levels, int filter_type) { const int size = TILE_SIZE; @@ -560,8 +656,8 @@ static size_t serialize_tile_data(tav_encoder_t *enc, int tile_x, int tile_y, return offset; } - // Quantize and serialize DWT coefficients - const int tile_size = TILE_SIZE * TILE_SIZE; + // Quantize and serialize DWT coefficients (full padded tile: 176x176) + const int tile_size = PADDED_TILE_SIZE * PADDED_TILE_SIZE; int16_t *quantized_y = malloc(tile_size * sizeof(int16_t)); int16_t *quantized_co = malloc(tile_size * sizeof(int16_t)); int16_t *quantized_cg = malloc(tile_size * sizeof(int16_t)); @@ -604,8 +700,8 @@ static size_t serialize_tile_data(tav_encoder_t *enc, int tile_x, int tile_y, // Compress and write frame data static size_t compress_and_write_frame(tav_encoder_t *enc, uint8_t packet_type) { - // Calculate total uncompressed size - const size_t max_tile_size = 9 + (TILE_SIZE * TILE_SIZE * 3 * sizeof(int16_t)); // header + 3 channels of coefficients + // Calculate total uncompressed size (for padded tile coefficients: 176x176) + const size_t max_tile_size = 9 + (PADDED_TILE_SIZE * PADDED_TILE_SIZE * 3 * sizeof(int16_t)); // header + 3 channels of coefficients const size_t total_uncompressed_size = enc->tiles_x * enc->tiles_y * max_tile_size; // Allocate buffer for uncompressed tile data @@ -620,31 +716,13 @@ static size_t compress_and_write_frame(tav_encoder_t *enc, uint8_t packet_type) // Determine tile mode (simplified) uint8_t mode = TAV_MODE_INTRA; // For now, all tiles are INTRA - // Extract tile data (already processed) - float tile_y_data[TILE_SIZE * TILE_SIZE]; - float tile_co_data[TILE_SIZE * TILE_SIZE]; - float tile_cg_data[TILE_SIZE * TILE_SIZE]; + // Extract padded tile data (176x176) with neighbor context for overlapping tiles + float tile_y_data[PADDED_TILE_SIZE * PADDED_TILE_SIZE]; + float tile_co_data[PADDED_TILE_SIZE * PADDED_TILE_SIZE]; + float tile_cg_data[PADDED_TILE_SIZE * PADDED_TILE_SIZE]; - // Extract tile data from frame buffers - for (int y = 0; y < TILE_SIZE; y++) { - for (int x = 0; x < TILE_SIZE; x++) { - int src_x = tile_x * TILE_SIZE + x; - int src_y = tile_y * TILE_SIZE + y; - int src_idx = src_y * enc->width + src_x; - int tile_idx_local = y * TILE_SIZE + x; - - if (src_x < enc->width && src_y < enc->height) { - tile_y_data[tile_idx_local] = enc->current_frame_y[src_idx]; - tile_co_data[tile_idx_local] = enc->current_frame_co[src_idx]; - tile_cg_data[tile_idx_local] = enc->current_frame_cg[src_idx]; - } else { - // Pad with zeros if tile extends beyond frame - tile_y_data[tile_idx_local] = 0.0f; - tile_co_data[tile_idx_local] = 0.0f; - tile_cg_data[tile_idx_local] = 0.0f; - } - } - } + // Extract padded tiles using context from neighbors + extract_padded_tile(enc, tile_x, tile_y, tile_y_data, tile_co_data, tile_cg_data); // Debug: check input data before DWT /*if (tile_x == 0 && tile_y == 0) { @@ -655,10 +733,10 @@ static size_t compress_and_write_frame(tav_encoder_t *enc, uint8_t packet_type) printf("\n"); }*/ - // Apply DWT transform to each channel - dwt_2d_forward(tile_y_data, enc->decomp_levels, enc->wavelet_filter); - dwt_2d_forward(tile_co_data, enc->decomp_levels, enc->wavelet_filter); - dwt_2d_forward(tile_cg_data, enc->decomp_levels, enc->wavelet_filter); + // Apply DWT transform to each padded channel (176x176) + dwt_2d_forward_padded(tile_y_data, enc->decomp_levels, enc->wavelet_filter); + dwt_2d_forward_padded(tile_co_data, enc->decomp_levels, enc->wavelet_filter); + dwt_2d_forward_padded(tile_cg_data, enc->decomp_levels, enc->wavelet_filter); // Serialize tile size_t tile_size = serialize_tile_data(enc, tile_x, tile_y,