Files
tsvm/video_encoder/lib/libfec/ldpc.c
2025-12-17 21:25:27 +09:00

398 lines
13 KiB
C

/**
* LDPC Rate 1/2 Codec Implementation
*
* LDPC for TAV-DT header protection.
* Uses a systematic rate 1/2 code with sum-product belief propagation decoder.
*
* The parity-check matrix is designed for good error correction on small blocks.
* Each parity bit is computed as XOR of multiple data bits using a pseudo-random
* but deterministic pattern.
*
* Created by CuriousTorvald and Claude on 2025-12-09.
* Updated 2025-12-17: Replaced bit-flipping with belief propagation decoder.
*/
#include "ldpc.h"
#include <string.h>
#include <stdio.h>
#include <math.h>
// Channel LLR magnitude for hard-decision input
// Higher value = more confidence in received bits
// For BER ~0.01, optimal is about 4.6; we use slightly lower for robustness
#define CHANNEL_LLR_MAG 4.0f
// Clipping value to prevent numerical overflow in tanh operations
#define LLR_CLIP 20.0f
// =============================================================================
// Parity-Check Matrix Generation
// =============================================================================
// For rate 1/2 LDPC: n = 2k bits, parity-check matrix H is (n-k) x n = k x 2k
// We use H = [P | I_k] where P is the parity pattern matrix
// This gives systematic encoding: c = [data | parity] where parity = P * data
// Parity pattern: each parity bit j depends on data bits where pattern[j][i] = 1
// We use a regular pattern with column weight 3 (each data bit affects 3 parity bits)
// and row weight varies to cover the data bits well
// Simple hash function for generating parity connections
static inline uint32_t hash_mix(uint32_t a, uint32_t b) {
a ^= b;
a = (a ^ (a >> 16)) * 0x85ebca6b;
a = (a ^ (a >> 13)) * 0xc2b2ae35;
return a ^ (a >> 16);
}
// Get bit from byte array
static inline int get_bit(const uint8_t *data, int bit_idx) {
return (data[bit_idx >> 3] >> (7 - (bit_idx & 7))) & 1;
}
// Set bit in byte array
static inline void set_bit(uint8_t *data, int bit_idx, int value) {
int byte_idx = bit_idx >> 3;
int bit_pos = 7 - (bit_idx & 7);
if (value) {
data[byte_idx] |= (1 << bit_pos);
} else {
data[byte_idx] &= ~(1 << bit_pos);
}
}
// Flip bit in byte array
static inline void flip_bit(uint8_t *data, int bit_idx) {
int byte_idx = bit_idx >> 3;
int bit_pos = 7 - (bit_idx & 7);
data[byte_idx] ^= (1 << bit_pos);
}
// Get list of data bits that affect parity bit j
// Returns number of connected data bits, stores indices in connections[]
// For rate 1/2: data bits are 0 to k*8-1, parity bits are k*8 to 2*k*8-1
static int get_parity_connections(int parity_idx, int k_bits, int *connections) {
int count = 0;
// Use a deterministic pseudo-random pattern
// Each parity bit connects to approximately k_bits/3 data bits
// Different seeds for different parity positions ensure coverage
uint32_t seed = hash_mix(0xDEADBEEF, (uint32_t)parity_idx);
for (int i = 0; i < k_bits; i++) {
// Each data bit has ~3/k_bits chance of connecting to this parity bit
// Total connections per parity ~ 3 (column weight)
uint32_t h = hash_mix(seed, (uint32_t)i);
if ((h % (k_bits / 3 + 1)) == 0) {
connections[count++] = i;
}
}
// Ensure at least 2 connections per parity bit
if (count < 2) {
connections[count++] = parity_idx % k_bits;
connections[count++] = (parity_idx + k_bits / 2) % k_bits;
}
return count;
}
// Get list of parity bits affected by data bit i
static int get_data_connections(int data_idx, int k_bits, int *connections) {
int count = 0;
for (int j = 0; j < k_bits; j++) {
int parity_conns[LDPC_MAX_DATA_BYTES * 8];
int n_conns = get_parity_connections(j, k_bits, parity_conns);
for (int c = 0; c < n_conns; c++) {
if (parity_conns[c] == data_idx) {
connections[count++] = j;
break;
}
}
}
return count;
}
// =============================================================================
// Initialization
// =============================================================================
static int ldpc_initialized = 0;
void ldpc_init(void) {
if (ldpc_initialized) return;
// No pre-computation needed - patterns generated on the fly
ldpc_initialized = 1;
}
// =============================================================================
// Encoding
// =============================================================================
size_t ldpc_encode(const uint8_t *data, size_t data_len, uint8_t *output) {
if (!ldpc_initialized) ldpc_init();
if (data_len > LDPC_MAX_DATA_BYTES) {
data_len = LDPC_MAX_DATA_BYTES;
}
int k_bits = (int)(data_len * 8); // Number of data bits
// Copy data to output (systematic encoding)
memcpy(output, data, data_len);
// Initialize parity bytes to zero
memset(output + data_len, 0, data_len);
// Compute parity bits
for (int j = 0; j < k_bits; j++) {
// Get data bits connected to parity bit j
int connections[LDPC_MAX_DATA_BYTES * 8];
int n_conns = get_parity_connections(j, k_bits, connections);
// Parity bit = XOR of connected data bits
int parity = 0;
for (int c = 0; c < n_conns; c++) {
parity ^= get_bit(data, connections[c]);
}
// Set parity bit
set_bit(output + data_len, j, parity);
}
return data_len * 2;
}
// =============================================================================
// Decoding
// =============================================================================
int ldpc_check_syndrome(const uint8_t *codeword, size_t len) {
if (!ldpc_initialized) ldpc_init();
size_t data_len = len / 2;
int k_bits = (int)(data_len * 8);
// Check all parity equations
for (int j = 0; j < k_bits; j++) {
int connections[LDPC_MAX_DATA_BYTES * 8];
int n_conns = get_parity_connections(j, k_bits, connections);
// Compute syndrome bit: XOR of connected data bits XOR parity bit
int syndrome = get_bit(codeword + data_len, j);
for (int c = 0; c < n_conns; c++) {
syndrome ^= get_bit(codeword, connections[c]);
}
if (syndrome != 0) {
return 0; // Syndrome non-zero: errors detected
}
}
return 1; // Zero syndrome: valid codeword
}
// Clip LLR to prevent overflow
static inline float clip_llr(float llr) {
if (llr > LLR_CLIP) return LLR_CLIP;
if (llr < -LLR_CLIP) return -LLR_CLIP;
return llr;
}
// Sign of a float (returns +1 or -1)
static inline float sign_f(float x) {
return (x >= 0.0f) ? 1.0f : -1.0f;
}
int ldpc_decode(const uint8_t *encoded, size_t encoded_len, uint8_t *output) {
if (!ldpc_initialized) ldpc_init();
if (encoded_len < 2 || (encoded_len & 1) != 0) {
return -1; // Invalid length
}
size_t data_len = encoded_len / 2;
if (data_len > LDPC_MAX_DATA_BYTES) {
return -1;
}
int k_bits = (int)(data_len * 8);
int n_bits = k_bits * 2; // Total codeword bits (data + parity)
// Pre-compute the parity check matrix structure for efficiency
// For each check node j: which variable nodes it connects to
int check_to_var[LDPC_MAX_DATA_BYTES * 8][LDPC_MAX_DATA_BYTES * 8 + 1];
int check_degree[LDPC_MAX_DATA_BYTES * 8];
for (int j = 0; j < k_bits; j++) {
int connections[LDPC_MAX_DATA_BYTES * 8];
int n_conns = get_parity_connections(j, k_bits, connections);
// Check j connects to: data bits in connections[] + parity bit j
check_degree[j] = n_conns + 1;
for (int c = 0; c < n_conns; c++) {
check_to_var[j][c] = connections[c]; // Data bit index
}
check_to_var[j][n_conns] = k_bits + j; // Parity bit index
}
// Initialize channel LLRs from received hard bits
// LLR > 0 means bit is probably 0, LLR < 0 means bit is probably 1
float channel_llr[LDPC_MAX_DATA_BYTES * 16];
for (int i = 0; i < n_bits; i++) {
int bit = get_bit(encoded, i);
channel_llr[i] = bit ? -CHANNEL_LLR_MAG : CHANNEL_LLR_MAG;
}
// Message arrays for BP
// check_to_var_msg[j][idx] = message from check j to variable check_to_var[j][idx]
float check_to_var_msg[LDPC_MAX_DATA_BYTES * 8][LDPC_MAX_DATA_BYTES * 8 + 1];
// Initialize check-to-variable messages to zero
memset(check_to_var_msg, 0, sizeof(check_to_var_msg));
// Belief Propagation iterations
for (int iter = 0; iter < LDPC_MAX_ITERATIONS; iter++) {
// Step 1: Variable-to-check messages (implicit, computed on the fly)
// var_to_check[v→j] = channel_llr[v] + sum of all check_to_var_msg[k][idx_v] for k != j
// Step 2: Check-to-variable messages using min-sum approximation
// For each check node j, for each connected variable v:
// check_to_var_msg[j→v] = sign * min(|incoming messages from other vars|)
for (int j = 0; j < k_bits; j++) {
int degree = check_degree[j];
// First, compute variable-to-check messages for all variables in this check
float var_to_check[LDPC_MAX_DATA_BYTES * 8 + 1];
for (int idx = 0; idx < degree; idx++) {
int v = check_to_var[j][idx];
// Sum all incoming check messages to variable v, except from check j
float sum = channel_llr[v];
for (int jj = 0; jj < k_bits; jj++) {
if (jj == j) continue;
// Find if check jj connects to variable v
for (int idx2 = 0; idx2 < check_degree[jj]; idx2++) {
if (check_to_var[jj][idx2] == v) {
sum += check_to_var_msg[jj][idx2];
break;
}
}
}
var_to_check[idx] = clip_llr(sum);
}
// Now compute check-to-variable messages using min-sum
for (int idx = 0; idx < degree; idx++) {
float sign_prod = 1.0f;
float min_abs = 1e30f;
for (int idx2 = 0; idx2 < degree; idx2++) {
if (idx2 == idx) continue;
float msg = var_to_check[idx2];
sign_prod *= sign_f(msg);
float abs_msg = fabsf(msg);
if (abs_msg < min_abs) min_abs = abs_msg;
}
// Min-sum with scaling factor 0.75 for better performance
check_to_var_msg[j][idx] = clip_llr(sign_prod * min_abs * 0.75f);
}
}
// Step 3: Compute posterior LLRs and make hard decisions
float posterior[LDPC_MAX_DATA_BYTES * 16];
for (int v = 0; v < n_bits; v++) {
float sum = channel_llr[v];
// Add all incoming check-to-variable messages
for (int j = 0; j < k_bits; j++) {
for (int idx = 0; idx < check_degree[j]; idx++) {
if (check_to_var[j][idx] == v) {
sum += check_to_var_msg[j][idx];
break;
}
}
}
posterior[v] = sum;
}
// Make hard decisions
uint8_t decoded[LDPC_MAX_DATA_BYTES * 2];
memset(decoded, 0, encoded_len);
for (int v = 0; v < n_bits; v++) {
if (posterior[v] < 0) {
set_bit(decoded, v, 1);
}
}
// Check syndrome
int syndrome_count = 0;
for (int j = 0; j < k_bits; j++) {
int syn = 0;
for (int idx = 0; idx < check_degree[j]; idx++) {
syn ^= get_bit(decoded, check_to_var[j][idx]);
}
if (syn) syndrome_count++;
}
// If all syndromes are zero, we're done
if (syndrome_count == 0) {
memcpy(output, decoded, data_len);
return 0;
}
// Early termination if syndrome count is very small (nearly converged)
if (iter > 5 && syndrome_count <= 2) {
// Try one more iteration, if still stuck, accept
}
}
// Decoding did not converge - compute final estimate
float posterior[LDPC_MAX_DATA_BYTES * 16];
for (int v = 0; v < n_bits; v++) {
float sum = channel_llr[v];
for (int j = 0; j < k_bits; j++) {
for (int idx = 0; idx < check_degree[j]; idx++) {
if (check_to_var[j][idx] == v) {
sum += check_to_var_msg[j][idx];
break;
}
}
}
posterior[v] = sum;
}
uint8_t decoded[LDPC_MAX_DATA_BYTES * 2];
memset(decoded, 0, encoded_len);
for (int v = 0; v < n_bits; v++) {
if (posterior[v] < 0) {
set_bit(decoded, v, 1);
}
}
// Check final syndrome count
int final_syndromes = 0;
for (int j = 0; j < k_bits; j++) {
int syn = 0;
for (int idx = 0; idx < check_degree[j]; idx++) {
syn ^= get_bit(decoded, check_to_var[j][idx]);
}
if (syn) final_syndromes++;
}
// Accept if syndrome count is low enough
if (final_syndromes <= k_bits / 4) {
memcpy(output, decoded, data_len);
return 0; // Soft success
}
// Total failure - return original data as best effort
memcpy(output, encoded, data_len);
return -1;
}