-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Terrible slow caused by ComplexBatchNormalization() #30
Comments
Indeed, the Complex BatchNorm is not optimized and is not previewed to be optimized in the short term. I am sorry for the trouble caused. The reason is similar as what happens with ComplexPyTorch. |
I was having the same problem and came up with this simple solution. According to the authors of ComplexPyTorch performing batch nomalization in a 'naive' way i.e. separately on the real and imaginary parts does not have a significant impact when compared to the complex formulation of Trabelsi et al. Here's a TF version of their NaiveComplexBatchNorm layer, which can be used with the keras functional API. import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
def naive_complex_batch_normalization(inputs: tf.Tensor) -> tf.Tensor:
real, imag = tf.cast(tf.math.real(inputs), tf.float32), tf.cast(tf.math.imag(inputs), tf.float32)
real_bn, imag_bn = BatchNormalization()(real), BatchNormalization()(imag)
return tf.cast(tf.complex(real_bn, imag_bn), tf.complex64) @NEGU93, would you be interested in a PR implementing this as a proper |
Sure, not sure what they are based on to guarantee that, from my point of view, doing a naive implementation may have a very negative impact on the phase, which is a crucial aspect of CVNN merits Ref. Please, submit your PR! and thank you for the contribution! |
Here is an implementation of a small 1D CNN for example until that PR would be integrated into the cvnn package: def get_model(input_len=1000, activation_func='crelu'): |
Hi there, @NEGU93. Thanks for the great effort in making this library. It really accelerate my research in signal recognition task. This TF 2.0 version indeed help me deploy in the edge device with the help of TFlite. However, I found
ComplexBatchNormalization()
will terribly slow down the training process. Give one example to reproduce:It almost cost me 10 mins to train one epoch. But, when I substitute
ComplexBatchNormalization()
toBatchNormalization()
, it only costs me half min. Any ideas?The text was updated successfully, but these errors were encountered: