-
Notifications
You must be signed in to change notification settings - Fork 1
Model Training
Aayush Grover edited this page May 13, 2025
·
4 revisions
An example script to train a model using asap has been defined in tutorials/train.py.
asap.training_datasets(signal_file, genome, train_chroms, val_chroms, generated, blacklist_file=None, unmap_file=None)Creates training and validation datasets using genomic sequence as input and ATAC-seq signal as output.
Args:
-
signal_file(str): Path to the ATAC-seq signal file. -
genome(str): Path to the genome file. -
train_chroms(List[int]): List of chromosomes for training. -
val_chroms(List[int]): List of chromosomes for validation. -
generated(str): Path to save the processed data. -
blacklist_file(List[str]): List of paths to blacklist files (including SNVs). -
unmap_file(str): Path to the unmapped regions file.
Returns:
-
train_dataset(asap.dataloader.WGDataset): Training dataset -
val_dataset(asap.dataloader.WGDataset): Validation dataset
asap.train_model(experiment_name, model, train_dataset, val_dataset, logs_dir, n_gpus=0, max_epochs=70, learning_rate=1e-3, batch_size=64, use_map=False)Trains the selected model using training and validation datasets.
Args:
-
experiment_name(str): The name of the experiment. This will be used to save model checkpoints. -
model(str): The name of the model to train. Choose from [cnn, lstm, dcnn, convnext_cnn, convnext_lstm, convnext_dcnn, convnext_transformer]. -
train_dataset(asap.dataloader.WGDataset): The training dataset. -
val_dataset(asap.dataloader.WGDataset): The validation dataset. -
logs_dir(str): The directory where the model is saved. -
n_gpus(int): The number of GPUs to use for training. Set 0 for CPU. -
max_epochs(int): The maximum number of epochs to train the model. -
learning_rate(float): The learning rate for the optimizer. -
batch_size(int): The batch size for training. -
use_map(bool): Whether to additionally use mappability information for training.
Returns:
- None
Note: The Pearson's R for the trained model on the validation dataset is expected to be ~0.7.