Skip to content

Commit

Permalink
Merge pull request #142 from VishwamAI/quantum-restructure
Browse files Browse the repository at this point in the history
Restructure quantum modules and update tests
  • Loading branch information
kasinadhsarma authored Oct 19, 2024
2 parents 6f5f38c + 02dda38 commit cbcf28d
Show file tree
Hide file tree
Showing 17 changed files with 217 additions and 181 deletions.
4 changes: 2 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
ignore = E8, E402, W503 # Add other codes you want to ignore
max-line-length = 88 # Adjust line length as needed
ignore = E801,E802,E803,E402,W503
max-line-length = 88
exclude = .git,__pycache__,dist,build # Exclude directories as needed
135 changes: 71 additions & 64 deletions NeuroFlex/advanced_models/multi_modal_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,30 +154,33 @@ def fuse_modalities(self, encoded_modalities: Dict[str, torch.Tensor]) -> torch.
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward pass through the multi-modal learning model."""
logger.debug(f"Input types: {[(name, type(tensor)) for name, tensor in inputs.items()]}")
logger.debug(f"Input shapes: {[(name, tensor.shape) for name, tensor in inputs.items()]}")
if not inputs:
raise ValueError("Input dictionary is empty")

# Check if only a single modality is provided
if len(inputs) == 1:
raise ValueError("At least two modalities are required for fusion")

# Ensure all inputs are tensors
# Ensure all inputs are tensors and have correct dtype
for name, tensor in inputs.items():
if not isinstance(tensor, torch.Tensor):
inputs[name] = torch.tensor(tensor, dtype=torch.float32)
inputs[name] = inputs[name].float() # Ensure all inputs are float tensors
logger.debug(f"Input {name} shape: {inputs[name].shape}, type: {type(inputs[name])}")
inputs[name] = torch.tensor(tensor)
if name == 'text':
inputs[name] = inputs[name].long()
else:
inputs[name] = inputs[name].float()
logger.debug(f"Input {name} shape: {inputs[name].shape}, type: {inputs[name].dtype}")

# Check for batch size consistency across all input modalities
batch_sizes = [tensor.size(0) for tensor in inputs.values()]
if len(set(batch_sizes)) > 1:
raise ValueError(f"Inconsistent batch sizes across modalities: {dict(zip(inputs.keys(), batch_sizes))}")

# Handle individual modality inputs
if set(inputs.keys()) != set(self.modalities.keys()):
missing_modalities = set(self.modalities.keys()) - set(inputs.keys())
for modality in missing_modalities:
inputs[modality] = torch.zeros((batch_sizes[0],) + self.modalities[modality]['input_shape'], dtype=torch.float32)
# Handle missing modalities
for modality in set(self.modalities.keys()) - set(inputs.keys()):
inputs[modality] = torch.zeros((batch_sizes[0],) + self.modalities[modality]['input_shape'], dtype=torch.float32)
logger.debug(f"Created zero tensor for missing modality {modality}: shape {inputs[modality].shape}")

max_batch_size = batch_sizes[0]

Expand All @@ -193,73 +196,66 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
if inputs[name].shape[1:] != modality['input_shape']:
raise ValueError(f"Input shape for {name} {inputs[name].shape} does not match the defined shape (batch_size, {modality['input_shape']})")

if name == 'image':
# For image modality, preserve the 4D structure
encoded_modalities[name] = modality['encoder'](inputs[name])
logger.debug(f"Encoded image shape: {encoded_modalities[name].shape}")
elif name == 'text':
# For text modality, ensure long type for embedding and float type for LSTM
text_input = inputs[name].long().clamp(0, 29999) # Clamp to valid range
logger.debug(f"Text input shape: {text_input.shape}, type: {type(text_input)}")
embedded = modality['encoder'][0](text_input)
logger.debug(f"Embedded shape: {embedded.shape}, type: {type(embedded)}")
lstm_out, _ = modality['encoder'][1](embedded.float()) # Unpack LSTM output
logger.debug(f"Raw LSTM output shape: {lstm_out.shape}, type: {type(lstm_out)}")
lstm_out = lstm_out[:, -1, :] # Use last time step output
logger.debug(f"LSTM output shape: {lstm_out.shape}, type: {type(lstm_out)}")
lstm_out = lstm_out.contiguous().view(lstm_out.size(0), -1) # Ensure correct shape
logger.debug(f"Reshaped LSTM output shape: {lstm_out.shape}, type: {type(lstm_out)}")
encoded_modalities[name] = modality['encoder'][2](lstm_out)
logger.debug(f"Encoded text shape: {encoded_modalities[name].shape}, type: {type(encoded_modalities[name])}")
elif name == 'time_series':
# For time series, ensure 3D input (batch_size, channels, sequence_length)
if inputs[name].dim() == 2:
inputs[name] = inputs[name].unsqueeze(1)
logger.debug(f"Time series input shape: {inputs[name].shape}, type: {type(inputs[name])}")
encoded_modalities[name] = modality['encoder'](inputs[name])
logger.debug(f"Encoded time series shape: {encoded_modalities[name].shape}")
elif name == 'tabular':
# For tabular data, ensure 2D input (batch_size, features)
logger.debug(f"Tabular input shape: {inputs[name].shape}, type: {type(inputs[name])}")
encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1))
logger.debug(f"Encoded tabular shape: {encoded_modalities[name].shape}")
else:
# For other modalities, flatten the input
logger.debug(f"Other modality input shape: {inputs[name].shape}, type: {type(inputs[name])}")
encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1))
logger.debug(f"Encoded other modality shape: {encoded_modalities[name].shape}")
try:
if name == 'image':
# For image modality, preserve the 4D structure
encoded_modalities[name] = modality['encoder'](inputs[name])
elif name == 'text':
# For text modality, ensure long type for embedding and float type for LSTM
text_input = inputs[name].long().clamp(0, 29999) # Clamp to valid range
embedded = modality['encoder'][0](text_input)
lstm_out, _ = modality['encoder'][1](embedded.float())
lstm_out = lstm_out[:, -1, :] # Use last time step output
encoded_modalities[name] = modality['encoder'][2](lstm_out)
elif name == 'time_series':
# For time series, ensure 3D input (batch_size, channels, sequence_length)
if inputs[name].dim() == 2:
inputs[name] = inputs[name].unsqueeze(1)
encoded_modalities[name] = modality['encoder'](inputs[name])
elif name == 'tabular':
# For tabular data, ensure 2D input (batch_size, features)
encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1))
else:
# For other modalities, flatten the input
encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1))

logger.debug(f"Encoded {name} shape: {encoded_modalities[name].shape}, type: {type(encoded_modalities[name])}")
logger.debug(f"Encoded {name} shape: {encoded_modalities[name].shape}, type: {encoded_modalities[name].dtype}")
except Exception as e:
logger.error(f"Error processing modality {name}: {str(e)}")
raise

# Ensure all encoded modalities have the same batch size and are 2D tensors
encoded_modalities = {name: tensor.view(max_batch_size, -1) for name, tensor in encoded_modalities.items()}
logger.debug(f"Encoded modalities shapes after reshaping: {[(name, tensor.shape) for name, tensor in encoded_modalities.items()]}")

if self.fusion_method == 'concatenation':
fused = torch.cat(list(encoded_modalities.values()), dim=1)
elif self.fusion_method == 'attention':
fused = self.fuse_modalities(encoded_modalities)
else:
raise ValueError(f"Unsupported fusion method: {self.fusion_method}")
try:
if self.fusion_method == 'concatenation':
fused = torch.cat(list(encoded_modalities.values()), dim=1)
elif self.fusion_method == 'attention':
fused = self.fuse_modalities(encoded_modalities)
else:
raise ValueError(f"Unsupported fusion method: {self.fusion_method}")

logger.debug(f"Fused tensor shape: {fused.shape}, type: {type(fused)}")
logger.debug(f"Fused tensor shape: {fused.shape}, type: {fused.dtype}")

# Ensure fused tensor is 2D and matches the classifier's input size
if fused.dim() != 2 or fused.size(1) != self.classifier.in_features:
fused = fused.view(max_batch_size, -1)
fused = nn.functional.adaptive_avg_pool1d(fused.unsqueeze(1), self.classifier.in_features).squeeze(1)
# Ensure fused tensor is 2D and matches the classifier's input size
if fused.dim() != 2 or fused.size(1) != self.classifier.in_features:
fused = fused.view(max_batch_size, -1)
fused = nn.functional.adaptive_avg_pool1d(fused.unsqueeze(1), self.classifier.in_features).squeeze(1)

# Ensure fused tensor is a valid input for the classifier
fused = fused.float() # Convert to float if not already
# Ensure fused tensor is a valid input for the classifier
fused = fused.float() # Convert to float if not already

logger.debug(f"Final fused tensor shape: {fused.shape}, type: {type(fused)}")
logger.debug(f"Final fused tensor shape: {fused.shape}, type: {fused.dtype}")
logger.debug(f"Classifier input shape: {fused.shape}")

# Ensure input to classifier is a tensor
if not isinstance(fused, torch.Tensor):
fused = torch.tensor(fused, dtype=torch.float32)
output = self.classifier(fused)
logger.debug(f"Final output shape: {output.shape}")

logger.debug(f"Classifier input shape: {fused.shape}, type: {type(fused)}")
return self.classifier(fused)
return output
except Exception as e:
logger.error(f"Error during fusion or classification: {str(e)}")
raise

def fit(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, val_data: Dict[str, torch.Tensor] = None, val_labels: torch.Tensor = None, epochs: int = 10, lr: float = 0.001, patience: int = 5, batch_size: int = 32):
"""Train the multi-modal learning model."""
Expand Down Expand Up @@ -415,6 +411,10 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti
correct_predictions = 0
total_samples = 0

logger.debug(f"_train_epoch input - data types: {[(k, type(v)) for k, v in data.items()]}")
logger.debug(f"_train_epoch input - data shapes: {[(k, v.shape) for k, v in data.items()]}")
logger.debug(f"_train_epoch input - labels type: {type(labels)}, shape: {labels.shape}")

if not data or not labels.numel():
logger.warning("Empty data or labels provided for training epoch.")
return 0.0, 0.0
Expand All @@ -425,14 +425,20 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti
num_batches = (num_samples + batch_size - 1) // batch_size

for i, (batch_data, batch_labels) in enumerate(self._batch_data(data, labels, batch_size)):
logger.debug(f"Batch {i+1}/{num_batches} - data types: {[(k, type(v)) for k, v in batch_data.items()]}")
logger.debug(f"Batch {i+1}/{num_batches} - data shapes: {[(k, v.shape) for k, v in batch_data.items()]}")
logger.debug(f"Batch {i+1}/{num_batches} - labels type: {type(batch_labels)}, shape: {batch_labels.shape}")

if not batch_data or not batch_labels.numel():
logger.warning(f"Empty batch encountered at iteration {i+1}/{num_batches}. Skipping.")
continue

print(f"\rTraining: {i+1}/{num_batches}", end="", flush=True)
optimizer.zero_grad()
outputs = self.forward(batch_data)
logger.debug(f"Batch {i+1}/{num_batches} - outputs type: {type(outputs)}, shape: {outputs.shape}")
loss = criterion(outputs, batch_labels)
logger.debug(f"Batch {i+1}/{num_batches} - loss type: {type(loss)}, value: {loss.item()}")
loss.backward()

# Gradient clipping
Expand Down Expand Up @@ -465,6 +471,7 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti

except Exception as e:
logger.error(f"Error in _train_epoch: {str(e)}")
logger.exception("Traceback:")
return 0.0, 0.0

def _validate(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, criterion: nn.Module, batch_size: int = 32) -> Tuple[float, float]:
Expand Down
Loading

0 comments on commit cbcf28d

Please sign in to comment.