TAV update: CDF 5/3 for motion coder

This commit is contained in:
minjaesong
2025-11-23 18:16:12 +09:00
parent e928d2d3ec
commit 1c7ab17b1c
6 changed files with 174 additions and 95 deletions

View File

@@ -993,11 +993,34 @@ static void dwt_97_inverse_1d(float *data, int length) {
free(temp);
}
// 5/3 inverse DWT (simplified - uses 9/7 for now)
// 5/3 inverse DWT using lifting scheme (JPEG 2000 reversible filter)
static void dwt_53_inverse_1d(float *data, int length) {
if (length < 2) return;
// TODO: Implement proper 5/3 from TSVM if needed
dwt_97_inverse_1d(data, length);
float *temp = malloc(length * sizeof(float));
int half = (length + 1) / 2;
// Copy low-pass and high-pass subbands to temp
memcpy(temp, data, length * sizeof(float));
// Undo update step (low-pass)
for (int i = 0; i < half; i++) {
float update = 0.25f * ((i > 0 ? temp[half + i - 1] : 0) +
(i < half - 1 ? temp[half + i] : 0));
temp[i] -= update;
}
// Undo predict step (high-pass) and interleave samples
for (int i = 0; i < half; i++) {
data[2 * i] = temp[i]; // Even samples (low-pass)
int idx = 2 * i + 1;
if (idx < length) {
float pred = 0.5f * (temp[i] + (i < half - 1 ? temp[i + 1] : temp[i]));
data[idx] = temp[half + i] + pred; // Odd samples (high-pass)
}
}
free(temp);
}
// Multi-level inverse DWT (matches TSVM exactly with correct non-power-of-2 handling)
@@ -1180,7 +1203,8 @@ static void dwt_haar_inverse_1d(float *data, int length) {
// Order: SPATIAL first (each frame), then TEMPORAL (across frames)
static void apply_inverse_3d_dwt(float **gop_y, float **gop_co, float **gop_cg,
int width, int height, int gop_size,
int spatial_levels, int temporal_levels, int filter_type) {
int spatial_levels, int temporal_levels, int filter_type,
int temporal_motion_coder) {
// Step 1: Apply inverse 2D spatial DWT to each frame
for (int t = 0; t < gop_size; t++) {
apply_inverse_dwt_multilevel(gop_y[t], width, height, spatial_levels, filter_type);
@@ -1212,7 +1236,12 @@ static void apply_inverse_3d_dwt(float **gop_y, float **gop_co, float **gop_cg,
for (int level = temporal_levels - 1; level >= 0; level--) {
const int level_frames = temporal_lengths[level];
if (level_frames >= 2) {
dwt_haar_inverse_1d(temporal_line, level_frames);
// Use selected temporal wavelet (0=Haar, 1=CDF 5/3)
if (temporal_motion_coder == 0) {
dwt_haar_inverse_1d(temporal_line, level_frames);
} else {
dwt_53_inverse_1d(temporal_line, level_frames);
}
}
}
for (int t = 0; t < gop_size; t++) {
@@ -1226,7 +1255,12 @@ static void apply_inverse_3d_dwt(float **gop_y, float **gop_co, float **gop_cg,
for (int level = temporal_levels - 1; level >= 0; level--) {
const int level_frames = temporal_lengths[level];
if (level_frames >= 2) {
dwt_haar_inverse_1d(temporal_line, level_frames);
// Use selected temporal wavelet (0=Haar, 1=CDF 5/3)
if (temporal_motion_coder == 0) {
dwt_haar_inverse_1d(temporal_line, level_frames);
} else {
dwt_53_inverse_1d(temporal_line, level_frames);
}
}
}
for (int t = 0; t < gop_size; t++) {
@@ -1240,7 +1274,12 @@ static void apply_inverse_3d_dwt(float **gop_y, float **gop_co, float **gop_cg,
for (int level = temporal_levels - 1; level >= 0; level--) {
const int level_frames = temporal_lengths[level];
if (level_frames >= 2) {
dwt_haar_inverse_1d(temporal_line, level_frames);
// Use selected temporal wavelet (0=Haar, 1=CDF 5/3)
if (temporal_motion_coder == 0) {
dwt_haar_inverse_1d(temporal_line, level_frames);
} else {
dwt_53_inverse_1d(temporal_line, level_frames);
}
}
}
for (int t = 0; t < gop_size; t++) {
@@ -1706,6 +1745,7 @@ typedef struct {
int frame_count;
int frame_size;
int is_monoblock; // True if version 3-6 (single tile mode)
int temporal_motion_coder; // Temporal wavelet: 0=Haar, 1=CDF 5/3 (extracted from version)
// Screen masking (letterbox/pillarbox) - array of geometry changes
screen_mask_entry_t *screen_masks;
@@ -1942,7 +1982,11 @@ static tav_decoder_t* tav_decoder_init(const char *input_file, const char *outpu
}
decoder->frame_size = decoder->header.width * decoder->header.height;
decoder->is_monoblock = (decoder->header.version >= 3 && decoder->header.version <= 6);
// Extract temporal motion coder from version (versions 9-16 use CDF 5/3, 1-8 use Haar)
decoder->temporal_motion_coder = (decoder->header.version > 8) ? 1 : 0;
// Extract base version for determining monoblock mode
uint8_t base_version = (decoder->header.version > 8) ? (decoder->header.version - 8) : decoder->header.version;
decoder->is_monoblock = (base_version >= 3 && base_version <= 6);
decoder->audio_file_path = strdup(audio_file);
// Phase 2: Initialize decoding dimensions to full frame (will be updated by Screen Mask packets)
@@ -2337,7 +2381,9 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint
// Dequantise (perceptual for versions 5-8, uniform for 1-4)
// Phase 2: Use decoding dimensions and temporary buffers
const int is_perceptual = (decoder->header.version >= 5 && decoder->header.version <= 8);
// Extract base version for perceptual check
uint8_t base_version = (decoder->header.version > 8) ? (decoder->header.version - 8) : decoder->header.version;
const int is_perceptual = (base_version >= 5 && base_version <= 8);
const int is_ezbc = (decoder->header.entropy_coder == 1);
if (is_ezbc && is_perceptual) {
@@ -2472,7 +2518,9 @@ static int decode_i_or_p_frame(tav_decoder_t *decoder, uint8_t packet_type, uint
}
// Convert YCoCg-R/ICtCp to RGB for cropped region
const int is_ictcp = (decoder->header.version % 2 == 0);
// Extract base version for ICtCp check (even versions use ICtCp)
uint8_t base_version_rgb = (decoder->header.version > 8) ? (decoder->header.version - 8) : decoder->header.version;
const int is_ictcp = (base_version_rgb % 2 == 0);
for (int i = 0; i < decoding_pixels; i++) {
uint8_t r, g, b;
@@ -2936,7 +2984,9 @@ int main(int argc, char *argv[]) {
}
// Dequantise with temporal scaling (perceptual quantisation for versions 5-8)
const int is_perceptual = (decoder->header.version >= 5 && decoder->header.version <= 8);
// Extract base version for perceptual check
uint8_t base_version_gop = (decoder->header.version > 8) ? (decoder->header.version - 8) : decoder->header.version;
const int is_perceptual = (base_version_gop >= 5 && base_version_gop <= 8);
const int is_ezbc = (decoder->header.entropy_coder == 1);
const int temporal_levels = 2; // Fixed for TAV GOP encoding
@@ -3034,7 +3084,7 @@ int main(int argc, char *argv[]) {
// Phase 2: Use GOP dimensions (may be cropped) for inverse DWT
apply_inverse_3d_dwt(gop_y, gop_co, gop_cg, gop_width, gop_height,
gop_size, decoder->header.decomp_levels, temporal_levels,
decoder->header.wavelet_filter);
decoder->header.wavelet_filter, decoder->temporal_motion_coder);
// Debug: Check Y values after inverse DWT
if (verbose && decoder->frame_count == 0) {