Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[samples] add inference sample for simple snp trainer #128

Open
wants to merge 6 commits into
base: dev-v0.1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion samples/simple_consensus_caller/rnn_consensus_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A sample program highlighting usage of VariantWorks SDK to write a simple SNP variant caller using a CNN."""
"""A sample program highlighting usage of VariantWorks SDK to write a simple consensus training tool."""

import argparse

Expand Down
78 changes: 78 additions & 0 deletions samples/simple_snp_trainer/cnn_snp_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python
#
# Copyright 2020 NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A sample program highlighting usage of VariantWorks SDK to write a simple SNP variant caller using a CNN."""

import argparse

import nemo

from variantworks.dataloader import HDFDataLoader
from variantworks.networks import AlexNet
from variantworks.neural_types import ReadPileupNeuralType, VariantZygosityNeuralType


def create_model():
"""Return neural network to test."""
# Neural Network
alexnet = AlexNet(num_input_channels=2, num_output_logits=3)

return alexnet


def infer(parsed_args):
"""Infer a sample model."""
# Create neural factory as per NeMo requirements.
nf = nemo.core.NeuralModuleFactory(
placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir=parsed_args.model_dir)

model = create_model()

# Create test DAG
test_dataset = HDFDataLoader(args.test_hdf, batch_size=32,
shuffle=True, num_workers=args.threads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as I'm thinking about this, I am realizing that we can't use the HDFDataLoader for the inference script here. We'll want to use the ReadPileupDataLoader because that way we'll have access to the Variant entries to serialize outputs to a VCF.

For that one, we also have to set shuffle=False so the inference order is the same as the VCF entries order.

tensor_keys=["encodings", "labels"],
tensor_dims=[('B', 'C', 'H', 'W'), tuple('B')],
tensor_neural_types=[ReadPileupNeuralType(), VariantZygosityNeuralType()])
encoding, vz_labels = test_dataset()

vz = model(encoding=encoding)

nf.infer([vz], checkpoint_dir=parsed_args.model_dir, verbose=True)


def build_parser():
"""Build parser object with options for sample."""
import multiprocessing

parser = argparse.ArgumentParser(
description="Simple model inference SNP caller based on VariantWorks.")
parser.add_argument("--test-hdf",
help="HDF with examples for testing.",
required=True)
parser.add_argument("-t", "--threads", type=int,
help="Threads to use for parallel loading.",
required=False, default=multiprocessing.cpu_count())
parser.add_argument("--model-dir", type=str,
help="Directory for loading saved trained model checkpoints.",
required=False, default="./models")
return parser


if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
infer(args)
2 changes: 1 addition & 1 deletion samples/simple_snp_trainer/cnn_snp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def train(args):
def build_parser():
"""Build parser object with options for sample."""
parser = argparse.ArgumentParser(
description="Simple SNP caller based on VariantWorks.")
description="Simple model training for SNP caller based on VariantWorks.")

parser.add_argument("--train-hdf",
help="HDF with examples for training.",
Expand Down