Skip to content

Commit

Permalink
Add detailed logging to diagnose dimension mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] committed Jul 9, 2024
1 parent 4077e05 commit aa546af
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions scripts/deep_learning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ def call(self, inputs):

# Reshape state_vector to match the expected shape for matrix multiplication
reshaped_state_vector = tf.reshape(self.state_vector, [self.num_particles, 3]) # Reshape to (num_particles, 3)
tf.print("Shape of reshaped_state_vector:", tf.shape(reshaped_state_vector))

# Compute particle weights based on the EEG measurement model
predicted_measurements = tf.cast(eeg_measurement_model(reshaped_state_vector, self.forward_matrix), tf.float32)
tf.print("Shape of predicted_measurements:", tf.shape(predicted_measurements))

# Print shapes for debugging
tf.print("Shape of inputs:", tf.shape(inputs))
tf.print("Shape of reshaped_state_vector:", tf.shape(reshaped_state_vector))
tf.print("Shape of predicted_measurements:", tf.shape(predicted_measurements))
tf.print("Inputs:", inputs)
tf.print("Predicted measurements:", predicted_measurements)

Expand All @@ -156,6 +156,7 @@ def call(self, inputs):

# Ensure the total number of elements in inputs matches predicted_measurements
if tf.reduce_prod(input_shape[1:]) != tf.reduce_prod(predicted_shape):
tf.print("Dimension mismatch: reshaped inputs shape", input_shape[1:], "does not match predicted_measurements shape", predicted_shape)
raise ValueError(f"Dimension mismatch: reshaped inputs shape {input_shape[1:]} does not match predicted_measurements shape {predicted_shape}")

# Reshape inputs to match the shape of predicted_measurements
Expand Down

0 comments on commit aa546af

Please sign in to comment.