Skip to content

Commit

Permalink
Update sequence analyzer to maintain 768-dimensional features through…
Browse files Browse the repository at this point in the history
…out pipeline
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent c74f720 commit 4e563d7
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions models/analysis/enhanced_sequence_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,34 @@ class EnhancedSequenceAnalyzer(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.hidden_size = config.get('hidden_size', 320) # Changed to match ESM2 dimensions
self.hidden_size = 768 # ESM2's output dimension

# Initialize protein language model
self.tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
self.protein_model = AutoModel.from_pretrained('facebook/esm2_t6_8M_UR50D')

# Feature extraction layers
# Feature extraction layers - maintain 768 dimensions
self.feature_extractor = nn.Sequential(
nn.Linear(320, 320), # Match ESM2 output dimension
nn.Linear(768, 768),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(320, 160), # Half the dimension
nn.Linear(768, 768),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(160, 80) # Quarter the dimension
nn.Dropout(0.1)
)

# Pattern recognition module
# Pattern recognition module - maintain 768 dimensions
self.pattern_recognition = nn.Sequential(
nn.Linear(80, 40), # Input from feature extractor output
nn.Linear(768, 768),
nn.ReLU(),
nn.Linear(40, 20) # Reduced dimension for motif identification
nn.Linear(768, 768)
)

# Conservation analysis module
self.conservation_analyzer = ConservationAnalyzer()

# Motif identification module
self.motif_identifier = MotifIdentifier(20)
# Motif identification module - updated input size
self.motif_identifier = MotifIdentifier(768)

def forward(self, sequences: List[str]) -> Dict[str, torch.Tensor]:
# Tokenize sequences
Expand Down Expand Up @@ -102,10 +101,13 @@ class MotifIdentifier(nn.Module):
def __init__(self, input_size: int):
super().__init__()
self.motif_detector = nn.Sequential(
nn.Linear(input_size, input_size * 2),
nn.Linear(input_size, input_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(input_size, input_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(input_size * 2, input_size),
nn.Linear(input_size, input_size),
nn.Sigmoid()
)

Expand Down

0 comments on commit 4e563d7

Please sign in to comment.