Skip to content

Commit

Permalink
Simplify pattern recognition and fix dimension handling
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent b7902d7 commit b702d33
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions models/analysis/enhanced_sequence_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class EnhancedSequenceAnalyzer(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.hidden_size = config.get('hidden_size', 768)
self.hidden_size = config.get('hidden_size', 320) # Changed to match ESM2 dimensions

# Initialize protein language model
self.tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
Expand All @@ -26,49 +26,49 @@ def __init__(self, config: Dict):
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.hidden_size, self.hidden_size // 2)
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.hidden_size // 2, self.hidden_size // 4)
)

# Pattern recognition module
self.pattern_recognizer = nn.Sequential(
nn.Conv1d(self.hidden_size // 2, self.hidden_size // 4, kernel_size=3, padding=1),
self.pattern_recognition = nn.Sequential(
nn.Linear(self.hidden_size // 4, self.hidden_size // 8),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=3, padding=1)
nn.Linear(self.hidden_size // 8, self.hidden_size // 16)
)

# Conservation analysis module
self.conservation_analyzer = ConservationAnalyzer()

# Motif identification module
self.motif_identifier = MotifIdentifier(self.hidden_size // 8)
self.motif_identifier = MotifIdentifier(self.hidden_size // 16)

def forward(self, sequences: List[str]) -> Dict[str, torch.Tensor]:
# Tokenize sequences
encoded = self.tokenizer(sequences, return_tensors="pt", padding=True)

# Get protein embeddings
with torch.no_grad():
protein_features = self.protein_model(**encoded).last_hidden_state
protein_features = self.protein_model(**encoded).last_hidden_state

# Extract sequence features
# Extract features
features = self.feature_extractor(protein_features)

# Analyze patterns
patterns = self.pattern_recognizer(features.transpose(1, 2)).transpose(1, 2)
# Pattern recognition
pattern_features = self.pattern_recognition(features)

# Analyze conservation
# Conservation analysis
conservation_scores = self.conservation_analyzer(sequences)

# Identify motifs
motifs = self.motif_identifier(patterns)
# Motif identification
motif_features = self.motif_identifier(pattern_features)

return {
'embeddings': protein_features,
'features': features,
'patterns': patterns,
'embeddings': features,
'patterns': pattern_features,
'conservation': conservation_scores,
'motifs': motifs
'motifs': motif_features
}

def analyze_sequence(self, sequence: str) -> Dict[str, torch.Tensor]:
Expand Down

0 comments on commit b702d33

Please sign in to comment.