-
Notifications
You must be signed in to change notification settings - Fork 4
/
util_cochlear_model.py
1378 lines (1221 loc) · 72.2 KB
/
util_cochlear_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
This script is closely based on pycochleagram and tfcochleagram,
which have been previously released:
https://github.com/mcdermottLab/pycochleagram
https://github.com/jenellefeather/tfcochleagram
Minor modifications have been made here to provide a single script
containing all functions needed to build the cochlear model used
in this project.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import warnings
import functools
import numpy as np
import tensorflow as tf
import scipy.signal as signal
import matplotlib.pyplot as plt
def freq2erb(freq_hz):
"""Converts Hz to human-defined ERBs, using the formula of Glasberg and Moore.
Args:
freq_hz (array_like): frequency to use for ERB.
Returns:
ndarray: **n_erb** -- Human-defined ERB representation of input.
"""
return 9.265 * np.log(1 + freq_hz / (24.7 * 9.265))
def erb2freq(n_erb):
"""Converts human ERBs to Hz, using the formula of Glasberg and Moore.
Args:
n_erb (array_like): Human-defined ERB to convert to frequency.
Returns:
ndarray: **freq_hz** -- Frequency representation of input.
"""
return 24.7 * 9.265 * (np.exp(n_erb / 9.265) - 1)
def get_freq_rand_conversions(xp, seed=0, minval=0.0, maxval=1.0):
"""Generates freq2rand and rand2freq conversion functions.
Args:
xp (array_like): xvals for freq2rand linear interpolation.
seed (int): numpy seed to generate yvals for linear interpolation.
minval (float): yvals for linear interpolation are scaled to [minval, maxval].
maxval (float): yvals for linear interpolation are scaled to [minval, maxval].
Returns:
freq2rand (function): converts Hz to random frequency scale
rand2freq (function): converts random frequency scale to Hz
"""
np.random.seed(seed)
yp = np.cumsum(np.random.poisson(size=xp.shape))
yp = ((maxval - minval) * (yp - yp.min())) / (yp.max() - yp.min()) + minval
freq2rand = lambda x : np.interp(x, xp, yp)
rand2freq = lambda y : np.interp(y, yp, xp)
return freq2rand, rand2freq
def make_cosine_filter(freqs, l, h, convert_to_erb=True):
"""Generate a half-cosine filter. Represents one subband of the cochleagram.
A half-cosine filter is created using the values of freqs that are within the
interval [l, h]. The half-cosine filter is centered at the center of this
interval, i.e., (h - l) / 2. Values outside the valid interval [l, h] are
discarded. So, if freqs = [1, 2, 3, ... 10], l = 4.5, h = 8, the cosine filter
will only be defined on the domain [5, 6, 7] and the returned output will only
contain 3 elements.
Args:
freqs (array_like): Array containing the domain of the filter, in ERB space;
see convert_to_erb parameter below.. A single half-cosine
filter will be defined only on the valid section of these values;
specifically, the values between cutoffs ``l`` and ``h``. A half-cosine filter
centered at (h - l ) / 2 is created on the interval [l, h].
l (float): The lower cutoff of the half-cosine filter in ERB space; see
convert_to_erb parameter below.
h (float): The upper cutoff of the half-cosine filter in ERB space; see
convert_to_erb parameter below.
convert_to_erb (bool, default=True): If this is True, the values in
input arguments ``freqs``, ``l``, and ``h`` will be transformed from Hz to ERB
space before creating the half-cosine filter. If this is False, the
input arguments are assumed to be in ERB space.
Returns:
ndarray: **half_cos_filter** -- A half-cosine filter defined using elements of
freqs within [l, h].
"""
if convert_to_erb:
freqs_erb = freq2erb(freqs)
l_erb = freq2erb(l)
h_erb = freq2erb(h)
else:
freqs_erb = freqs
l_erb = l
h_erb = h
avg_in_erb = (l_erb + h_erb) / 2 # center of filter
rnge_in_erb = h_erb - l_erb # width of filter
# return np.cos((freq2erb(freqs[a_l_ind:a_h_ind+1]) - avg)/rnge * np.pi) # h_ind+1 to include endpoint
# return np.cos((freqs_erb[(freqs_erb >= l_erb) & (freqs_erb <= h_erb)]- avg_in_erb) / rnge_in_erb * np.pi) # map cutoffs to -pi/2, pi/2 interval
return np.cos((freqs_erb[(freqs_erb > l_erb) & (freqs_erb < h_erb)]- avg_in_erb) / rnge_in_erb * np.pi) # map cutoffs to -pi/2, pi/2 interval
def make_full_filter_set(filts, signal_length=None):
"""Create the full set of filters by extending the filterbank to negative FFT
frequencies.
Args:
filts (array_like): Array containing the cochlear filterbank in frequency space,
i.e., the output of make_cos_filters_nx. Each row of ``filts`` is a
single filter, with columns indexing frequency.
signal_length (int, optional): Length of the signal to be filtered with this filterbank.
This should be equal to filter length * 2 - 1, i.e., 2*filts.shape[1] - 1, and if
signal_length is None, this value will be computed with the above formula.
This parameter might be deprecated later.
Returns:
ndarray: **full_filter_set** -- Array containing the complete filterbank in
frequency space. This output can be directly applied to the frequency
representation of a signal.
"""
if signal_length is None:
signal_length = 2 * filts.shape[1] - 1
# note that filters are currently such that each ROW is a filter and COLUMN idxs freq
if np.remainder(signal_length, 2) == 0: # even -- don't take the DC & don't double sample nyquist
neg_filts = np.flipud(filts[1:filts.shape[0] - 1, :])
else: # odd -- don't take the DC
neg_filts = np.flipud(filts[1:filts.shape[0], :])
fft_filts = np.vstack((filts, neg_filts))
# we need to switch representation to apply filters to fft of the signal, not sure why, but do it here
return fft_filts.T
def make_cos_filters_nx(signal_length, sr, n, low_lim, hi_lim, sample_factor,
padding_size=None, full_filter=True, strict=True,
bandwidth_scale_factor=1.0, include_lowpass=True,
include_highpass=True, filter_spacing='erb'):
"""Create cosine filters, oversampled by a factor provided by "sample_factor"
Args:
signal_length (int): Length of signal to be filtered with the generated
filterbank. The signal length determines the length of the filters.
sr (int): Sampling rate associated with the signal waveform.
n (int): Number of filters (subbands) to be generated with standard
sampling (i.e., using a sampling factor of 1). Note, the actual number of
filters in the generated filterbank depends on the sampling factor, and
may optionally include lowpass and highpass filters that allow for
perfect reconstruction of the input signal (the exact number of lowpass
and highpass filters is determined by the sampling factor). The
number of filters in the generated filterbank is given below:
+---------------+---------------+-+------------+---+---------------------+
| sample factor | n_out |=| bandpass |\ +| highpass + lowpass |
+===============+===============+=+============+===+=====================+
| 1 | n+2 |=| n |\ +| 1 + 1 |
+---------------+---------------+-+------------+---+---------------------+
| 2 | 2*n+1+4 |=| 2*n+1 |\ +| 2 + 2 |
+---------------+---------------+-+------------+---+---------------------+
| 4 | 4*n+3+8 |=| 4*n+3 |\ +| 4 + 4 |
+---------------+---------------+-+------------+---+---------------------+
| s | s*(n+1)-1+2*s |=| s*(n+1)-1 |\ +| s + s |
+---------------+---------------+-+------------+---+---------------------+
low_lim (int): Lower limit of frequency range. Filters will not be defined
below this limit.
hi_lim (int): Upper limit of frequency range. Filters will not be defined
above this limit.
sample_factor (int): Positive integer that determines how densely ERB function
will be sampled to create bandpass filters. 1 represents standard sampling;
adjacent bandpass filters will overlap by 50%. 2 represents 2x overcomplete sampling;
adjacent bandpass filters will overlap by 75%. 4 represents 4x overcomplete sampling;
adjacent bandpass filters will overlap by 87.5%.
padding_size (int, optional): If None (default), the signal will not be padded
before filtering. Otherwise, the filters will be created assuming the
waveform signal will be padded to length padding_size*signal_length.
full_filter (bool, default=True): If True (default), the complete filter that
is ready to apply to the signal is returned. If False, only the first
half of the filter is returned (likely positive terms of FFT).
strict (bool, default=True): If True (default), will throw an error if
sample_factor is not a power of two. This facilitates comparison across
sample_factors. Also, if True, will throw an error if provided hi_lim
is greater than the Nyquist rate.
bandwidth_scale_factor (float, default=1.0): scales the bandpass filter bandwidths.
bandwidth_scale_factor=2.0 means half-cosine filters will be twice as wide.
Note that values < 1 will cause frequency gaps between the filters.
bandwidth_scale_factor requires sample_factor=1, include_lowpass=False, include_highpass=False.
include_lowpass (bool, default=True): if set to False, lowpass filter will be discarded.
include_highpass (bool, default=True): if set to False, highpass filter will be discarded.
filter_spacing (str, default='erb'): Specifies the type of reference spacing for the
half-cosine filters. Options include 'erb' and 'linear'.
Returns:
tuple:
A tuple containing the output:
* **filts** (*array*)-- The filterbank consisting of filters have
cosine-shaped frequency responses, with center frequencies equally
spaced from low_lim to hi_lim on a scale specified by filter_spacing
* **center_freqs** (*array*) -- center frequencies of filterbank in filts
* **freqs** (*array*) -- freq vector in Hz, same frequency dimension as filts
Raises:
ValueError: Various value errors for bad choices of sample_factor or frequency
limits; see description for strict parameter.
UserWarning: Raises warning if cochlear filters exceed the Nyquist
limit or go below 0.
NotImplementedError: Raises error if specified filter_spacing is not implemented
"""
# Specifiy the type of filter spacing, if using linear filters instead
if filter_spacing == 'erb':
_freq2ref = freq2erb
_ref2freq = erb2freq
elif filter_spacing == 'erb_r':
_freq2ref = lambda x: freq2erb(hi_lim) - freq2erb(hi_lim - x)
_ref2freq = lambda x: hi_lim - erb2freq(freq2erb(hi_lim) - x)
elif (filter_spacing == 'lin') or (filter_spacing == 'linear'):
_freq2ref = lambda x: x
_ref2freq = lambda x: x
elif 'random' in filter_spacing:
_freq2ref, _ref2freq = get_freq_rand_conversions(
np.linspace(low_lim, hi_lim, n),
seed=int(filter_spacing.split('-')[1].replace('seed', '')),
minval=freq2erb(low_lim),
maxval=freq2erb(hi_lim))
else:
raise NotImplementedError('unrecognized spacing mode: %s' % filter_spacing)
print('[make_cos_filters_nx] using filter_spacing=`{}`'.format(filter_spacing))
if not bandwidth_scale_factor == 1.0:
assert sample_factor == 1, "bandwidth_scale_factor only supports sample_factor=1"
assert include_lowpass == False, "bandwidth_scale_factor only supports include_lowpass=False"
assert include_highpass == False, "bandwidth_scale_factor only supports include_highpass=False"
if not isinstance(sample_factor, int):
raise ValueError('sample_factor must be an integer, not %s' % type(sample_factor))
if sample_factor <= 0:
raise ValueError('sample_factor must be positive')
if sample_factor != 1 and np.remainder(sample_factor, 2) != 0:
msg = 'sample_factor odd, and will change filter widths. Use even sample factors for comparison.'
if strict:
raise ValueError(msg)
else:
warnings.warn(msg, RuntimeWarning, stacklevel=2)
if padding_size is not None and padding_size >= 1:
signal_length += padding_size
if np.remainder(signal_length, 2) == 0: # even length
n_freqs = signal_length // 2 # .0 does not include DC, likely the sampling grid
max_freq = sr / 2 # go all the way to nyquist
else: # odd length
n_freqs = (signal_length - 1) // 2 # .0
max_freq = sr * (signal_length - 1) / 2 / signal_length # just under nyquist
# verify the high limit is allowed by the sampling rate
if hi_lim > sr / 2:
hi_lim = max_freq
msg = 'input arg "hi_lim" exceeds nyquist limit for max frequency; ignore with "strict=False"'
if strict:
raise ValueError(msg)
else:
warnings.warn(msg, RuntimeWarning, stacklevel=2)
# changing the sampling density without changing the filter locations
# (and, thereby changing their widths) requires that a certain number of filters
# be used.
n_filters = sample_factor * (n + 1) - 1
n_lp_hp = 2 * sample_factor
freqs = np.linspace(0, max_freq, n_freqs + 1)
filts = np.zeros((n_freqs + 1, n_filters + n_lp_hp))
# cutoffs are evenly spaced on the scale specified by filter_spacing; for ERB scale,
# interpolate linearly in erb space then convert back.
# Also return the actual spacing used to generate the sequence (in case numpy does
# something weird)
center_freqs, step_spacing = np.linspace(_freq2ref(low_lim), _freq2ref(hi_lim), n_filters + 2, retstep=True) # +2 for bin endpoints
# we need to exclude the endpoints
center_freqs = center_freqs[1:-1]
freqs_ref = _freq2ref(freqs)
for i in range(n_filters):
i_offset = i + sample_factor
l = center_freqs[i] - sample_factor * bandwidth_scale_factor * step_spacing
h = center_freqs[i] + sample_factor * bandwidth_scale_factor * step_spacing
if _ref2freq(h) > sr/2:
cf = _ref2freq(center_freqs[i])
msg = "High ERB cutoff of filter with cf={:.2f}Hz exceeds {:.2f}Hz (Nyquist frequency)"
warnings.warn(msg.format(cf, sr/2))
if _ref2freq(l) < 0:
cf = _ref2freq(center_freqs[i])
msg = 'Low ERB cutoff of filter with cf={:.2f}Hz is not strictly positive'
warnings.warn(msg.format(cf))
# the first sample_factor # of rows in filts will be lowpass filters
filts[(freqs_ref > l) & (freqs_ref < h), i_offset] = make_cosine_filter(freqs_ref, l, h, convert_to_erb=False)
# add lowpass and highpass filters (there will be sample_factor number of each)
for i in range(sample_factor):
# account for the fact that the first sample_factor # of filts are lowpass
i_offset = i + sample_factor
lp_h_ind = max(np.where(freqs < _ref2freq(center_freqs[i]))[0]) # lowpass filter goes up to peak of first cos filter
lp_filt = np.sqrt(1 - np.power(filts[:lp_h_ind+1, i_offset], 2))
hp_l_ind = min(np.where(freqs > _ref2freq(center_freqs[-1-i]))[0]) # highpass filter goes down to peak of last cos filter
hp_filt = np.sqrt(1 - np.power(filts[hp_l_ind:, -1-i_offset], 2))
filts[:lp_h_ind+1, i] = lp_filt
filts[hp_l_ind:, -1-i] = hp_filt
# get center freqs for lowpass and highpass filters
cfs_low = np.copy(center_freqs[:sample_factor]) - sample_factor * step_spacing
cfs_hi = np.copy(center_freqs[-sample_factor:]) + sample_factor * step_spacing
center_freqs = np.concatenate((cfs_low, center_freqs, cfs_hi))
# ensure that squared freq response adds to one
filts = filts / np.sqrt(sample_factor)
# convert center freqs from ERB numbers to Hz
center_freqs = _ref2freq(center_freqs)
# rectify
center_freqs[center_freqs < 0] = 1
# discard highpass and lowpass filters, if requested
if include_lowpass == False:
filts = filts[:, sample_factor:]
center_freqs = center_freqs[sample_factor:]
if include_highpass == False:
filts = filts[:, :-sample_factor]
center_freqs = center_freqs[:-sample_factor]
# make the full filter by adding negative components
if full_filter:
filts = make_full_filter_set(filts, signal_length)
return filts, center_freqs, freqs
def tflog10(x):
"""Implements log base 10 in tensorflow """
numerator = tf.log(x)
denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
@tf.custom_gradient
def stable_power_compression_norm_grad(x):
"""With this power compression function, the gradients from the power compression are not applied via backprop, we just pass the previous gradient onwards"""
e = tf.nn.relu(x) # add relu to x to avoid NaN in loss
p = tf.pow(e,0.3)
def grad(dy): #try to check for nans before we clip the gradients. (use tf.where)
return dy
return p, grad
@tf.custom_gradient
def stable_power_compression(x):
"""Clip the gradients for the power compression and remove nans. Clipped values are (-1,1), so any cochleagram value below ~0.2 will be clipped."""
e = tf.nn.relu(x) # add relu to x to avoid NaN in loss
p = tf.pow(e,0.3)
def grad(dy): #try to check for nans before we clip the gradients. (use tf.where)
g = 0.3 * pow(e,-0.7)
is_nan_values = tf.is_nan(g)
replace_nan_values = tf.ones(tf.shape(g), dtype=tf.float32)*1
return dy * tf.where(is_nan_values,replace_nan_values,tf.clip_by_value(g, -1, 1))
return p, grad
def cochleagram_graph(nets, SIGNAL_SIZE, SR, ENV_SR=200, LOW_LIM=20, HIGH_LIM=8000, N=40, SAMPLE_FACTOR=4, compression='none', WINDOW_SIZE=1001, debug=False, subbands_ifft=False, pycoch_downsamp=False, linear_max=796.87416837456942, input_node='input_signal', mean_subtract=False, rms_normalize=False, SMOOTH_ABS = False, return_subbands_only=False, include_all_keys=False, rectify_and_lowpass_subbands=False, pad_factor=None, return_coch_params=False, rFFT=False, linear_params=None, custom_filts=None, custom_compression_op=None, erb_filter_kwargs={}, reshape_kell2018=False, include_subbands_noise=False, subbands_noise_mean=0., subbands_noise_stddev=0., rate_level_kwargs={}, preprocess_kwargs={}):
"""
Creates a tensorflow cochleagram graph using the pycochleagram erb filters to create the cochleagram with the tensorflow functions.
Parameters
----------
nets : dictionary
dictionary containing parts of the cochleagram graph. At a minumum, nets['input_signal'] (or equivilant) should be defined containing a placeholder (if just constructing cochleagrams) or a variable (if optimizing over the cochleagrams), and can have a batch size>1.
SIGNAL_SIZE : int
the length of the audio signal used for the cochleagram graph
SR : int
raw sampling rate in Hz for the audio.
ENV_SR : int
the sampling rate for the cochleagram after downsampling
LOW_LIM : int
Lower frequency limits for the filters.
HIGH_LIM : int
Higher frequency limits for the filters.
N : int
Number of filters to uniquely span the frequency space
SAMPLE_FACTOR : int
number of times to overcomplete the filters.
compression : string. see include_compression for compression options
determine compression type to use in the cochleagram graph. If return_subbands is true, compress the rectified subbands
WINDOW_SIZE : int
the size of a window to use for the downsampling filter
debug : boolean
Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True (default False).
subbands_ifft : boolean
If true, adds the ifft of the subbands to nets
input_node : string
Name of the top level of nets, this is the input into the cochleagram graph.
mean_subtract : boolean
If true, subtracts the mean of the waveform (explicitly removes the DC offset)
rms_normalize : Boolean # ONLY USE WHEN GENERATING COCHLEAGRAMS
If true, divides the input signal by its RMS value, such that the RMS value of the sound going into the cochleagram generation is equal to 1. This option should be false if inverting cochleagrams, as it can cause problems with the gradients
linear_max : float
If default value, use 796.87416837456942, which is the 5th percentile from the speech dataset when it is rms normalized to a value of 1. This value is only used if the compression is 'linearbelow1', 'linearbelow1sqrt', 'stable_point3'
SMOOTH_ABS : Boolean
If True, uses a smoother version of the absolute value for the hilbert transform sqrt(10^-3 + real(env) + imag(env))
return_subbands_only : Boolean
If True, returns the non-envelope extracted subbands before taking the hilbert envelope as the output node of the graph
include_all_keys : Boolean
If True, returns all of the cochleagram and subbands processing keys in the dictionary
rectify_and_lowpass_subbands : Boolean
If True, rectifies and lowpasses the subbands before returning them (only works with return_subbands_only)
pad_factor : int
how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
return_coch_params : Boolean
If True, returns the cochleagram generation parameters in addition to nets
rFFT : Boolean
If True, builds the graph using rFFT and irFFT operations whenever possible
linear_params : list of floats
used for the linear compression operation, [m, b] where the output of the compression is y=mx+b. m and b can be vectors of shape [1,num_filts,1] to apply different values to each frequency channel.
custom_filts : None, or numpy array
if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
custom_compression_op : None or tensorflow partial function
if specified as a function, applies the tensorflow function as a custom compression operation. Should take the input node and 'name' as the arguments
erb_filter_kwargs : dictionary
contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
reshape_kell2018 : boolean (False)
if true, reshapes the output cochleagram to be 256x256 as used by kell2018
include_subbands_noise : boolean (False)
if include_subbands_noise and return_subbands_only are both true, white noise is added to subbands after compression (this feature is currently only accessible when return_subbands_only == True)
subbands_noise_mean : float
sets mean of subbands white noise if include_subbands_noise == True
subbands_noise_stddev : float
sets standard deviation of subbands white noise if include_subbands_noise == True
rate_level_kwargs : dictionary
contains keyword arguments for AN_rate_level_function (used if compression == 'rate_level')
preprocess_kwargs : dictionary
contains keyword arguments for preprocess_input function (used to randomize input dB SPL)
Returns
-------
nets : dictionary
a dictionary containing the parts of the cochleagram graph. Top node in this graph is nets['output_tfcoch_graph']
COCH_PARAMS : dictionary (Optional)
a dictionary containing all of the input parameters into the function
"""
if return_coch_params:
COCH_PARAMS = locals()
COCH_PARAMS.pop('nets')
# run preprocessing operations on the input (ie rms normalization, convert to complex)
nets = preprocess_input(nets, SIGNAL_SIZE, input_node, mean_subtract, rms_normalize, rFFT, **preprocess_kwargs)
# fft of the input
nets = fft_of_input(nets, pad_factor,debug, rFFT)
# Make a wrapper for the compression function so it can be applied to the cochleagram and the subbands
compression_function = functools.partial(include_compression, compression=compression, linear_max=linear_max, linear_params=linear_params, rate_level_kwargs=rate_level_kwargs, custom_compression_op=custom_compression_op)
# make cochlear filters and compute the cochlear subbands
nets = extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, debug, subbands_ifft, return_subbands_only, rectify_and_lowpass_subbands, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function, include_subbands_noise, subbands_noise_mean, subbands_noise_stddev)
# Build the rest of the graph for the downsampled cochleagram, if we are returning the cochleagram or if we want to build the whole graph anyway.
if (not return_subbands_only) or include_all_keys:
# hilbert transform on subband fft
nets = hilbert_transform_from_fft(nets, SR, SIGNAL_SIZE, pad_factor, debug, rFFT)
# absolute value of the envelopes (and expand to one channel)
nets = abs_envelopes(nets, SMOOTH_ABS)
# downsample and rectified nonlinearity
nets = downsample_and_rectify(nets, SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp)
# compress cochleagram
nets = compression_function(nets, input_node_name='cochleagram_no_compression', output_node_name='cochleagram')
if reshape_kell2018:
nets, output_node_name_coch = reshape_coch_kell_2018(nets)
else:
output_node_name_coch = 'cochleagram'
if return_subbands_only:
nets['output_tfcoch_graph'] = nets['subbands_time_processed']
else:
nets['output_tfcoch_graph'] = nets[output_node_name_coch]
# return
if return_coch_params:
return nets, COCH_PARAMS
else:
return nets
def preprocess_input(nets, SIGNAL_SIZE, input_node, mean_subtract, rms_normalize, rFFT,
set_dBSPL=False, dBSPL_range=[60., 60.]):
"""
Does preprocessing on the input (rms and converting to complex number)
Parameters
----------
nets : dictionary
dictionary containing parts of the cochleagram graph. should already contain input_node
input_node : string
Name of the top level of nets, this is the input into the cochleagram graph.
mean_subtract : boolean
If true, subtracts the mean of the waveform (explicitly removes the DC offset)
rms_normalize : Boolean # TODO: incorporate stable gradient code for RMS
If true, divides the input signal by its RMS value, such that the RMS value of the sound going
rFFT : Boolean
If true, preprocess input for using the rFFT operations
set_dBSPL : Boolean
If true, re-scale input waveform to dB SPL sampled uniformly from dBSPL_range
dBSPL_range : list
Range of sound presentation levels in units of dB re 20e-6 Pa ([minval, maxval])
Returns
-------
nets : dictionary
updated dictionary containing parts of the cochleagram graph.
"""
if rFFT:
if SIGNAL_SIZE%2!=0:
print('rFFT is only tested with even length signals. Change your input length.')
return
processed_input_node = input_node
if mean_subtract:
processed_input_node = processed_input_node + '_mean_subtract'
nets[processed_input_node] = nets[input_node] - tf.reshape(tf.reduce_mean(nets[input_node],1),(-1,1))
input_node = processed_input_node
if rms_normalize: # TODO: incoporate stable RMS normalization
processed_input_node = processed_input_node + '_rms_normalized'
nets['rms_input'] = tf.sqrt(tf.reduce_mean(tf.square(nets[input_node]), 1))
nets[processed_input_node] = tf.identity(nets[input_node]/tf.reshape(nets['rms_input'],(-1,1)),'rms_normalized_input')
input_node = processed_input_node
if set_dBSPL: # NOTE: unstable if RMS of input is zero
processed_input_node = processed_input_node + '_set_dBSPL'
assert rms_normalize == False, "rms_normalize must be False if set_dBSPL=True"
assert len(dBSPL_range) == 2, "dBSPL_range must be specified as [minval, maxval]"
nets['dBSPL_set'] = tf.random.uniform([tf.shape(nets[input_node])[0], 1],
minval=dBSPL_range[0], maxval=dBSPL_range[1],
dtype=nets[input_node].dtype, name='sample_dBSPL_set')
nets['rms_set'] = 20e-6 * tf.math.pow(10., nets['dBSPL_set'] / 20.)
nets['rms_input'] = tf.sqrt(tf.reduce_mean(tf.square(nets[input_node]), axis=1, keepdims=True))
nets[processed_input_node] = tf.math.multiply(nets['rms_set'] / nets['rms_input'], nets[input_node],
name='scale_input_to_dBSPL_set')
input_node = processed_input_node
if not rFFT:
nets['input_signal_i'] = nets[input_node]*0.0
nets['input_signal_complex'] = tf.complex(nets[input_node], nets['input_signal_i'], name='input_complex')
else:
nets['input_real'] = nets[input_node]
return nets
def fft_of_input(nets, pad_factor, debug, rFFT):
"""
Computs the fft of the signal and adds appropriate padding
Parameters
----------
nets : dictionary
dictionary containing parts of the cochleagram graph. 'subbands' are used for the hilbert transform
pad_factor : int
how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
debug : boolean
Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True.
rFFT : Boolean
If true, cochleagram graph is constructed using rFFT wherever possible
Returns
-------
nets : dictionary
updated dictionary containing parts of the cochleagram graph with the rFFT of the input
"""
# fft of the input
if not rFFT:
if pad_factor is not None:
nets['input_signal_complex'] = tf.concat([nets['input_signal_complex'], tf.zeros([nets['input_signal_complex'].get_shape()[0], nets['input_signal_complex'].get_shape()[1]*(pad_factor-1)], dtype=tf.complex64)], axis=1)
nets['fft_input'] = tf.fft(nets['input_signal_complex'],name='fft_of_input')
else:
nets['fft_input'] = tf.spectral.rfft(nets['input_real'],name='fft_of_input') # Since the DFT of a real signal is Hermitian-symmetric, RFFT only returns the fft_length / 2 + 1 unique components of the FFT: the zero-frequency term, followed by the fft_length / 2 positive-frequency terms.
nets['fft_input'] = tf.expand_dims(nets['fft_input'], 1, name='exd_fft_of_input')
if debug: # return the real and imaginary parts of the fft separately
nets['fft_input_r'] = tf.real(nets['fft_input'])
nets['fft_input_i'] = tf.imag(nets['fft_input'])
return nets
def extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, debug, subbands_ifft, return_subbands_only, rectify_and_lowpass_subbands, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function, include_subbands_noise, subbands_noise_mean, subbands_noise_stddev):
"""
Computes the cochlear subbands from the fft of the input signal
Parameters
----------
nets : dictionary
dictionary containing parts of the cochleagram graph. 'fft_input' is multiplied by the cochlear filters
SIGNAL_SIZE : int
the length of the audio signal used for the cochleagram graph
SR : int
raw sampling rate in Hz for the audio.
LOW_LIM : int
Lower frequency limits for the filters.
HIGH_LIM : int
Higher frequency limits for the filters.
N : int
Number of filters to uniquely span the frequency space
SAMPLE_FACTOR : int
number of times to overcomplete the filters.
N : int
Number of filters to uniquely span the frequency space
SAMPLE_FACTOR : int
number of times to overcomplete the filters.
pad_factor : int
how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
debug : boolean
Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal
subbands_ifft : boolean
If true, adds the ifft of the subbands to nets
return_subbands_only : Boolean
If True, returns the non-envelope extracted subbands before taking the hilbert envelope as the output node of the graph
rectify_and_lowpass_subbands : Boolean
If True, rectifies and lowpasses the subbands before returning them (only works with return_subbands_only)
rFFT : Boolean
If true, cochleagram graph is constructed using rFFT wherever possible
custom_filts : None, or numpy array
if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
erb_filter_kwargs : dictionary
contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
include_all_keys : Boolean
If True, includes the time subbands and the cochleagram in the dictionary keys
compression_function : function
A partial function that takes in nets and the input and output names to apply compression
include_subbands_noise : boolean (False)
if include_subbands_noise and return_subbands_only are both true, white noise is added to subbands after compression (this feature is currently only accessible when return_subbands_only == True)
subbands_noise_mean : float
sets mean of subbands white noise if include_subbands_noise == True
subbands_noise_stddev : float
sets standard deviation of subbands white noise if include_subbands_noise == True
Returns
-------
nets : dictionary
updated dictionary containing parts of the cochleagram graph.
"""
# make the erb filters tensor
nets['filts_tensor'] = make_filts_tensor(SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, use_rFFT=rFFT, pad_factor=pad_factor, custom_filts=custom_filts, erb_filter_kwargs=erb_filter_kwargs)
# make subbands by multiplying filts with fft of input
nets['subbands'] = tf.multiply(nets['filts_tensor'],nets['fft_input'],name='mul_subbands')
if debug: # return the real and imaginary parts of the subbands separately -- use if matching to their output
nets['subbands_r'] = tf.real(nets['subbands'])
nets['subbands_i'] = tf.imag(nets['subbands'])
# TODO: with using subbands_ifft is redundant.
# make the time subband operations if we are returning the subbands or if we want to include all of the keys in the graph
if subbands_ifft or return_subbands_only or include_all_keys:
if not rFFT:
nets['subbands_ifft'] = tf.real(tf.ifft(nets['subbands'],name='ifft_subbands'),name='ifft_subbands_r')
else:
nets['subbands_ifft'] = tf.spectral.irfft(nets['subbands'],name='ifft_subbands')
if return_subbands_only or include_all_keys:
nets['subbands_time'] = nets['subbands_ifft']
if rectify_and_lowpass_subbands: # TODO: the subband operations are hard coded in?
nets['subbands_time_relu'] = tf.nn.relu(nets['subbands_time'], name='rectified_subbands')
nets['subbands_time_lowpassed'] = hanning_pooling_1d_no_depthwise(nets['subbands_time_relu'], downsample=2, length_of_window=2*4, make_plots=False, data_format='NCW', normalize=True, sqrt_window=False)
# TODO: noise is only added in the case when we are calcalculating the time subbands, but we might want something similar for the cochleagram
if return_subbands_only or include_all_keys:
# Compress subbands if specified and add noise.
nets = compression_function(nets, input_node_name='subbands_time_lowpassed', output_node_name='subbands_time_lowpassed_compressed')
if include_subbands_noise:
nets = add_neural_noise(nets, subbands_noise_mean, subbands_noise_stddev, input_node_name='subbands_time_lowpassed_compressed', output_node_name='subbands_time_lowpassed_compressed_with_noise')
nets['subbands_time_lowpassed_compressed_with_noise'] = tf.expand_dims(nets['subbands_time_lowpassed_compressed_with_noise'],-1)
nets['subbands_time_processed'] = nets['subbands_time_lowpassed_compressed_with_noise']
else:
nets['subbands_time_lowpassed_compressed'] = tf.expand_dims(nets['subbands_time_lowpassed_compressed'],-1)
nets['subbands_time_processed'] = nets['subbands_time_lowpassed_compressed']
return nets
def hilbert_transform_from_fft(nets, SR, SIGNAL_SIZE, pad_factor, debug, rFFT):
"""
Performs the hilbert transform from the subband FFT -- gets ifft using only the real parts of the signal
Parameters
----------
nets : dictionary
dictionary containing parts of the cochleagram graph. 'subbands' are used for the hilbert transform
SR : int
raw sampling rate in Hz for the audio.
SIGNAL_SIZE : int
the length of the audio signal used for the cochleagram graph
pad_factor : int
how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
debug : boolean
Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True.
rFFT : Boolean
If true, cochleagram graph is constructed using rFFT wherever possible
"""
if not rFFT:
# make the step tensor for the hilbert transform (only keep the real components)
if pad_factor is not None:
freq_signal = np.fft.fftfreq(SIGNAL_SIZE*pad_factor, 1./SR)
else:
freq_signal = np.fft.fftfreq(SIGNAL_SIZE,1./SR)
nets['step_tensor'] = make_step_tensor(freq_signal)
# envelopes in frequency domain -- hilbert transform of the subbands
nets['envelopes_freq'] = tf.multiply(nets['subbands'],nets['step_tensor'],name='env_freq')
else:
# make the padding to turn rFFT into a step function
num_filts = nets['filts_tensor'].get_shape().as_list()[1]
# num_batch = nets['subbands'].get_shape().as_list()[0]
num_batch = tf.shape(nets['subbands'])[0]
# TODO: this also might be a problem when we have pad_factor > 1
print(num_batch)
print(num_filts)
print(int(SIGNAL_SIZE/2)-1)
nets['hilbert_padding'] = tf.zeros([num_batch,num_filts,int(SIGNAL_SIZE/2)-1], tf.complex64)
nets['envelopes_freq'] = tf.concat([nets['subbands'],nets['hilbert_padding']],2,name='env_freq')
if debug: # return real and imaginary parts separately
nets['envelopes_freq_r'] = tf.real(nets['envelopes_freq'])
nets['envelopes_freq_i'] = tf.imag(nets['envelopes_freq'])
# fft of the envelopes.
nets['envelopes_time'] = tf.ifft(nets['envelopes_freq'],name='ifft_envelopes')
if not rFFT: # TODO: was this a bug in pycochleagram where the pad factor doesn't actually work?
if pad_factor is not None:
nets['envelopes_time'] = nets['envelopes_time'][:,:,:SIGNAL_SIZE]
if debug: # return real and imaginary parts separately
nets['envelopes_time_r'] = tf.real(nets['envelopes_time'])
nets['envelopes_time_i'] = tf.imag(nets['envelopes_time'])
return nets
def abs_envelopes(nets, SMOOTH_ABS):
"""
Absolute value of the envelopes (and expand to one channel), analytic hilbert signal
Parameters
----------
nets : dictionary
dictionary containing the cochleagram graph. Downsampling will be applied to 'envelopes_time'
SMOOTH_ABS : Boolean
If True, uses a smoother version of the absolute value for the hilbert transform sqrt(10^-3 + real(env) + imag(env))
Returns
-------
nets : dictionary
dictionary containing the updated cochleagram graph
"""
if SMOOTH_ABS:
nets['envelopes_abs'] = tf.sqrt(1e-10 + tf.square(tf.real(nets['envelopes_time'])) + tf.square(tf.imag(nets['envelopes_time'])))
else:
nets['envelopes_abs'] = tf.abs(nets['envelopes_time'], name='complex_abs_envelopes')
nets['envelopes_abs'] = tf.expand_dims(nets['envelopes_abs'],3, name='exd_abs_real_envelopes')
return nets
def downsample_and_rectify(nets, SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp):
"""
Downsamples the cochleagram and then performs rectification on the output (in case the downsampling results in small negative numbers)
Parameters
----------
nets : dictionary
dictionary containing the cochleagram graph. Downsampling will be applied to 'envelopes_abs'
SR : int
raw sampling rate of the audio signal
ENV_SR : int
end sampling rate of the envelopes
WINDOW_SIZE : int
the size of the downsampling window (should be large enough to go to zero on the edges).
pycoch_downsamp : Boolean
if true, uses a slightly different downsampling function
Returns
-------
nets : dictionary
dictionary containing parts of the cochleagram graph with added nodes for the downsampled subbands
"""
# The stride for the downsample, works fine if it is an integer.
DOWNSAMPLE = SR/ENV_SR
if not ENV_SR == SR:
# make the downsample tensor
nets['downsample_filt_tensor'] = make_downsample_filt_tensor(SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp=pycoch_downsamp)
nets['cochleagram_preRELU'] = tf.nn.conv2d(nets['envelopes_abs'], nets['downsample_filt_tensor'], [1, 1, DOWNSAMPLE, 1], 'SAME',name='conv2d_cochleagram_raw')
else:
nets['cochleagram_preRELU'] = nets['envelopes_abs']
nets['cochleagram_no_compression'] = tf.nn.relu(nets['cochleagram_preRELU'], name='coch_no_compression')
return nets
def include_compression(nets, compression='none', linear_max=796.87416837456942, input_node_name='cochleagram_no_compression', output_node_name='cochleagram', linear_params=None, rate_level_kwargs={}, custom_compression_op=None):
"""
Choose compression operation to use and adds appropriate nodes to nets
Parameters
----------
nets : dictionary
dictionary containing parts of the cochleagram graph. Compression will be applied to input_node_name
compression : string
type of compression to perform
linear_max : float
used for the linearbelow compression operations (compression is linear below a value and compressed above it)
input_node_name : string
name in nets to apply the compression
output_node_name : string
name in nets that will be used for the following operation (default is cochleagram, but if returning subbands than it can be chaged)
linear_params : list of floats
used for the linear compression operation, [m, b] where the output of the compression is y=mx+b. m and b can be vectors of shape [1,num_filts,1] to apply different values to each frequency channel.
custom_compression_op : None or tensorflow partial function
if specified as a function, applies the tensorflow function as a custom compression operation. Should take the input node and 'name' as the arguments
Returns
-------
nets : dictionary
dictionary containing parts of the cochleagram graph with added nodes for the compressed cochleagram
"""
# compression of the cochleagram
if compression=='quarter':
nets[output_node_name] = tf.sqrt(tf.sqrt(nets[input_node_name], name=output_node_name))
elif compression=='quarter_plus':
nets[output_node_name] = tf.sqrt(tf.sqrt(nets[input_node_name]+1e-01, name=output_node_name))
elif compression=='point3':
nets[output_node_name] = tf.pow(nets[input_node_name],0.3, name=output_node_name)
elif compression=='stable_point3':
nets[output_node_name] = tf.identity(stable_power_compression(nets[input_node_name]*linear_max),name=output_node_name)
elif compression=='stable_point3_norm_grads':
nets[output_node_name] = tf.identity(stable_power_compression_norm_grad(nets[input_node_name]*linear_max),name=output_node_name)
elif compression=='linearbelow1':
nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, tf.pow(nets[input_node_name]*linear_max,0.3), name=output_node_name)
elif compression=='stable_linearbelow1':
nets['stable_power_compressed_%s'%output_node_name] = tf.identity(stable_power_compression(nets[input_node_name]*linear_max),name='stable_power_compressed_%s'%output_node_name)
nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, nets['stable_power_compressed_%s'%output_node_name], name=output_node_name)
elif compression=='linearbelow1sqrt':
nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, tf.sqrt(nets[input_node_name]*linear_max), name=output_node_name)
elif compression=='quarter_clipped':
nets[output_node_name] = tf.sqrt(tf.sqrt(tf.maximum(nets[input_node_name],1e-01), name=output_node_name))
elif compression=='none':
nets[output_node_name] = nets[input_node_name]
elif compression=='sqrt':
nets[output_node_name] = tf.sqrt(nets[input_node_name], name=output_node_name)
elif compression=='dB': # NOTE: this compression does not work well for the backwards pass, results in nans
nets[output_node_name + '_noclipped'] = 20 * tflog10(nets[input_node_name])/tf.reduce_max(nets[input_node_name])
nets[output_node_name] = tf.maximum(nets[output_node_name + '_noclipped'], -60)
elif compression=='dB_plus': # NOTE: this compression does not work well for the backwards pass, results in nans
nets[output_node_name + '_noclipped'] = 20 * tflog10(nets[input_node_name]+1)/tf.reduce_max(nets[input_node_name]+1)
nets[output_node_name] = tf.maximum(nets[output_node_name + '_noclipped'], -60, name=output_node_name)
elif compression=='linear':
assert (type(linear_params)==list) and len(linear_params)==2, "Specifying linear compression but not specifying the compression parameters in linear_params=[m, b]"
nets[output_node_name] = linear_params[0]*nets[input_node_name] + linear_params[1]
elif compression=='rate_level':
nets[output_node_name] = AN_rate_level_function(nets[input_node_name], name=output_node_name, **rate_level_kwargs)
elif compression=='custom':
nets[output_node_name] = custom_compression_op(nets[input_node_name], name=output_node_name)
return nets
def make_step_tensor(freq_signal):
"""
Make step tensor for calcaulting the anlyatic envelopes.
Parameters
__________
freq_signal : array
numpy array containing the frequenies of the audio signal (as calculated by np.fft.fftfreqs).
Returns
-------
step_tensor : tensorflow tensor
tensorflow tensor with dimensions [0 len(freq_signal) 0 0] as a step function where frequencies > 0 are 1 and frequencies < 0 are 0.
"""
step_func = (freq_signal>=0).astype(np.int)*2 # wikipedia says that this should be 2x the original.
step_func[freq_signal==0] = 0 # https://en.wikipedia.org/wiki/Analytic_signal (this shouldn't actually matter i think.
step_tensor = tf.constant(step_func, dtype=tf.complex64)
step_tensor = tf.expand_dims(step_tensor, 0)
step_tensor = tf.expand_dims(step_tensor, 1)
return step_tensor
def make_filts_tensor(SIGNAL_SIZE, SR=16000, LOW_LIM=20, HIGH_LIM=8000, N=40, SAMPLE_FACTOR=4, use_rFFT=False, pad_factor=None, custom_filts=None, erb_filter_kwargs={}):
"""
Use pycochleagram to make the filters using the specified prameters (make_erb_cos_filters_nx). Then input them into a tensorflow tensor to be used in the tensorflow cochleagram graph.
Parameters
----------
SIGNAL_SIZE: int
length of the audio signal to convert, and the size of cochleagram filters to make.
SR : int
raw sampling rate in Hz for the audio.
LOW_LIM : int
Lower frequency limits for the filters.
HIGH_LIM : int
Higher frequency limits for the filters.
N : int
Number of filters to uniquely span the frequency space
SAMPLE_FACTOR : int
number of times to overcomplete the filters.
use_rFFT : Boolean
if True, the only returns the first half of the filters, corresponding to the positive component.
custom_filts : None, or numpy array
if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
erb_filter_kwargs : dictionary
contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
Returns
-------
filts_tensor : tensorflow tensor, complex
tensorflow tensor with dimensions [0 SIGNAL_SIZE NUMBER_OF_FILTERS] that includes the erb filters created from make_erb_cos_filters_nx in pycochleagram
"""
if pad_factor:
padding_size = (pad_factor-1)*SIGNAL_SIZE
else:
padding_size=None
if custom_filts is None:
# make the filters
filts, hz_cutoffs, freqs = make_erb_cos_filters_nx(SIGNAL_SIZE, SR, N, LOW_LIM, HIGH_LIM, SAMPLE_FACTOR, padding_size=padding_size, **erb_filter_kwargs) #TODO: decide if we want to change the pad_factor and full_filter arguments.
else: # TODO: ADD CHECKS TO MAKE SURE THAT THESE MATCH UP WITH THE INPUT SIGNAL
assert custom_filts.shape[1] == SIGNAL_SIZE, "CUSTOM FILTER SHAPE DOES NOT MATCH THE INPUT AUDIO SHAPE"
filts = custom_filts
if not use_rFFT:
filts_tensor = tf.constant(filts, tf.complex64)
else: # TODO I believe that this is where the padd factor problem comes in! We are only using part of the signal here.
filts_tensor = tf.constant(filts[:,0:(int(SIGNAL_SIZE/2)+1)], tf.complex64)
filts_tensor = tf.expand_dims(filts_tensor, 0)
return filts_tensor
def make_downsample_filt_tensor(SR=16000, ENV_SR=200, WINDOW_SIZE=1001, pycoch_downsamp=False):
"""
Make the sinc filter that will be used to downsample the cochleagram
Parameters
----------
SR : int
raw sampling rate of the audio signal
ENV_SR : int
end sampling rate of the envelopes
WINDOW_SIZE : int
the size of the downsampling window (should be large enough to go to zero on the edges).
pycoch_downsamp : Boolean
if true, uses a slightly different downsampling function
Returns
-------
downsample_filt_tensor : tensorflow tensor, tf.float32
a tensor of shape [0 WINDOW_SIZE 0 0] the sinc windows with a kaiser lowpass filter that is applied while downsampling the cochleagram
"""
DOWNSAMPLE = SR/ENV_SR
if not pycoch_downsamp:
downsample_filter_times = np.arange(-WINDOW_SIZE/2,int(WINDOW_SIZE/2))
downsample_filter_response_orig = np.sinc(downsample_filter_times/DOWNSAMPLE)/DOWNSAMPLE
downsample_filter_window = signal.kaiser(WINDOW_SIZE, 5)
downsample_filter_response = downsample_filter_window * downsample_filter_response_orig
else:
max_rate = DOWNSAMPLE
f_c = 1. / max_rate # cutoff of FIR filter (rel. to Nyquist)
half_len = 10 * max_rate # reasonable cutoff for our sinc-like function
if max_rate!=1:
downsample_filter_response = signal.firwin(2 * half_len + 1, f_c, window=('kaiser', 5.0))
else: # just in case we aren't downsampling -- I think this should work?
downsample_filter_response = zeros(2 * half_len + 1)
downsample_filter_response[half_len + 1] = 1
# Zero-pad our filter to put the output samples at the center
# n_pre_pad = int((DOWNSAMPLE - half_len % DOWNSAMPLE))
# n_post_pad = 0
# n_pre_remove = (half_len + n_pre_pad) // DOWNSAMPLE
# We should rarely need to do this given our filter lengths...
# while _output_len(len(h) + n_pre_pad + n_post_pad, x.shape[axis],
# up, down) < n_out + n_pre_remove:
# n_post_pad += 1
# downsample_filter_response = np.concatenate((np.zeros(n_pre_pad), downsample_filter_response, np.zeros(n_post_pad)))
downsample_filt_tensor = tf.constant(downsample_filter_response, tf.float32)
downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 0)
downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 2)
downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 3)
return downsample_filt_tensor
def add_neural_noise(nets, subbands_noise_mean, subbands_noise_stddev, input_node_name='subbands_time_lowpassed_compressed', output_node_name='subbands_time_lowpassed_compressed_with_noise'):
# Add white noise variable with the same size to the rectified and compressed subbands
nets['neural_noise'] = tf.random.normal(tf.shape(nets[input_node_name]), mean=subbands_noise_mean,
stddev=subbands_noise_stddev, dtype=nets[input_node_name].dtype)
nets[output_node_name] = tf.nn.relu(tf.math.add(nets[input_node_name], nets['neural_noise']))
return nets