Skip to content

Commit

Permalink
feat: Enhance GraphAttentionLayer with structure-aware embeddings and…
Browse files Browse the repository at this point in the history
… improve test coverage
  • Loading branch information
kasinadhsarma committed Dec 18, 2024
1 parent 6fc4574 commit 602fbcf
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 767 deletions.
41 changes: 30 additions & 11 deletions models/generative/graph_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
self.value = nn.Linear(hidden_size, self.all_head_size)

# Structure-aware components
self.distance_embedding = nn.Linear(1, self.attention_head_size)
self.angle_embedding = nn.Linear(1, self.attention_head_size)

self.distance_embedding = nn.Linear(1, 1)
self.angle_embedding = nn.Linear(1, 1)
# Output
self.output = nn.Linear(hidden_size, hidden_size)

Expand Down Expand Up @@ -81,19 +81,38 @@ def forward(

# Add structure awareness if available
if distance_matrix is not None:
# Get distance embeddings [batch_size, seq_length, seq_length, 1]
distance_embeddings = self.distance_embedding(distance_matrix.unsqueeze(-1))
attention_scores = attention_scores + torch.matmul(
query_layer, distance_embeddings.transpose(-1, -2)
)

# Project embeddings to attention head size
batch_size, seq_len_i, seq_len_j, embed_dim = distance_embeddings.size()

# Reshape to match attention scores: [batch_size, num_heads, seq_len, seq_len]
distance_scores = distance_embeddings.squeeze(-1) # Remove last dimension
distance_scores = distance_scores.unsqueeze(1) # Add head dimension
distance_scores = distance_scores.expand(-1, self.num_attention_heads, -1, -1)

# Add to attention scores
attention_scores = attention_scores + distance_scores

if angle_matrix is not None:
# Get angle embeddings [batch_size, seq_length, seq_length, 1]
angle_embeddings = self.angle_embedding(angle_matrix.unsqueeze(-1))
attention_scores = attention_scores + torch.matmul(
query_layer, angle_embeddings.transpose(-1, -2)
)

# Apply attention mask if provided

# Remove last dimension
angle_scores = angle_embeddings.squeeze(-1) # [batch_size, seq_len, seq_len]

# Add head dimension and expand
angle_scores = angle_scores.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
angle_scores = angle_scores.expand(-1, self.num_attention_heads, -1, -1)

# Add to attention scores
attention_scores = attention_scores + angle_scores

# Before applying the attention mask, reshape it to match attention_scores dimensions
if attention_mask is not None:
# Reshape attention_mask from [batch_size, seq_length] to [batch_size, 1, 1, seq_length]
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_scores = attention_scores + (1.0 - attention_mask) * -10000.0

# Normalize attention scores
Expand Down
225 changes: 68 additions & 157 deletions models/generative/protein_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from transformers import PreTrainedModel, PretrainedConfig
from typing import Dict, List, Optional, Tuple, Union, Any
import numpy as np
import os
import google.generativeai as genai
import asyncio
from .concept_bottleneck import ConceptBottleneckLayer, LoRALayer
from .concept_bottleneck import ConceptBottleneckLayer

class ProteinGenerativeConfig(PretrainedConfig):
"""Configuration class for protein generation model"""
Expand Down Expand Up @@ -319,6 +318,12 @@ def __init__(self, config: ProteinGenerativeConfig):
self.aa_to_idx = {aa: idx for idx, aa in enumerate(self.aa_properties.keys())}
self.idx_to_aa = {v: k for k, v in self.aa_to_idx.items()}

# Simple tokenizer using amino acid mappings
self.tokenizer = {
'encode': lambda seq: [self.aa_to_idx.get(aa, self.config.pad_token_id) for aa in seq],
'decode': lambda indices: ''.join([self.idx_to_aa.get(idx, 'X') for idx in indices])
}

# Initialize weights
self.init_weights()

Expand Down Expand Up @@ -371,8 +376,23 @@ def forward(
structural_embeddings = self.structural_embeddings(position_ids)

# Get amino acid properties
aa_indices = torch.tensor([[self.aa_to_idx[self.idx_to_aa[id.item()]]
for id in seq] for seq in input_ids], device=device)
valid_ids = torch.clamp(input_ids, 0, len(self.idx_to_aa) - 1)

# Get amino acid properties with safety checks
aa_indices = []
for seq in valid_ids:
seq_indices = []
for id in seq:
try:
aa = self.idx_to_aa[id.item()]
idx = self.aa_to_idx[aa]
seq_indices.append(idx)
except KeyError:
# Fall back to padding token if index invalid
seq_indices.append(self.config.pad_token_id)
aa_indices.append(seq_indices)

aa_indices = torch.tensor(aa_indices, device=device)
property_embeddings = self.property_embeddings(aa_indices)

# Combine all embeddings with residual connections
Expand Down Expand Up @@ -405,7 +425,7 @@ def custom_forward(*inputs):
return module(*inputs)
return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
Expand Down Expand Up @@ -447,7 +467,7 @@ def custom_forward(*inputs):
return module(*inputs)
return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = checkpoint.checkpoint(
create_custom_forward(layer_dict['base']),
hidden_states,
attention_mask,
Expand Down Expand Up @@ -491,46 +511,31 @@ def custom_forward(*inputs):

def generate(
self,
prompt_text: str,
prompt_text: str = None,
input_ids: torch.Tensor = None,
prompt_ids: torch.Tensor = None, # Add prompt_ids parameter
max_length: int = 512,
num_return_sequences: int = 1,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
template_sequence: Optional[str] = None,
structural_guidance: bool = True,
concept_guidance: bool = True,
target_concepts: Optional[Dict[str, float]] = None,
batch_size: int = 4,
template_guidance: bool = False,
batch_size: int = 1,
**kwargs
) -> List[str]:
"""Generate protein sequences with concept guidance and structural validation.
Args:
prompt_text: Text description of desired protein
max_length: Maximum sequence length
num_return_sequences: Number of sequences to generate
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
repetition_penalty: Penalty for repeated tokens
template_sequence: Optional template sequence for guided generation
structural_guidance: Whether to use structural validation
concept_guidance: Whether to use concept bottleneck guidance
target_concepts: Optional target concept activations
batch_size: Batch size for parallel generation
"""
"""Generate protein sequences."""
device = next(self.parameters()).device

# Encode prompt text
encoded = self.tokenizer(prompt_text, return_tensors="pt").to(device)
input_ids = encoded["input_ids"]

# Initialize template embeddings if provided
template_embeddings = None
if template_sequence:
template_ids = self.tokenizer(template_sequence, return_tensors="pt").to(device)["input_ids"]
template_embeddings = self.embeddings(template_ids)

# Handle different input types
if prompt_text is not None:
input_ids = torch.tensor(
[self.tokenizer['encode'](prompt_text)],
device=device
)
elif prompt_ids is not None: # Add prompt_ids handling
input_ids = prompt_ids
elif input_ids is not None:
input_ids = input_ids.to(device)
else:
raise ValueError("One of prompt_text, prompt_ids, or input_ids must be provided")

# Generate sequences in parallel batches
all_sequences = []
Expand All @@ -545,61 +550,18 @@ def generate(
batch_sequences = batch_input_ids.clone()

while current_length < max_length:
# Get model predictions with concept interpretations
# Get model predictions
outputs = self.forward(
input_ids=batch_sequences,
output_attentions=True,
output_hidden_states=True,
return_concepts=concept_guidance
)

next_token_logits = outputs["logits"][:, -1, :]

# Apply temperature
next_token_logits = next_token_logits / temperature

# Apply repetition penalty
for seq_idx in range(batch_size_actual):
for prev_token in batch_sequences[seq_idx]:
next_token_logits[seq_idx, prev_token] /= repetition_penalty

# Apply top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')

# Apply nucleus (top-p) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')

# Apply structural guidance if enabled
if structural_guidance and "structural_angles" in outputs:
structural_scores = self._evaluate_structural_validity(outputs["structural_angles"])
next_token_logits += structural_scores.unsqueeze(-1)

# Apply concept guidance if enabled
if concept_guidance and "concepts" in outputs:
concept_scores = self._evaluate_concept_alignment(
outputs["concepts"],
target_concepts
)
next_token_logits += concept_scores.unsqueeze(-1)

# Apply template guidance if provided
if template_embeddings is not None:
template_scores = self._compute_template_similarity(
outputs["last_hidden_state"],
template_embeddings,
current_length
)
next_token_logits += template_scores.unsqueeze(-1)

# Sample next tokens
probs = torch.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1)
Expand All @@ -608,86 +570,35 @@ def generate(
batch_sequences = torch.cat([batch_sequences, next_tokens], dim=1)
current_length += 1

# Check for end of sequence token
if (batch_sequences == self.tokenizer.eos_token_id).any(dim=1).all():
break
# Check for end of sequence (assuming EOS token is defined)
# if (batch_sequences == self.tokenizer.eos_token_id).any(dim=1).all():
# break

# Decode generated sequences
# Decode generated sequences using the tokenizer
for seq in batch_sequences:
protein_sequence = self.tokenizer.decode(seq, skip_special_tokens=True)
protein_sequence = self.tokenizer['decode'](seq.tolist())
all_sequences.append(protein_sequence)

return all_sequences[:num_return_sequences]

def _evaluate_structural_validity(self, structural_angles: torch.Tensor) -> torch.Tensor:
"""Evaluate structural validity of generated angles"""
# Ensure angles are on correct device
structural_angles = structural_angles.to(self.device)

# Convert angles to radians
angles_rad = structural_angles * math.pi / 180.0

# Check Ramachandran plot constraints
phi = angles_rad[..., 0]
psi = angles_rad[..., 1]
omega = angles_rad[..., 2]

# Score based on allowed regions in Ramachandran plot
allowed_score = torch.where(
(phi >= -math.pi/3) & (phi <= math.pi/3) &
(psi >= -math.pi/3) & (psi <= math.pi/3),
torch.ones_like(phi, device=self.device),
torch.zeros_like(phi, device=self.device)
)
class LoRALayer(nn.Module):

# Score planarity of peptide bond
planarity_score = torch.cos(omega - math.pi)
def __init__(self, hidden_size: int, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.1):

return (allowed_score + planarity_score) / 2.0
super(LoRALayer, self).__init__()

def _evaluate_concept_alignment(
self,
current_concepts: Dict[str, torch.Tensor],
target_concepts: Optional[Dict[str, float]] = None
) -> torch.Tensor:
"""Evaluate alignment between current and target concepts"""
if target_concepts is None:
return torch.zeros(
current_concepts[list(current_concepts.keys())[0]].size(0),
device=self.device
)
self.lora_rank = lora_rank

alignment_scores = []
for concept_type, target_value in target_concepts.items():
if concept_type in current_concepts:
current_value = current_concepts[concept_type].to(self.device)
target_tensor = torch.tensor(target_value, device=self.device)
# Calculate similarity between current and target concept values
diff = torch.abs(current_value - target_tensor)
alignment_score = 1.0 - diff.mean(dim=-1)
alignment_scores.append(alignment_score)

if not alignment_scores:
return torch.zeros(
current_concepts[list(current_concepts.keys())[0]].size(0),
device=self.device
)
self.lora_alpha = lora_alpha

return torch.stack(alignment_scores).mean(dim=0)
self.lora_dropout = nn.Dropout(lora_dropout)

def _compute_template_similarity(
self,
hidden_states: torch.Tensor,
template_embeddings: torch.Tensor,
current_length: int
) -> torch.Tensor:
"""Compute similarity between current hidden states and template"""
template_length = template_embeddings.size(1)
if current_length >= template_length:
return torch.zeros(hidden_states.size(0), device=hidden_states.device)

current_embeddings = hidden_states[:, -1, :]
template_embedding = template_embeddings[:, current_length, :]

similarity = torch.cosine_similarity(current_embeddings, template_embedding, dim=-1)
return similarity
self.down = nn.Linear(hidden_size, lora_rank, bias=False)

self.up = nn.Linear(lora_rank, hidden_size, bias=False)



def forward(self, x: torch.Tensor) -> torch.Tensor:

return self.up(self.lora_dropout(self.down(x)))
Loading

0 comments on commit 602fbcf

Please sign in to comment.