Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Casting Error complex -> float during learning only. #52

Open
rjdw opened this issue Jun 19, 2024 · 2 comments
Open

Casting Error complex -> float during learning only. #52

rjdw opened this issue Jun 19, 2024 · 2 comments

Comments

@rjdw
Copy link

rjdw commented Jun 19, 2024

Using example from Documentation Example

Model after initialization correctly takes complex input and outputs complex.

However, during training warning, You are casting an input of type complex64 to an incompatible dtype float32. This will discard the imaginary part and may not be what you intended.

During training, model is unable to output complex value. Output contains real part only.

I am unable to determine the origin of this casting error.

As per TF 2.16 Incompatability I am using:

tensorflow                   2.15.0
cvnn                         1.2.22

MWE:

import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from keras.initializers import HeNormal, GlorotUniform
import cvnn.layers as complex_layers
import cvnn.initializers as complex_initializers
## Custom Loss Function
def custom_loss(k):
    def loss(y_true, y_pred):
        # Extract real (a) and imaginary (b) parts
        a = tf.math.real(y_pred)
        b = tf.math.imag(y_pred)
        
        # Main loss: (y - a)^2 + b^2
        mse = tf.reduce_mean(tf.square(y_true - a) + tf.square(b))
        
        # Regularization term: (e^(-k * a) - 1) if a < 0, else 0
        condition = tf.less(a, 0)
        regularization_term = tf.where(condition, tf.exp(-k * a) - 1, tf.zeros_like(a))
        regularization_term = tf.reduce_mean(regularization_term)
        
        # Total loss
        total_loss = mse + regularization_term
        return total_loss
    return loss


# Custom MAE metric
# abs(a - y_true)
def custom_mae(y_true, y_pred):
    a = tf.math.real(y_pred)
    mae = tf.reduce_mean(tf.abs(y_true - a))
    return mae

# Printing b_value metric
def b_value(y_true, y_pred):
    b = tf.math.imag(y_pred)
    b_val = tf.reduce_mean(tf.abs(b))
    return b_val
# Example from Docs
test_model = Sequential()
test_model.add(complex_layers.ComplexInput(input_shape=num_features, dtype=np.complex64))
test_model.add(complex_layers.ComplexFlatten())
test_model.add(complex_layers.ComplexDense(32, activation='cart_relu', dtype=np.complex64))
test_model.add(complex_layers.ComplexDense(1, dtype=np.complex64))
print(test_model.output_shape)
# Showing that the model correctly takes complex input and outputs complex value
x = tf.cast(tf.random.normal((1, 8 ,148)), tf.complex64)
x.shape, x.dtype

(TensorShape([1, 8, 148]), tf.complex64)

out = test_model(x)
out.dtype, out

(tf.complex64,
<tf.Tensor: shape=(1, 1), dtype=complex64, numpy=array([[1.0031033+0.26049298j]], dtype=complex64)>)

k = 0.001

# Compile the model
test_model.compile(optimizer='adam', loss=custom_loss(k), metrics=[custom_mae, b_value])
# Sim Data
import numpy as np

# Generate 100 complex-valued 2D arrays of shape (8, 148)
data = np.random.randn(100, 8, 148) + 1j * np.random.randn(100, 8, 148)
data = data.astype(np.complex64)

# Generate 100 labels of dtype float32
labels = np.random.randn(100).astype(np.float32)

# Split data and labels into training (80%) and validation (20%) sets
train_size = int(0.8 * len(data))

t_d = data[:train_size]
t_l = labels[:train_size]

v_d = data[train_size:]
v_l = labels[train_size:]

t_d.shape, t_l.shape, v_d.shape, v_l.shape
# Data is correctly typed
for data in t_d:
    if data.dtype != 'complex64':
        print(data.dtype)
        print(exit)
t_l[0].dtype, t_d[0].real.dtype

(dtype('float32'), dtype('float32'))

from keras.callbacks import Callback

# Callback for seeing the b value in the complex output (a + ib)
class PrintBValueCallback(Callback):
    def on_train_batch_end(self, batch, logs=None):
        b_value = logs.get('b_value')
        print(f'Batch {batch}, b_value: {b_value}')
# Train the model
history = test_model.fit(t_d, t_l, 
                    epochs=32, batch_size=16, 
                    shuffle=True, 
                    callbacks=[PrintBValueCallback()],
                    validation_data=(v_d, v_l))

complex_error_cast

First training batch: phase value is already 0. No valid complex value is being outputted. There is a float that is being cast back into complex. The complex output has no imag part.

Edit: Changed dtype from complex128 to complex64. Result is same casting error.

@NEGU93
Copy link
Owner

NEGU93 commented Jun 20, 2024

I think the default dtype is complex64 and not complex128. Can you try using all complex64 and see if it works?

@rjdw
Copy link
Author

rjdw commented Jun 24, 2024

Hi, sorry for late response.

I tried to use all complex64, but the same casting error occurs.
I've edited the above MRE to reflect this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants