diff --git a/models/analysis/multimodal_integrator.py b/models/analysis/multimodal_integrator.py index d3080fe..a1d5477 100644 --- a/models/analysis/multimodal_integrator.py +++ b/models/analysis/multimodal_integrator.py @@ -142,10 +142,11 @@ def __init__(self, hidden_size: int): ) self.integration_network = nn.Sequential( - nn.Linear(2304, 1536), # 3 * 768-dim features + nn.Linear(768 * 3, 1536), # Concatenated features from all three modalities + nn.LayerNorm(1536), # Normalize combined features nn.ReLU(), nn.Dropout(0.1), - nn.Linear(1536, 768) # Match ESM2 dimensions + nn.Linear(1536, 768) # Final output dimension ) self.confidence_estimator = nn.Sequential( @@ -164,6 +165,13 @@ def forward( # Transform function features to match dimensions function_features = self.function_encoder(function_results['go_terms']) + # Ensure all features have the same dimensions before combining + batch_size = sequence_features.size(0) + seq_len = sequence_features.size(1) + + # Reshape function features if needed + function_features = function_features.view(batch_size, seq_len, -1) + # Combine all features combined_features = torch.cat([ sequence_features,