From d97446d6ad886b40e7e76b7b4bc3d461166d18e6 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:15:07 +0000 Subject: [PATCH] Fix syntax error and implement global average pooling for confidence estimation --- models/analysis/multimodal_integrator.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/models/analysis/multimodal_integrator.py b/models/analysis/multimodal_integrator.py index fde949e..55967af 100644 --- a/models/analysis/multimodal_integrator.py +++ b/models/analysis/multimodal_integrator.py @@ -157,6 +157,7 @@ def __init__(self, hidden_size: int): nn.Linear(1536, 768) # Final output dimension ) + # Global average pooling followed by confidence estimation self.confidence_estimator = nn.Sequential( nn.Linear(768, 384), nn.ReLU(), @@ -210,8 +211,13 @@ def forward( unified_features = self.integration_network(combined_features) print(f"Unified features shape: {unified_features.shape}") - # Estimate confidence - confidence = self.confidence_estimator(unified_features) + # Global average pooling for confidence estimation + pooled_features = torch.mean(unified_features, dim=1) # Average across sequence length + print(f"Pooled features shape: {pooled_features.shape}") + + # Estimate confidence (single value per protein) + confidence = self.confidence_estimator(pooled_features) + print(f"Confidence shape: {confidence.shape}") return { 'unified_features': unified_features,