Skip to content

Commit

Permalink
feat: Implement advanced sampling techniques
Browse files Browse the repository at this point in the history
- Add confidence-guided sampler with dynamic noise scheduling
- Add energy-based sampler with structure validation
- Add attention-based sampler with structure-aware attention
- Add graph-based sampler with message passing
- Add comprehensive test suite for all samplers

This implementation integrates cutting-edge sampling techniques
for improved protein generation, including:
- Dynamic confidence estimation
- Energy-based refinement
- Structure-aware attention routing
- Graph-based message passing
- Local structure preservation
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent 62b3e7b commit 67d9252
Show file tree
Hide file tree
Showing 9 changed files with 1,223 additions and 0 deletions.
8 changes: 8 additions & 0 deletions models/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Sampling module for ProteinFlex.
Implements advanced protein generation sampling techniques.
"""

from .confidence_guided_sampler import ConfidenceGuidedSampler

__all__ = ['ConfidenceGuidedSampler']
197 changes: 197 additions & 0 deletions models/sampling/attention_based_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
Attention-Based Sampling implementation for ProteinFlex.
Implements structure-aware attention routing for protein generation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional

class StructureAwareAttention(nn.Module):
def __init__(
self,
feature_dim: int,
num_heads: int = 8,
dropout: float = 0.1
):
"""Initialize structure-aware attention."""
super().__init__()
self.num_heads = num_heads
self.head_dim = feature_dim // num_heads
assert self.head_dim * num_heads == feature_dim, "feature_dim must be divisible by num_heads"

self.qkv = nn.Linear(feature_dim, 3 * feature_dim)
self.structure_proj = nn.Linear(feature_dim, feature_dim)
self.output_proj = nn.Linear(feature_dim, feature_dim)
self.dropout = nn.Dropout(dropout)

def forward(
self,
x: torch.Tensor,
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass with optional structure bias."""
B, L, D = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

# Compute attention scores
attn = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_dim)))

# Add structure bias if provided
if structure_bias is not None:
structure_weights = self.structure_proj(structure_bias)
structure_weights = structure_weights.view(B, 1, L, L)
attn = attn + structure_weights

attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)

# Apply attention to values
x = (attn @ v).transpose(1, 2).reshape(B, L, D)
x = self.output_proj(x)

return x

class AttentionBasedSampler(nn.Module):
def __init__(
self,
feature_dim: int = 768,
hidden_dim: int = 512,
num_layers: int = 6,
num_heads: int = 8,
dropout: float = 0.1
):
"""
Initialize Attention-Based Sampler.
Args:
feature_dim: Dimension of protein features
hidden_dim: Hidden dimension for feed-forward
num_layers: Number of transformer layers
num_heads: Number of attention heads
dropout: Dropout rate
"""
super().__init__()
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim

# Structure encoder
self.structure_encoder = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, feature_dim)
)

# Transformer layers
self.layers = nn.ModuleList([
nn.ModuleDict({
'attention': StructureAwareAttention(feature_dim, num_heads, dropout),
'norm1': nn.LayerNorm(feature_dim),
'ff': nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, feature_dim)
),
'norm2': nn.LayerNorm(feature_dim)
}) for _ in range(num_layers)
])

# Output projection
self.output_proj = nn.Linear(feature_dim, feature_dim)

def forward(
self,
x: torch.Tensor,
structure_info: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass for training.
Args:
x: Input protein features [batch_size, seq_len, feature_dim]
structure_info: Optional structure information
Returns:
Processed features
"""
# Process structure information if provided
structure_bias = None
if structure_info is not None:
structure_bias = self.structure_encoder(structure_info)

# Apply transformer layers
for layer in self.layers:
# Attention with structure bias
attn_out = layer['attention'](
layer['norm1'](x),
structure_bias
)
x = x + attn_out

# Feed-forward
ff_out = layer['ff'](layer['norm2'](x))
x = x + ff_out

return self.output_proj(x)

def sample(
self,
batch_size: int,
seq_len: int,
device: torch.device,
structure_info: Optional[torch.Tensor] = None,
temperature: float = 1.0
) -> torch.Tensor:
"""
Generate protein features using attention-based sampling.
Args:
batch_size: Number of samples to generate
seq_len: Sequence length
device: Device to generate on
structure_info: Optional structure information
temperature: Sampling temperature
Returns:
Generated protein features
"""
# Initialize from random
x = torch.randn(batch_size, seq_len, self.feature_dim, device=device)

# Apply temperature scaling
x = x * temperature

# Generate features with structure guidance
return self.forward(x, structure_info)

def compute_loss(
self,
pred_features: torch.Tensor,
target_features: torch.Tensor,
structure_info: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute training loss.
Args:
pred_features: Predicted protein features
target_features: Target protein features
structure_info: Optional structure information
Returns:
Loss value
"""
# Feature reconstruction loss
recon_loss = F.mse_loss(pred_features, target_features)

# Structure-aware loss if structure info provided
if structure_info is not None:
structure_pred = self.structure_encoder(pred_features)
structure_loss = F.mse_loss(structure_pred, structure_info)
return recon_loss + structure_loss

return recon_loss
182 changes: 182 additions & 0 deletions models/sampling/confidence_guided_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Confidence-Guided Sampling implementation for ProteinFlex.
Based on recent advances in protein structure generation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Tuple, Optional

class ConfidenceGuidedSampler(nn.Module):
def __init__(
self,
feature_dim: int = 768,
hidden_dim: int = 512,
num_steps: int = 1000,
min_beta: float = 1e-4,
max_beta: float = 0.02
):
"""
Initialize the Confidence-Guided Sampler.
Args:
feature_dim: Dimension of protein features
hidden_dim: Hidden dimension for confidence network
num_steps: Number of diffusion steps
min_beta: Minimum noise schedule value
max_beta: Maximum noise schedule value
"""
super().__init__()
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim

# Confidence estimation network
self.confidence_net = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)

# Noise prediction network
self.noise_pred_net = nn.Sequential(
nn.Linear(feature_dim + hidden_dim, hidden_dim), # +hidden_dim for time embedding
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, feature_dim)
)

# Setup noise schedule
self.num_steps = num_steps
self.register_buffer('betas', torch.linspace(min_beta, max_beta, num_steps))
alphas = 1 - self.betas
self.register_buffer('alphas_cumprod', torch.cumprod(alphas, dim=0))

def get_time_embedding(self, t: torch.Tensor) -> torch.Tensor:
"""Generate time embeddings."""
half_dim = self.hidden_dim // 2
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
embeddings = t[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings

def forward(
self,
x: torch.Tensor,
noise: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for training.
Args:
x: Input protein features [batch_size, seq_len, feature_dim]
noise: Optional pre-generated noise
Returns:
Tuple of (noisy features, predicted noise)
"""
batch_size = x.shape[0]

# Generate noise if not provided
if noise is None:
noise = torch.randn_like(x)

# Sample timestep
t = torch.randint(0, self.num_steps, (batch_size,), device=x.device)

# Get noise scaling
a = self.alphas_cumprod[t]
a = a.view(-1, 1, 1)

# Add noise to input
noisy_x = torch.sqrt(a) * x + torch.sqrt(1 - a) * noise

# Predict noise
time_emb = self.get_time_embedding(t)
time_emb = time_emb.view(batch_size, 1, -1).expand(-1, x.shape[1], -1)
pred_input = torch.cat([noisy_x, time_emb], dim=-1)
pred_noise = self.noise_pred_net(pred_input)

return noisy_x, pred_noise

def sample(
self,
batch_size: int,
seq_len: int,
device: torch.device,
temperature: float = 1.0
) -> torch.Tensor:
"""
Generate protein features using confidence-guided sampling.
Args:
batch_size: Number of samples to generate
seq_len: Sequence length
device: Device to generate on
temperature: Sampling temperature
Returns:
Generated protein features
"""
# Start from random noise
x = torch.randn(batch_size, seq_len, self.feature_dim, device=device)

# Iterative refinement
for t in reversed(range(self.num_steps)):
# Get confidence score
confidence = self.confidence_net(x)

# Predict and remove noise
time_emb = self.get_time_embedding(torch.tensor([t], device=device))
time_emb = time_emb.expand(batch_size, seq_len, -1)
pred_input = torch.cat([x, time_emb], dim=-1)
pred_noise = self.noise_pred_net(pred_input)

# Update features based on confidence
alpha = self.alphas_cumprod[t]
alpha_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0, device=device)
beta = 1 - alpha / alpha_prev

# Apply confidence-guided update
mean = (x - beta * pred_noise) / torch.sqrt(1 - beta)
var = beta * temperature * (1 - confidence)
x = mean + torch.sqrt(var) * torch.randn_like(x)

return x

def compute_loss(
self,
x: torch.Tensor,
noise: torch.Tensor,
pred_noise: torch.Tensor
) -> torch.Tensor:
"""
Compute training loss.
Args:
x: Input protein features
noise: Target noise
pred_noise: Predicted noise
Returns:
Combined loss value
"""
# MSE loss for noise prediction
noise_loss = F.mse_loss(pred_noise, noise)

# Confidence loss to encourage accurate confidence estimation
confidence = self.confidence_net(x)
confidence_target = torch.exp(-F.mse_loss(pred_noise, noise, reduction='none').mean(-1))
confidence_loss = F.binary_cross_entropy(confidence, confidence_target)

return noise_loss + confidence_loss
Loading

0 comments on commit 67d9252

Please sign in to comment.