diff --git a/models/analysis/multimodal_integrator.py b/models/analysis/multimodal_integrator.py index a1d5477..465a0f6 100644 --- a/models/analysis/multimodal_integrator.py +++ b/models/analysis/multimodal_integrator.py @@ -162,27 +162,43 @@ def forward( structure_features: torch.Tensor, function_results: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: + # Add debug prints for input shapes + print(f"Sequence features shape: {sequence_features.shape}") + print(f"Structure features shape: {structure_features.shape}") + print(f"Function results GO terms shape: {function_results['go_terms'].shape}") + # Transform function features to match dimensions function_features = self.function_encoder(function_results['go_terms']) + print(f"Transformed function features shape: {function_features.shape}") # Ensure all features have the same dimensions before combining batch_size = sequence_features.size(0) seq_len = sequence_features.size(1) + feature_dim = sequence_features.size(2) # Should be 768 + + # Reshape features if needed + sequence_features = sequence_features.view(batch_size, seq_len, feature_dim) + structure_features = structure_features.view(batch_size, seq_len, feature_dim) + function_features = function_features.view(batch_size, seq_len, feature_dim) - # Reshape function features if needed - function_features = function_features.view(batch_size, seq_len, -1) + print(f"Reshaped sequence features: {sequence_features.shape}") + print(f"Reshaped structure features: {structure_features.shape}") + print(f"Reshaped function features: {function_features.shape}") # Combine all features combined_features = torch.cat([ sequence_features, structure_features, function_features - ], dim=-1) + ], dim=-1) # Concatenate along feature dimension + + print(f"Combined features shape: {combined_features.shape}") - # Generate unified representation + # Process through integration network unified_features = self.integration_network(combined_features) + print(f"Unified features shape: {unified_features.shape}") - # Estimate prediction confidence + # Estimate confidence confidence = self.confidence_estimator(unified_features) return {