Skip to content

Commit

Permalink
Update multimodal integration with proper dimension handling and geom…
Browse files Browse the repository at this point in the history
…etric constraints
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent 9e36c6c commit c74f720
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 60 deletions.
16 changes: 8 additions & 8 deletions models/analysis/multimodal_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def analyze_protein(self, sequence: str) -> Dict[str, torch.Tensor]:
class CrossModalAttention(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.sequence_attention = nn.MultiheadAttention(80, num_heads=8) # Match feature dimensions
self.structure_attention = nn.MultiheadAttention(80, num_heads=8) # Match feature dimensions
self.sequence_attention = nn.MultiheadAttention(768, num_heads=8) # Match ESM2 dimensions
self.structure_attention = nn.MultiheadAttention(768, num_heads=8) # Match ESM2 dimensions

self.feature_combiner = nn.Sequential(
nn.Linear(160, 80), # Combine 80-dim features
nn.Linear(1536, 768), # Combine 768-dim features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(80, 80) # Output 80-dim features
nn.Linear(768, 768) # Output 768-dim features
)

def forward(
Expand Down Expand Up @@ -111,16 +111,16 @@ class UnifiedPredictor(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.integration_network = nn.Sequential(
nn.Linear(240, 160), # 3 * 80-dim features
nn.Linear(2304, 1536), # 3 * 768-dim features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(160, 80) # Output 80-dim features
nn.Linear(1536, 768) # Output 768-dim features
)

self.confidence_estimator = nn.Sequential(
nn.Linear(80, 40),
nn.Linear(768, 384),
nn.ReLU(),
nn.Linear(40, 1),
nn.Linear(384, 1),
nn.Sigmoid()
)

Expand Down
156 changes: 104 additions & 52 deletions models/analysis/structure_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,56 @@
from transformers import AutoModel

class StructurePredictor(nn.Module):
def __init__(self, config: Dict):
def __init__(self, config: Dict = None):
super().__init__()
self.config = config
self.hidden_size = config.get('hidden_size', 320)
if config is None:
config = {}

# Backbone prediction network
self.backbone_predictor = nn.Sequential(
nn.Linear(80, 160), # Input from feature extractor (80-dim)
hidden_size = config.get('hidden_size', 768) # Match ESM2 dimensions

# Initialize backbone prediction network
self.backbone_network = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(160, 3) # (phi, psi, omega) angles
nn.Linear(hidden_size, hidden_size)
)

# Side chain optimization network
self.side_chain_optimizer = nn.Sequential(
nn.Linear(83, 160), # 80-dim features + 3 backbone angles
nn.ReLU(),
nn.Linear(160, 80),
# Initialize side chain optimization network
self.side_chain_network = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(80, 4) # chi angles
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size)
)

# Contact map prediction
self.contact_predictor = ContactMapPredictor(80)
# Initialize contact map predictor
self.contact_predictor = ContactMapPredictor()

# Structure refinement module
# Initialize structure refinement
self.structure_refiner = StructureRefiner()

def forward(self, sequence_features: torch.Tensor) -> Dict[str, torch.Tensor]:
# Predict backbone angles
backbone_angles = self.backbone_predictor(sequence_features)
"""Forward pass for structure prediction"""
# Predict backbone features
backbone_features = self.backbone_network(sequence_features) # [batch, seq_len, 768]

# Predict side chain angles
side_chain_input = torch.cat([sequence_features, backbone_angles], dim=-1)
side_chain_angles = self.side_chain_optimizer(side_chain_input)
# Predict side chain features
side_chain_features = self.side_chain_network(backbone_features) # [batch, seq_len, 768]

# Predict contact map
contact_map = self.contact_predictor(sequence_features)
contact_map = self.contact_predictor(sequence_features) # [batch, seq_len, seq_len]

# Refine structure
refined_structure = self.structure_refiner(
backbone_angles,
side_chain_angles,
contact_map
backbone_features=backbone_features,
side_chain_features=side_chain_features,
contact_map=contact_map
)

return {
'backbone_angles': backbone_angles,
'side_chain_angles': side_chain_angles,
'backbone_features': backbone_features,
'side_chain_features': side_chain_features,
'contact_map': contact_map,
'refined_structure': refined_structure
}
Expand All @@ -69,13 +70,13 @@ def predict_structure(self, sequence_features: torch.Tensor) -> Dict[str, torch.
return self.forward(sequence_features)

class ContactMapPredictor(nn.Module):
def __init__(self, hidden_size: int):
def __init__(self):
super().__init__()
self.attention = nn.MultiheadAttention(80, num_heads=8) # Updated to match feature dim
self.attention = nn.MultiheadAttention(768, num_heads=8) # Match ESM2 dimensions
self.mlp = nn.Sequential(
nn.Linear(80, 40),
nn.Linear(768, 384), # Input from attention
nn.ReLU(),
nn.Linear(40, 1)
nn.Linear(384, 1) # Output single contact probability
)

def forward(self, features: torch.Tensor) -> torch.Tensor:
Expand All @@ -98,40 +99,91 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
class StructureRefiner(nn.Module):
def __init__(self):
super().__init__()
self.refinement_network = nn.Sequential(
nn.Linear(7, 128), # 3 backbone + 4 side chain angles
# Feature processing networks
self.backbone_processor = nn.Sequential(
nn.Linear(768, 384),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 7)
nn.Dropout(0.1),
nn.Linear(384, 3) # Output (phi, psi, omega) angles
)

def forward(
self,
backbone_angles: torch.Tensor,
side_chain_angles: torch.Tensor,
contact_map: torch.Tensor
) -> torch.Tensor:
"""Refine predicted structure using geometric constraints"""
# Combine angles
combined_angles = torch.cat([backbone_angles, side_chain_angles], dim=-1)
self.side_chain_processor = nn.Sequential(
nn.Linear(768, 384),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(384, 4) # Output chi angles
)

# Apply refinement
refined_angles = self.refinement_network(combined_angles)
def forward(self, backbone_features: torch.Tensor,
side_chain_features: torch.Tensor,
contact_map: torch.Tensor) -> torch.Tensor:
"""Refine protein structure using predicted features"""
# Process backbone and side chain features into angles
backbone_angles = self.backbone_processor(backbone_features)
side_chain_angles = self.side_chain_processor(side_chain_features)

# Apply contact map constraints
refined_structure = self._apply_contact_constraints(refined_angles, contact_map)
refined_structure = self._apply_contact_constraints(
backbone_angles, side_chain_angles, contact_map
)

return refined_structure

def _apply_contact_constraints(
self,
angles: torch.Tensor,
backbone_angles: torch.Tensor,
side_chain_angles: torch.Tensor,
contact_map: torch.Tensor
) -> torch.Tensor:
"""Apply contact map constraints to refine structure"""
# Implementation of contact-based refinement
return angles # Placeholder for actual implementation
"""Apply contact map constraints to refine the predicted structure"""
batch_size, seq_len, _ = backbone_angles.shape

# Initialize structure tensor
structure = torch.zeros(batch_size, seq_len, 3, device=backbone_angles.device)

# Convert angles to 3D coordinates
for i in range(seq_len):
if i > 0:
# Use previous residue position and current angles
prev_pos = structure[:, i-1]
curr_backbone = backbone_angles[:, i]
curr_sidechain = side_chain_angles[:, i]

# Calculate new position using geometric transformations
phi, psi, omega = curr_backbone.unbind(-1)
chi1, chi2, chi3, chi4 = curr_sidechain.unbind(-1)

# Apply geometric transformations (simplified for demonstration)
new_pos = prev_pos + torch.stack([
torch.cos(phi) * torch.cos(psi),
torch.sin(phi) * torch.cos(psi),
torch.sin(psi)
], dim=-1) * 3.8 # Approximate CA-CA distance

structure[:, i] = new_pos

# Apply contact map constraints through gradient descent
structure.requires_grad_(True)
optimizer = torch.optim.Adam([structure], lr=0.01)

for _ in range(50): # Refinement iterations
optimizer.zero_grad()

# Calculate pairwise distances
dists = torch.cdist(structure, structure)

# Contact map loss
contact_loss = torch.mean((dists - 8.0).abs() * contact_map) # 8Å threshold

# Chain connectivity loss
chain_loss = torch.mean((dists.diagonal(dim1=1, dim2=2) - 3.8).abs())

# Total loss
loss = contact_loss + chain_loss
loss.backward()
optimizer.step()

return structure.detach()


def create_structure_predictor(config: Dict) -> StructurePredictor:
Expand Down

0 comments on commit c74f720

Please sign in to comment.