fft: data struct optimisation

This commit is contained in:
minjaesong
2023-11-26 17:33:29 +09:00
parent 3b38958a08
commit 1d727397b4
10 changed files with 94 additions and 131 deletions

View File

@@ -1,15 +1,20 @@
package net.torvald.terrarum.audio
import org.apache.commons.math3.exception.MathIllegalStateException
import org.apache.commons.math3.transform.DftNormalization
import org.apache.commons.math3.transform.TransformType
import org.apache.commons.math3.util.FastMath
data class FComplex(var re: Float = 0f, var im: Float = 0f) {
operator fun times(other: FComplex) = FComplex(
this.re * other.re - this.im * other.im,
this.re * other.im + this.im * other.re
)
class ComplexArray(val res: FloatArray, val ims: FloatArray) {
val indices: IntProgression
get() = 0 until size
val size: Int
get() = res.size
operator fun times(other: ComplexArray): ComplexArray {
val l = size
val re = FloatArray(l) { res[it] * other.res[it] - ims[it] * other.ims[it] }
val im = FloatArray(l) { res[it] * other.ims[it] + ims[it] * other.res[it] }
return ComplexArray(re, im)
}
}
/**
@@ -20,53 +25,16 @@ data class FComplex(var re: Float = 0f, var im: Float = 0f) {
object FFT {
// org.apache.commons.math3.transform.FastFouriesTransformer.java:370
fun fft(signal: FloatArray): Array<FComplex> {
val dataRI = arrayOf(signal.copyOf(), FloatArray(signal.size))
fun fft(signal: FloatArray): ComplexArray {
val dataRI = ComplexArray(signal.copyOf(), FloatArray(signal.size))
transformInPlace(dataRI, DftNormalization.STANDARD, TransformType.FORWARD)
val output = dataRI.toComplexArray()
return getComplex(output, false)
return dataRI
}
// org.apache.commons.math3.transform.FastFouriesTransformer.java:404
fun ifftAndGetReal(y: Array<FComplex>): FloatArray {
val dataRI = Array<FloatArray>(2) { FloatArray(y.size) }
for (i in y.indices) {
dataRI[0][i] = y[i].re
dataRI[1][i] = y[i].im
}
transformInPlace(dataRI, DftNormalization.STANDARD, TransformType.INVERSE)
return dataRI[0]
}
private fun Array<FloatArray>.toComplexArray(): Array<FComplex> {
return Array(this[0].size) {
FComplex(this[0][it], this[1][it])
}
}
// com.github.psambit9791.jdsp.transform.FastFourier.java:190
/**
* Returns the complex value of the fast fourier transformed sequence
* @param onlyPositive Set to True if non-mirrored output is required
* @throws java.lang.ExceptionInInitializerError if called before executing transform() method
* @return Complex[] The complex FFT output
*/
@Throws(ExceptionInInitializerError::class)
fun getComplex(output: Array<FComplex>, onlyPositive: Boolean): Array<FComplex> {
val dftout: Array<FComplex> = if (onlyPositive) {
val numBins: Int = output.size / 2 + 1
Array<FComplex>(numBins) { FComplex() }
}
else {
Array<FComplex>(output.size) { FComplex() }
}
System.arraycopy(output, 0, dftout, 0, dftout.size)
return dftout
fun ifftAndGetReal(y: ComplexArray): FloatArray {
transformInPlace(y, DftNormalization.STANDARD, TransformType.INVERSE)
return y.res
}
// org.apache.commons.math3.transform.FastFouriesTransformer.java:214
@@ -86,12 +54,12 @@ object FFT {
* @throws MathIllegalArgumentException if the number of data points is not
* a power of two
*/
private fun transformInPlace(dataRI: Array<FloatArray>, normalization: DftNormalization, type: TransformType) {
val dataR = dataRI[0]
val dataI = dataRI[1]
private fun transformInPlace(dataRI: ComplexArray, normalization: DftNormalization, type: TransformType) {
val dataR = dataRI.res
val dataI = dataRI.ims
val n = dataR.size
if (n == 1) {
/*if (n == 1) {
return
}
else if (n == 2) {
@@ -108,7 +76,7 @@ object FFT {
dataI[1] = srcI0 - srcI1
normalizeTransformedData(dataRI, normalization, type)
return
}
}*/
bitReversalShuffle2(dataR, dataI)
@@ -230,25 +198,26 @@ object FFT {
* @param type the type of transform (forward, inverse) which resulted in the specified data
*/
private fun normalizeTransformedData(
dataRI: Array<FloatArray>,
dataRI: ComplexArray,
normalization: DftNormalization, type: TransformType
) {
val dataR = dataRI[0]
val dataI = dataRI[1]
val dataR = dataRI.res
val dataI = dataRI.ims
val n = dataR.size
assert(dataI.size == n)
when (normalization) {
DftNormalization.STANDARD -> if (type == TransformType.INVERSE) {
val scaleFactor = 1f / n.toFloat()
var i = 0
while (i < n) {
dataR[i] *= scaleFactor
dataI[i] *= scaleFactor
i++
// assert(dataI.size == n)
// when (normalization) {
// DftNormalization.STANDARD ->
if (type == TransformType.INVERSE) {
val scaleFactor = 1f / n.toFloat()
var i = 0
while (i < n) {
dataR[i] *= scaleFactor
dataI[i] *= scaleFactor
i++
}
}
}
DftNormalization.UNITARY -> {
/* DftNormalization.UNITARY -> {
val scaleFactor = (1.0 / FastMath.sqrt(n.toDouble())).toFloat()
var i = 0
while (i < n) {
@@ -259,7 +228,7 @@ object FFT {
}
else -> throw MathIllegalStateException()
}
}*/
}
/**