Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specgrams #3

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 129 additions & 21 deletions lib/nd_list.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
library nd_list;

import 'dart:math';

export './spectral.dart';

List<int> unsqueezeShape(List<int> shape, int axis) {
if (axis < 0) {
axis += shape.length + 1;
Expand All @@ -8,7 +12,9 @@ List<int> unsqueezeShape(List<int> shape, int axis) {
}

List<int> squeezeShape(List<int> shape) {
return shape.where((element) => element != 1).toList();
final squeezedShape = shape.where((element) => element != 1).toList();

return (squeezedShape.isEmpty) ? [1] : squeezedShape;
}

int getLinearIndex(List<int> shape, List<int> index) {
Expand Down Expand Up @@ -75,8 +81,31 @@ class NDIndexResult<X> {
/// In the end, `sliced` would represent `[[4.0, 5.0], [7.0, 8.0]]`.
class NDList<X> {
final List<X> _list = [];
List<X> get list => _list;
final List<int> _shape = [];

bool get is1D {
return squeezeShape(_shape).length == 1;
}

NDList<X> transpose([int otherAxis = 1]) {
final newShape = List<int>.from(_shape);
final otherLength = _shape[otherAxis];
final axis0Length = _shape[0];
newShape[0] = otherLength;
newShape[otherAxis] = axis0Length;

final newIndicesList = [
for (int i = 0; i < otherLength; i++)
_intIndexWithAxis(NDIndexResult.from(this), i, otherAxis).evaluate()
];
return NDList.from<NDList<X>>(newIndicesList).cemented().reshape(newShape);
}

List<X> toFlattenedList() {
return _list;
}

List toIteratedList() {
// Note! Originally from tflite_flutter's ListShape extension.
// Since this is the only method using tflite_flutter, and we are both using Apache 2.0, I have copied the code here. All rights to the original author(s).
Expand Down Expand Up @@ -303,6 +332,7 @@ class NDList<X> {
final parts = slice.split(':');
final start = parts[0].isEmpty ? 0 : int.parse(parts[0]);
final end = parts[1].isEmpty ? null : int.parse(parts[1]);

return (start, end);
} catch (e) {
return null;
Expand Down Expand Up @@ -430,8 +460,7 @@ class NDList<X> {

static NDIndexResult<Y> _stringIndex<Y>(
NDIndexResult<Y> priorResult, String index, int axis) {
// TODO: remove print
print('string index: "$index"');
// print("string index: $index");
try {
// is it just an int in string format?
// .parse throws if cannot be parsed as an int
Expand All @@ -450,16 +479,10 @@ class NDList<X> {
/// This method is used to index the NDList with a list of valid indices, i.e. ints and formatted slice strings.
static NDIndexResult<X> _listIndex<X>(
NDIndexResult<X> priorResult, List index) {
// TODO: remove print
print('list index: $index');
for (var i = 0; i < index.length; i++) {
if (index[i] is String) {
// TODO: remove print
print('string index: "${index[i]}"');
priorResult = _stringIndex(priorResult, index[i], i);
} else if (index[i] is int) {
// TODO: remove print
print("int index: ${index[i]}");
priorResult = _intIndexWithAxis(priorResult, index[i], i);
} else {
throw ArgumentError(
Expand All @@ -471,8 +494,6 @@ class NDList<X> {

static NDIndexResult<X> _intIndex<X>(
NDIndexResult<X> priorResult, int index) {
// TODO: remove print
print('int index: $index');
if (priorResult.shape.isEmpty) {
throw ArgumentError('Cannot index an empty NDList');
}
Expand Down Expand Up @@ -500,8 +521,6 @@ class NDList<X> {
/// This builds on the base case of an axis-0 int index, and allows for indexing on any axis.
static NDIndexResult<X> _intIndexWithAxis<X>(
NDIndexResult<X> priorResult, int index, int axis) {
// TODO: remove print
print('int index with axis: $index, axis $axis');
return _slice(priorResult, index, index + 1, axis: axis);
}

Expand All @@ -513,12 +532,13 @@ class NDList<X> {
NDIndexResult<Y> priorResult, int start, int end,
{required int axis}) {
// TODO: uncomment and fix, test
// if (start < 0) {
// start += priorResult.shape[axis];
// }
// if (end < 0) {
// end += priorResult.shape[axis];
// }
// support for negative indices
if (start < 0) {
start %= priorResult.shape[axis];
}
if (end < 0) {
end %= priorResult.shape[axis];
}
// if (end < start) {
// return _slice(priorResult, end, start, axis: axis);
// }
Expand Down Expand Up @@ -627,10 +647,11 @@ class NDList<X> {
if (_list.isEmpty) return NDList._([], newShape);
throw ArgumentError('New shape cannot have a dimension of 0');
}
final positiveDims = newShape.where((element) => element < 1).toList();
if (positiveDims.length > 1) {
final impliedDims = newShape.where((element) => element == -1).toList();
if (impliedDims.length > 1) {
throw ArgumentError('Only one dimension can be -1');
}
final positiveDims = newShape.where((element) => element > 0).toList();
final nSpecified = _product(positiveDims);
if (count % nSpecified != 0) {
throw ArgumentError('New shape must have the same number of elements');
Expand Down Expand Up @@ -705,6 +726,36 @@ extension NumNDList on NDList {
}
}

class RollingResult<X> {
List<NDIndexResult<X>> slices;
NDList<X> baseArray;

RollingResult._(this.slices, this.baseArray);

factory RollingResult(NDList<X> baseArray, int windowSize,
{int step = 1, int axis = 0}) {
// allow for negative axis to mean "from end"
// eg -1 means "last axis"
axis = axis % baseArray.nDims;
final slices = [
for (int i = windowSize - 1; i < baseArray.shape[axis]; i = i + step)
NDList._stringIndex(NDIndexResult.from(baseArray),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this line requires this class to be in the same file as NDList, so that we can call the private method.

'${i - windowSize + 1}:${i + 1}', axis)
];
return RollingResult._(slices, baseArray);
}

NDList<Y> reduce<Y>(Y Function(NDList<X>) f) {
return NDList.from<Y>(slices.map((e) => f(e.evaluate())).toList());
}
}

extension Rolling<X> on NDList<X> {
RollingResult<X> rolling(int windowSize, {int step = 1, int axis = 0}) {
return RollingResult(this, windowSize, step: step, axis: axis);
}
}

extension ArithmeticNDList<X extends num> on NDList<X> {
NDList<X> zipWith(NDList<X> other, X Function(X, X) f) {
if (!_shapeMatches(other)) {
Expand Down Expand Up @@ -733,6 +784,63 @@ extension ArithmeticNDList<X extends num> on NDList<X> {
operator /(NDList<X> other) {
return this.zipWith(other, ((p0, p1) => (p0 / p1) as X));
}

X sum() {
return _list.reduce((value, element) => value + element as X);
}

double mean() {
return sum() / count;
}

X quantile(double q) {
final sorted = _list..sort();
final index = (count - 1) * q;
final lower = sorted[index.floor()];
final upper = sorted[index.ceil()];
return lower + (upper - lower) * (index - index.floor()) as X;
}

X median() {
final sorted = _list..sort();
final mid = count ~/ 2;
return (count.isEven ? (sorted[mid - 1] + sorted[mid]) / 2 : sorted[mid])
as X;
}

X max() {
return _list.reduce((value, element) => value > element ? value : element);
}

X min() {
return _list.reduce((value, element) => value < element ? value : element);
}

NDList<X> iqrNormalizationAdaptive(
{required int windowSize, double lowerQ = 0.25, double upperQ = 0.75}) {
final lowerQResult = rolling(windowSize).quantile(0.25);
final higherQResult = rolling(windowSize).quantile(0.75);
final iqr = higherQResult - lowerQResult;

return (this - lowerQResult) / iqr * 2;
}

NDList<X> sumAlong({int axis = 0}) {
return rolling(1, axis: axis).sum();
}
}

extension NumericalAggregation<X extends num> on RollingResult<X> {
NDList<X> sum() => reduce((e) => e.sum());

NDList<double> mean() => reduce((e) => e.mean());

NDList<X> median() => reduce((e) => e.median());

NDList<X> quantile(double q) => reduce((e) => e.quantile(q));

NDList<X> max() => reduce((e) => e.max());
NDList<X> min() => reduce((e) => e.min());
}

extension MultiLinear<X> on NDList<NDList<X>> {
Expand Down
112 changes: 112 additions & 0 deletions lib/spectral.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import 'dart:math';
import 'package:complex/complex.dart';
import 'package:nd_list/nd_list.dart';

/// For an input NDList<double> computes a spectrogram. The spectrogram is taken as follows:
/// 1. If the input is 1D, this is just a sliding FFT
/// 2. If the input is 2D, this is a sliding FFT along the columns (axis 1), meaning each array[[:, i]] is taken as input
/// 3. If the input is 3D, this a stacked spectrogram along axis 2, meaning each array[:, :, i] is taken as input to the previous case.
///
/// The pattern continues into higher dimensions, where the last axis is taken as the input to the previous case.

extension SpectralAnalysis on NDList<double> {
/// Calculates the twiddle factor for a given index and length.
Complex twiddle(int k, int N) {
final angle = -2 * pi * k / N;
return Complex.polar(1, angle);
}

/// The Radix-2 split-radix FFT algorithm for real-valued data.
List<Complex> splitRadixFFT(List<double> data) {
final N = data.length;
if (N <= 1) {
return [Complex(data[0], 0)];
} else if (N == 2) {
// Handle N == 2 case separately
final e = data[0];
final o = data[1];
return [Complex(e + o, 0), Complex(e - o, 0)];
} else if (N == 4) {
// Handle N == 4 case separately
final e0 = data[0];
final e1 = data[1];
final o0 = data[2];
final o1 = data[3];

final t0 = twiddle(0, 4);
final t1 = twiddle(1, 4);

return [
Complex(e0 + e1 + o0 + o1, 0),
Complex(e0 - e1, 0) + t1 * Complex(0, o0 - o1),
Complex(e0 + e1 - o0 - o1, 0),
Complex(e0 - e1, 0) - t1 * Complex(0, o0 - o1),
];
}

// Split into even and odd indices
final even = data.sublist(0, N ~/ 2);
final odd = data.sublist(N ~/ 2);

// Recursively compute FFTs of even and odd parts
final evenFFT = splitRadixFFT(even);
final oddFFT = splitRadixFFT(odd);

// Combine results
final result = List.generate(N, (i) => Complex(0, 0));
for (int k = 0; k < N ~/ 4; k++) {
final t = twiddle(k, N);
final e = evenFFT[k];
final o = oddFFT[k];
final o1 = oddFFT[N ~/ 4 - k - 1].conjugate();
result[k] = e + t * o;
result[k + N ~/ 4] = e - t * o;
result[k + N ~/ 2] = e + t * o1;
result[k + 3 * N ~/ 4] = e - t * o1;
}

return result;
}

/// Computes the Fast Fourier Transform of the NDList<double>
List<Complex> _fft(List<double> x, {bool isReal = true}) {
int N = x.length;

final z = List<Complex>.generate(N, (index) => Complex(x[index], 0));
if (N <= 1) return z;

// Cooley-Tukey FFT algorithm optimized for real input
List<Complex> even =
_fft([for (int i = 0; i < N ~/ 2; i++) x[2 * i]], isReal: false);
List<Complex> odd =
_fft([for (int i = 0; i < N ~/ 2; i++) x[2 * i + 1]], isReal: false);

List<Complex> result = List<Complex>.filled(N, Complex.zero);
for (int k = 0; k < N ~/ 2; k++) {
Complex t = Complex.polar(1.0, -2 * pi * k / N) * odd[k];
result[k] = even[k] + t;
result[k + N ~/ 2] = even[k] - t;
}

return result;
}

NDList<Complex> fft() {
var complexOutput = splitRadixFFT(list);

return NDList.from(complexOutput);
}

NDList<double> spectrogram(int nFFT, {int hopLength = 1}) {
if (!is1D) {
return rolling(nFFT, axis: -1)
.reduce((a) => a.spectrogram(nFFT, hopLength: hopLength))
.cemented();
}

return reshape([-1])
.rolling(nFFT, step: hopLength, axis: 0)
.reduce((e) => e.fft().map((e) => e.abs() * e.abs()))
.cemented();
}
}
8 changes: 8 additions & 0 deletions pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.18.0"
complex:
dependency: "direct main"
description:
name: complex
sha256: dba084899c0a4bd2fcba9a36760409171d7bee7c35a749cc4451348270361325
url: "https://pub.dev"
source: hosted
version: "0.7.2"
convert:
dependency: transitive
description:
Expand Down
2 changes: 2 additions & 0 deletions pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ environment:
dev_dependencies:
lints: ^4.0.0
test: ^1.16.0
dependencies:
complex: ^0.7.2
Loading