-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement advanced sampling techniques
- Add confidence-guided sampler with dynamic noise scheduling - Add energy-based sampler with structure validation - Add attention-based sampler with structure-aware attention - Add graph-based sampler with message passing - Add comprehensive test suite for all samplers This implementation integrates cutting-edge sampling techniques for improved protein generation, including: - Dynamic confidence estimation - Energy-based refinement - Structure-aware attention routing - Graph-based message passing - Local structure preservation
- Loading branch information
1 parent
62b3e7b
commit 67d9252
Showing
9 changed files
with
1,223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
""" | ||
Sampling module for ProteinFlex. | ||
Implements advanced protein generation sampling techniques. | ||
""" | ||
|
||
from .confidence_guided_sampler import ConfidenceGuidedSampler | ||
|
||
__all__ = ['ConfidenceGuidedSampler'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
Attention-Based Sampling implementation for ProteinFlex. | ||
Implements structure-aware attention routing for protein generation. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from typing import Dict, Tuple, Optional | ||
|
||
class StructureAwareAttention(nn.Module): | ||
def __init__( | ||
self, | ||
feature_dim: int, | ||
num_heads: int = 8, | ||
dropout: float = 0.1 | ||
): | ||
"""Initialize structure-aware attention.""" | ||
super().__init__() | ||
self.num_heads = num_heads | ||
self.head_dim = feature_dim // num_heads | ||
assert self.head_dim * num_heads == feature_dim, "feature_dim must be divisible by num_heads" | ||
|
||
self.qkv = nn.Linear(feature_dim, 3 * feature_dim) | ||
self.structure_proj = nn.Linear(feature_dim, feature_dim) | ||
self.output_proj = nn.Linear(feature_dim, feature_dim) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
structure_bias: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
"""Forward pass with optional structure bias.""" | ||
B, L, D = x.shape | ||
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | ||
q, k, v = qkv[0], qkv[1], qkv[2] | ||
|
||
# Compute attention scores | ||
attn = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_dim))) | ||
|
||
# Add structure bias if provided | ||
if structure_bias is not None: | ||
structure_weights = self.structure_proj(structure_bias) | ||
structure_weights = structure_weights.view(B, 1, L, L) | ||
attn = attn + structure_weights | ||
|
||
attn = F.softmax(attn, dim=-1) | ||
attn = self.dropout(attn) | ||
|
||
# Apply attention to values | ||
x = (attn @ v).transpose(1, 2).reshape(B, L, D) | ||
x = self.output_proj(x) | ||
|
||
return x | ||
|
||
class AttentionBasedSampler(nn.Module): | ||
def __init__( | ||
self, | ||
feature_dim: int = 768, | ||
hidden_dim: int = 512, | ||
num_layers: int = 6, | ||
num_heads: int = 8, | ||
dropout: float = 0.1 | ||
): | ||
""" | ||
Initialize Attention-Based Sampler. | ||
Args: | ||
feature_dim: Dimension of protein features | ||
hidden_dim: Hidden dimension for feed-forward | ||
num_layers: Number of transformer layers | ||
num_heads: Number of attention heads | ||
dropout: Dropout rate | ||
""" | ||
super().__init__() | ||
self.feature_dim = feature_dim | ||
self.hidden_dim = hidden_dim | ||
|
||
# Structure encoder | ||
self.structure_encoder = nn.Sequential( | ||
nn.Linear(feature_dim, hidden_dim), | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim, feature_dim) | ||
) | ||
|
||
# Transformer layers | ||
self.layers = nn.ModuleList([ | ||
nn.ModuleDict({ | ||
'attention': StructureAwareAttention(feature_dim, num_heads, dropout), | ||
'norm1': nn.LayerNorm(feature_dim), | ||
'ff': nn.Sequential( | ||
nn.Linear(feature_dim, hidden_dim), | ||
nn.ReLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim, feature_dim) | ||
), | ||
'norm2': nn.LayerNorm(feature_dim) | ||
}) for _ in range(num_layers) | ||
]) | ||
|
||
# Output projection | ||
self.output_proj = nn.Linear(feature_dim, feature_dim) | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
structure_info: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
""" | ||
Forward pass for training. | ||
Args: | ||
x: Input protein features [batch_size, seq_len, feature_dim] | ||
structure_info: Optional structure information | ||
Returns: | ||
Processed features | ||
""" | ||
# Process structure information if provided | ||
structure_bias = None | ||
if structure_info is not None: | ||
structure_bias = self.structure_encoder(structure_info) | ||
|
||
# Apply transformer layers | ||
for layer in self.layers: | ||
# Attention with structure bias | ||
attn_out = layer['attention']( | ||
layer['norm1'](x), | ||
structure_bias | ||
) | ||
x = x + attn_out | ||
|
||
# Feed-forward | ||
ff_out = layer['ff'](layer['norm2'](x)) | ||
x = x + ff_out | ||
|
||
return self.output_proj(x) | ||
|
||
def sample( | ||
self, | ||
batch_size: int, | ||
seq_len: int, | ||
device: torch.device, | ||
structure_info: Optional[torch.Tensor] = None, | ||
temperature: float = 1.0 | ||
) -> torch.Tensor: | ||
""" | ||
Generate protein features using attention-based sampling. | ||
Args: | ||
batch_size: Number of samples to generate | ||
seq_len: Sequence length | ||
device: Device to generate on | ||
structure_info: Optional structure information | ||
temperature: Sampling temperature | ||
Returns: | ||
Generated protein features | ||
""" | ||
# Initialize from random | ||
x = torch.randn(batch_size, seq_len, self.feature_dim, device=device) | ||
|
||
# Apply temperature scaling | ||
x = x * temperature | ||
|
||
# Generate features with structure guidance | ||
return self.forward(x, structure_info) | ||
|
||
def compute_loss( | ||
self, | ||
pred_features: torch.Tensor, | ||
target_features: torch.Tensor, | ||
structure_info: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
""" | ||
Compute training loss. | ||
Args: | ||
pred_features: Predicted protein features | ||
target_features: Target protein features | ||
structure_info: Optional structure information | ||
Returns: | ||
Loss value | ||
""" | ||
# Feature reconstruction loss | ||
recon_loss = F.mse_loss(pred_features, target_features) | ||
|
||
# Structure-aware loss if structure info provided | ||
if structure_info is not None: | ||
structure_pred = self.structure_encoder(pred_features) | ||
structure_loss = F.mse_loss(structure_pred, structure_info) | ||
return recon_loss + structure_loss | ||
|
||
return recon_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
""" | ||
Confidence-Guided Sampling implementation for ProteinFlex. | ||
Based on recent advances in protein structure generation. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from typing import Dict, Tuple, Optional | ||
|
||
class ConfidenceGuidedSampler(nn.Module): | ||
def __init__( | ||
self, | ||
feature_dim: int = 768, | ||
hidden_dim: int = 512, | ||
num_steps: int = 1000, | ||
min_beta: float = 1e-4, | ||
max_beta: float = 0.02 | ||
): | ||
""" | ||
Initialize the Confidence-Guided Sampler. | ||
Args: | ||
feature_dim: Dimension of protein features | ||
hidden_dim: Hidden dimension for confidence network | ||
num_steps: Number of diffusion steps | ||
min_beta: Minimum noise schedule value | ||
max_beta: Maximum noise schedule value | ||
""" | ||
super().__init__() | ||
self.feature_dim = feature_dim | ||
self.hidden_dim = hidden_dim | ||
|
||
# Confidence estimation network | ||
self.confidence_net = nn.Sequential( | ||
nn.Linear(feature_dim, hidden_dim), | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim, hidden_dim // 2), | ||
nn.LayerNorm(hidden_dim // 2), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim // 2, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
# Noise prediction network | ||
self.noise_pred_net = nn.Sequential( | ||
nn.Linear(feature_dim + hidden_dim, hidden_dim), # +hidden_dim for time embedding | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim, hidden_dim), | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(hidden_dim, feature_dim) | ||
) | ||
|
||
# Setup noise schedule | ||
self.num_steps = num_steps | ||
self.register_buffer('betas', torch.linspace(min_beta, max_beta, num_steps)) | ||
alphas = 1 - self.betas | ||
self.register_buffer('alphas_cumprod', torch.cumprod(alphas, dim=0)) | ||
|
||
def get_time_embedding(self, t: torch.Tensor) -> torch.Tensor: | ||
"""Generate time embeddings.""" | ||
half_dim = self.hidden_dim // 2 | ||
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | ||
embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings) | ||
embeddings = t[:, None] * embeddings[None, :] | ||
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | ||
return embeddings | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
noise: Optional[torch.Tensor] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Forward pass for training. | ||
Args: | ||
x: Input protein features [batch_size, seq_len, feature_dim] | ||
noise: Optional pre-generated noise | ||
Returns: | ||
Tuple of (noisy features, predicted noise) | ||
""" | ||
batch_size = x.shape[0] | ||
|
||
# Generate noise if not provided | ||
if noise is None: | ||
noise = torch.randn_like(x) | ||
|
||
# Sample timestep | ||
t = torch.randint(0, self.num_steps, (batch_size,), device=x.device) | ||
|
||
# Get noise scaling | ||
a = self.alphas_cumprod[t] | ||
a = a.view(-1, 1, 1) | ||
|
||
# Add noise to input | ||
noisy_x = torch.sqrt(a) * x + torch.sqrt(1 - a) * noise | ||
|
||
# Predict noise | ||
time_emb = self.get_time_embedding(t) | ||
time_emb = time_emb.view(batch_size, 1, -1).expand(-1, x.shape[1], -1) | ||
pred_input = torch.cat([noisy_x, time_emb], dim=-1) | ||
pred_noise = self.noise_pred_net(pred_input) | ||
|
||
return noisy_x, pred_noise | ||
|
||
def sample( | ||
self, | ||
batch_size: int, | ||
seq_len: int, | ||
device: torch.device, | ||
temperature: float = 1.0 | ||
) -> torch.Tensor: | ||
""" | ||
Generate protein features using confidence-guided sampling. | ||
Args: | ||
batch_size: Number of samples to generate | ||
seq_len: Sequence length | ||
device: Device to generate on | ||
temperature: Sampling temperature | ||
Returns: | ||
Generated protein features | ||
""" | ||
# Start from random noise | ||
x = torch.randn(batch_size, seq_len, self.feature_dim, device=device) | ||
|
||
# Iterative refinement | ||
for t in reversed(range(self.num_steps)): | ||
# Get confidence score | ||
confidence = self.confidence_net(x) | ||
|
||
# Predict and remove noise | ||
time_emb = self.get_time_embedding(torch.tensor([t], device=device)) | ||
time_emb = time_emb.expand(batch_size, seq_len, -1) | ||
pred_input = torch.cat([x, time_emb], dim=-1) | ||
pred_noise = self.noise_pred_net(pred_input) | ||
|
||
# Update features based on confidence | ||
alpha = self.alphas_cumprod[t] | ||
alpha_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0, device=device) | ||
beta = 1 - alpha / alpha_prev | ||
|
||
# Apply confidence-guided update | ||
mean = (x - beta * pred_noise) / torch.sqrt(1 - beta) | ||
var = beta * temperature * (1 - confidence) | ||
x = mean + torch.sqrt(var) * torch.randn_like(x) | ||
|
||
return x | ||
|
||
def compute_loss( | ||
self, | ||
x: torch.Tensor, | ||
noise: torch.Tensor, | ||
pred_noise: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Compute training loss. | ||
Args: | ||
x: Input protein features | ||
noise: Target noise | ||
pred_noise: Predicted noise | ||
Returns: | ||
Combined loss value | ||
""" | ||
# MSE loss for noise prediction | ||
noise_loss = F.mse_loss(pred_noise, noise) | ||
|
||
# Confidence loss to encourage accurate confidence estimation | ||
confidence = self.confidence_net(x) | ||
confidence_target = torch.exp(-F.mse_loss(pred_noise, noise, reduction='none').mean(-1)) | ||
confidence_loss = F.binary_cross_entropy(confidence, confidence_target) | ||
|
||
return noise_loss + confidence_loss |
Oops, something went wrong.