mirror of
https://github.com/curioustorvald/tsvm.git
synced 2026-03-07 11:51:49 +09:00
TAV: 3D DWT makes coherent picture at least
This commit is contained in:
@@ -72,6 +72,7 @@ import kotlin.intArrayOf
|
||||
import kotlin.let
|
||||
import kotlin.longArrayOf
|
||||
import kotlin.math.*
|
||||
import kotlin.math.pow
|
||||
import kotlin.repeat
|
||||
import kotlin.text.format
|
||||
import kotlin.text.lowercase
|
||||
@@ -4538,30 +4539,18 @@ class GraphicsJSR223Delegate(private val vm: VM) {
|
||||
* - Frames 8-15: Level 2 (tHH - highest frequency)
|
||||
*/
|
||||
private fun getTemporalSubbandLevel(frameIdx: Int, numFrames: Int, temporalLevels: Int): Int {
|
||||
if (temporalLevels == 0) return 0
|
||||
// Match encoder logic exactly (encoder_tav.c:1487-1501)
|
||||
// After temporal DWT with 2 levels:
|
||||
// Frames 0...num_frames/(2^2) = tLL (temporal low-low, coarsest, level 0)
|
||||
// Frames in first half but after tLL = tLH (level 1)
|
||||
// Remaining frames = tH from first level (level 2, finest)
|
||||
|
||||
val framesPerSubband = numFrames shr temporalLevels // numFrames / 2^temporalLevels
|
||||
val framesPerLevel0 = numFrames shr temporalLevels // e.g., 16 >> 2 = 4, or 8 >> 2 = 2
|
||||
|
||||
// Safety check: ensure we have enough frames for the temporal levels
|
||||
// Minimum frames needed = 2^temporalLevels
|
||||
if (framesPerSubband == 0) {
|
||||
// Not enough frames for this many temporal levels - treat all as base level
|
||||
return 0
|
||||
}
|
||||
|
||||
// Determine which temporal subband this frame belongs to
|
||||
val subbandIdx = frameIdx / framesPerSubband
|
||||
|
||||
// Map subband index to level (0 = tLL, 1+ = temporal high-pass levels)
|
||||
return if (subbandIdx == 0) 0 else {
|
||||
// Find highest bit position in subbandIdx to determine level
|
||||
var level = 0
|
||||
var idx = subbandIdx
|
||||
while (idx > 1) {
|
||||
idx = idx shr 1
|
||||
level++
|
||||
}
|
||||
level + 1
|
||||
return when {
|
||||
frameIdx < framesPerLevel0 -> 0 // Coarsest temporal level (tLL)
|
||||
frameIdx < (numFrames shr 1) -> 1 // First level high-pass (tLH)
|
||||
else -> 2 // Finest level high-pass (tH from level 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4575,9 +4564,10 @@ class GraphicsJSR223Delegate(private val vm: VM) {
|
||||
* - Level 2 (tHH): 1.0 × 2^1.6 = 3.03
|
||||
*/
|
||||
private fun getTemporalQuantizerScale(temporalLevel: Int): Float {
|
||||
val BETA = 0.8f
|
||||
val TEMPORAL_BASE_SCALE = 1.0f
|
||||
return TEMPORAL_BASE_SCALE * Math.pow(2.0, (BETA * temporalLevel).toDouble()).toFloat()
|
||||
val BETA = 0.6f // Temporal scaling exponent (aggressive for temporal high-pass)
|
||||
val KAPPA = 1.14f
|
||||
val TEMPORAL_BASE_SCALE = 1.0f // Don't reduce tLL quantization (same as intra)
|
||||
return TEMPORAL_BASE_SCALE * 2.0f.pow(BETA * temporalLevel.toFloat().pow(KAPPA))
|
||||
}
|
||||
|
||||
// level is one-based index
|
||||
@@ -6251,8 +6241,8 @@ class GraphicsJSR223Delegate(private val vm: VM) {
|
||||
* @param compressedDataPtr Pointer to compressed Zstd data
|
||||
* @param compressedSize Size of compressed data
|
||||
* @param gopSize Number of frames in GOP (1-16)
|
||||
* @param motionVectorsX X motion vectors in quarter-pixel units
|
||||
* @param motionVectorsY Y motion vectors in quarter-pixel units
|
||||
* @param motionVectorsX X motion vectors in 1/16-pixel units
|
||||
* @param motionVectorsY Y motion vectors in 1/16-pixel units
|
||||
* @param outputRGBAddrs Array of output RGB buffer addresses
|
||||
* @param width Frame width
|
||||
* @param height Frame height
|
||||
@@ -6363,10 +6353,10 @@ class GraphicsJSR223Delegate(private val vm: VM) {
|
||||
tavApplyInverse3DDWT(gopCg, width, height, gopSize, spatialLevels, temporalLevels, spatialFilter)
|
||||
|
||||
// Step 7: Apply inverse motion compensation (shift frames back)
|
||||
// Note: Motion vectors are in quarter-pixel units
|
||||
// Note: Motion vectors are in 1/16-pixel units, cumulative relative to frame 0
|
||||
for (t in 1 until gopSize) { // Skip frame 0 (reference)
|
||||
val dx = motionVectorsX[t] / 4 // Convert to pixel units
|
||||
val dy = motionVectorsY[t] / 4
|
||||
val dx = motionVectorsX[t] / 16 // Convert to pixel units
|
||||
val dy = motionVectorsY[t] / 16
|
||||
|
||||
if (dx != 0 || dy != 0) {
|
||||
applyInverseTranslation(gopY[t], width, height, dx, dy)
|
||||
@@ -6486,36 +6476,30 @@ class GraphicsJSR223Delegate(private val vm: VM) {
|
||||
|
||||
// Haar wavelet inverse 1D transform
|
||||
// The simplest wavelet: reverses averages and differences
|
||||
// MUST match encoder's dwt_haar_inverse_1d exactly (encoder_tav.c:1265-1284)
|
||||
private fun tavApplyDWTHaarInverse1D(data: FloatArray, length: Int) {
|
||||
if (length < 2) return
|
||||
|
||||
val temp = FloatArray(length)
|
||||
val half = (length + 1) / 2
|
||||
|
||||
// Split into low and high frequency components
|
||||
for (i in 0 until half) {
|
||||
temp[i] = data[i] // Low-pass coefficients (averages)
|
||||
}
|
||||
for (i in 0 until length / 2) {
|
||||
if (half + i < length) {
|
||||
temp[half + i] = data[half + i] // High-pass coefficients (differences)
|
||||
}
|
||||
}
|
||||
|
||||
// Haar inverse: reconstruct original samples from averages and differences
|
||||
// Inverse Haar transform: reconstruct from averages and differences
|
||||
// Read directly from data array (already has low-pass then high-pass layout)
|
||||
for (i in 0 until half) {
|
||||
if (2 * i + 1 < length) {
|
||||
val avg = temp[i] // Average (low-pass)
|
||||
val diff = if (half + i < length) temp[half + i] else 0.0f // Difference (high-pass)
|
||||
|
||||
// Reconstruct original adjacent pair
|
||||
data[2 * i] = avg + diff // First sample: average + difference
|
||||
data[2 * i + 1] = avg - diff // Second sample: average - difference
|
||||
// Reconstruct adjacent pairs from average and difference
|
||||
temp[2 * i] = data[i] + data[half + i] // average + difference
|
||||
temp[2 * i + 1] = data[i] - data[half + i] // average - difference
|
||||
} else {
|
||||
// Handle odd length: last sample comes directly from low-pass
|
||||
data[2 * i] = temp[i]
|
||||
// Handle odd length: last sample comes from low-pass only
|
||||
temp[2 * i] = data[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Copy reconstructed data back
|
||||
for (i in 0 until length) {
|
||||
data[i] = temp[i]
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user