Skip to content

Commit

Permalink
Update structure predictor to handle 768-dimensional features from se…
Browse files Browse the repository at this point in the history
…quence analyzer
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent d8e9f26 commit ad2f9b3
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions models/analysis/structure_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def __init__(self, config: Dict = None):
"""Initialize the structure predictor"""
super().__init__()
self.config = config or {}
self.hidden_size = 320 # Match ESM2's output dimension
self.hidden_size = 768 # Match sequence analyzer's output dimension

# Feature processing networks
self.feature_processor = nn.Sequential(
nn.Linear(320, 320),
nn.Linear(768, 768),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(320, 320)
nn.Linear(768, 768)
)

# Initialize contact map predictor
Expand All @@ -32,8 +32,8 @@ def __init__(self, config: Dict = None):
# Initialize structure refiner
self.structure_refiner = StructureRefiner(
config={
'input_dim': 320,
'hidden_dim': 320,
'input_dim': 768,
'hidden_dim': 768,
'refinement_steps': 100,
'refinement_lr': 0.01
}
Expand All @@ -42,10 +42,10 @@ def __init__(self, config: Dict = None):
def forward(self, sequence_features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass for structure prediction"""
# Predict backbone features
backbone_features = self.feature_processor(sequence_features) # [batch, seq_len, 320]
backbone_features = self.feature_processor(sequence_features) # [batch, seq_len, 768]

# Predict side chain features
side_chain_features = self.feature_processor(backbone_features) # [batch, seq_len, 320]
side_chain_features = self.feature_processor(backbone_features) # [batch, seq_len, 768]

# Predict contact map
contact_map = self.contact_predictor(sequence_features) # [batch, seq_len, seq_len]
Expand All @@ -71,11 +71,11 @@ def predict_structure(self, sequence_features: torch.Tensor) -> Dict[str, torch.
class ContactMapPredictor(nn.Module):
def __init__(self):
super().__init__()
self.attention = nn.MultiheadAttention(320, num_heads=8) # Match ESM2 dimensions
self.attention = nn.MultiheadAttention(768, num_heads=8) # Match sequence analyzer dimensions
self.mlp = nn.Sequential(
nn.Linear(320, 160), # Input from attention
nn.Linear(768, 384), # Input from attention
nn.ReLU(),
nn.Linear(160, 1) # Output single contact probability
nn.Linear(384, 1) # Output single contact probability
)

def forward(self, features: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -103,27 +103,27 @@ def __init__(self, config: Dict[str, Any] = None):

# Initialize feature processors
self.backbone_processor = nn.Sequential(
nn.Linear(320, 160),
nn.Linear(768, 384),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(160, 4) # phi, psi, omega angles
nn.Linear(384, 4) # phi, psi, omega angles
)

self.side_chain_processor = nn.Sequential(
nn.Linear(320, 160),
nn.Linear(768, 384),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(160, 4) # chi1, chi2, chi3, chi4 angles
nn.Linear(384, 4) # chi1, chi2, chi3, chi4 angles
)

# Initialize position predictor for 3D coordinates
self.position_predictor = nn.Sequential(
nn.Linear(640, 320), # Combined backbone and side chain features
nn.Linear(1536, 768), # Combined backbone and side chain features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(320, 160),
nn.Linear(768, 384),
nn.ReLU(),
nn.Linear(160, 3) # x, y, z coordinates
nn.Linear(384, 3) # x, y, z coordinates
)

def forward(self, backbone_features: torch.Tensor,
Expand Down

0 comments on commit ad2f9b3

Please sign in to comment.