diff --git a/python/varspark/core.py b/python/varspark/core.py index 6c353282..5cf22d59 100644 --- a/python/varspark/core.py +++ b/python/varspark/core.py @@ -56,15 +56,26 @@ def __init__(self, ss, silent=False): " /_/ \n" ) - @params(self=object, vcf_file_path=str, min_partitions=int) - def import_vcf(self, vcf_file_path, min_partitions=0): - """Import features from a VCF file.""" + @params(self=object, vcf_file_path=str, imputation_strategy=Nullable(str)) + def import_vcf(self, vcf_file_path, imputation_strategy="none"): + """Import features from a VCF file. + + :param vcf_file_path String: The file path for the vcf file to import + :param imputation_strategy String: + The imputation strategy to use. Options for imputation include: + + - none: No imputation will be performed. Missing values will be replaced with -1 (not recommended unless there are no missing values) + - mode: Missing values will be replaced with the most commonly occuring value among that feature. Recommended option + - zeros: Missing values will be replaced with zeros. Faster than mode imputation + """ + if imputation_strategy == "none": + print("WARNING: Imputation strategy is set to none - please ensure that there are no missing values in the data.") return FeatureSource( self._jvm, self._vs_api, self._jsql, self.sql, - self._jvsc.importVCF(vcf_file_path, min_partitions), + self._jvsc.importVCF(vcf_file_path, imputation_strategy), ) @params( @@ -76,7 +87,7 @@ def import_vcf(self, vcf_file_path, min_partitions=0): def import_covariates(self, cov_file_path, cov_types=None, transposed=False): """Import covariates from a CSV file. - :param cov_file_path: The file path for covariate csv file + :param cov_file_path String: The file path for covariate csv file :param cov_types Dict[String]: A dictionary specifying types for each covariate, where the key is the variable name and the value is the type. The value can be one of the following: diff --git a/src/main/scala/au/csiro/variantspark/api/VSContext.scala b/src/main/scala/au/csiro/variantspark/api/VSContext.scala index 2672e583..d5203ecd 100644 --- a/src/main/scala/au/csiro/variantspark/api/VSContext.scala +++ b/src/main/scala/au/csiro/variantspark/api/VSContext.scala @@ -38,11 +38,11 @@ class VSContext(val spark: SparkSession) extends SqlContextHolder { * @param inputFile path to file or directory with VCF files to load * @return FeatureSource loaded from the VCF file */ - def importVCF(inputFile: String, sparkPar: Int = 0): FeatureSource = { + def importVCF(inputFile: String, imputationStrategy: String = "none"): FeatureSource = { val vcfSource = VCFSource(sc, inputFile) // VCFSource(sc.textFile(inputFile, if (sparkPar > 0) sparkPar else sc.defaultParallelism)) - VCFFeatureSource(vcfSource) + VCFFeatureSource(vcfSource, imputationStrategy) } /** Import features from a CSV file diff --git a/src/main/scala/au/csiro/variantspark/cli/CochranArmanCmd.scala b/src/main/scala/au/csiro/variantspark/cli/CochranArmanCmd.scala index fd067adb..4199ef05 100644 --- a/src/main/scala/au/csiro/variantspark/cli/CochranArmanCmd.scala +++ b/src/main/scala/au/csiro/variantspark/cli/CochranArmanCmd.scala @@ -89,7 +89,7 @@ class CochranArmanCmd extends ArgsApp with SparkApp with Echoable with Logging w VCFSource(sc.textFile(inputFile, if (sparkPar > 0) sparkPar else sc.defaultParallelism)) verbose(s"VCF Version: ${vcfSource.version}") verbose(s"VCF Header: ${vcfSource.header}") - VCFFeatureSource(vcfSource) + VCFFeatureSource(vcfSource, imputationStrategy = "none") } def loadCSV(): CsvFeatureSource = { diff --git a/src/main/scala/au/csiro/variantspark/cli/FilterCmd.scala b/src/main/scala/au/csiro/variantspark/cli/FilterCmd.scala index 2a511a39..c714de9b 100644 --- a/src/main/scala/au/csiro/variantspark/cli/FilterCmd.scala +++ b/src/main/scala/au/csiro/variantspark/cli/FilterCmd.scala @@ -30,7 +30,7 @@ class FilterCmd extends ArgsApp with TestArgs with SparkApp { logDebug(s"Running with filesystem: ${fs}, home: ${fs.getHomeDirectory}") val vcfSource = VCFSource(sc.textFile(inputFile)) - val source = VCFFeatureSource(vcfSource) + val source = VCFFeatureSource(vcfSource, imputationStrategy = "none") val features = source.features.zipWithIndex().cache() val featureCount = features.count() println(s"No features: ${featureCount}") diff --git a/src/main/scala/au/csiro/variantspark/cli/VcfToLabels.scala b/src/main/scala/au/csiro/variantspark/cli/VcfToLabels.scala index 20266407..e575eb5d 100644 --- a/src/main/scala/au/csiro/variantspark/cli/VcfToLabels.scala +++ b/src/main/scala/au/csiro/variantspark/cli/VcfToLabels.scala @@ -27,7 +27,7 @@ class VcfToLabels extends ArgsApp with SparkApp { val version = vcfSource.version println(header) println(version) - val source = VCFFeatureSource(vcfSource) + val source = VCFFeatureSource(vcfSource, imputationStrategy = "none") val columns = source.features.take(limit) CSVUtils.withFile(new File(outputFile)) { writer => writer.writeRow("" :: columns.map(_.label).toList) diff --git a/src/main/scala/au/csiro/variantspark/cli/args/FeatureSourceArgs.scala b/src/main/scala/au/csiro/variantspark/cli/args/FeatureSourceArgs.scala index 40d70924..85f27d8b 100644 --- a/src/main/scala/au/csiro/variantspark/cli/args/FeatureSourceArgs.scala +++ b/src/main/scala/au/csiro/variantspark/cli/args/FeatureSourceArgs.scala @@ -26,8 +26,8 @@ object VCFFeatureSourceFactory { val DEF_SEPARATOR: String = "_" } -case class VCFFeatureSourceFactory(inputFile: String, isBiallelic: Option[Boolean], - separator: Option[String]) +case class VCFFeatureSourceFactory(inputFile: String, imputationStrategy: Option[String], + isBiallelic: Option[Boolean], separator: Option[String]) extends FeatureSourceFactory with Echoable { def createSource(sparkArgs: SparkArgs): FeatureSource = { echo(s"Loading header from VCF file: ${inputFile}") @@ -36,8 +36,8 @@ case class VCFFeatureSourceFactory(inputFile: String, isBiallelic: Option[Boolea verbose(s"VCF Header: ${vcfSource.header}") import VCFFeatureSourceFactory._ - VCFFeatureSource(vcfSource, isBiallelic.getOrElse(DEF_IS_BIALLELIC), - separator.getOrElse(DEF_SEPARATOR)) + VCFFeatureSource(vcfSource, imputationStrategy.getOrElse("none"), + isBiallelic.getOrElse(DEF_IS_BIALLELIC), separator.getOrElse(DEF_SEPARATOR)) } } diff --git a/src/main/scala/au/csiro/variantspark/input/VCFFeatureSource.scala b/src/main/scala/au/csiro/variantspark/input/VCFFeatureSource.scala index da647071..c556fedf 100644 --- a/src/main/scala/au/csiro/variantspark/input/VCFFeatureSource.scala +++ b/src/main/scala/au/csiro/variantspark/input/VCFFeatureSource.scala @@ -10,6 +10,8 @@ import au.csiro.variantspark.data.StdFeature trait VariantToFeatureConverter { def convert(vc: VariantContext): Feature + def convertModeImputed(vc: VariantContext): Feature + def convertZeroImputed(vc: VariantContext): Feature } case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: String = "_") @@ -20,6 +22,18 @@ case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: S StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), gts) } + def convertModeImputed(vc: VariantContext): Feature = { + val gts = vc.getGenotypes.iterator().asScala.map(convertGenotype).toArray + val modeImputedGts = ModeImputationStrategy(noLevels = 3).impute(gts) + StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), modeImputedGts) + } + + def convertZeroImputed(vc: VariantContext): Feature = { + val gts = vc.getGenotypes.iterator().asScala.map(convertGenotype).toArray + val zeroImputedGts = ZeroImputationStrategy.impute(gts) + StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), zeroImputedGts) + } + def convertLabel(vc: VariantContext): String = { if (biallelic && !vc.isBiallelic) { @@ -44,23 +58,34 @@ case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: S } def convertGenotype(gt: Genotype): Byte = { - if (!gt.isCalled || gt.isHomRef) 0 else if (gt.isHomVar || gt.isHetNonRef) 2 else 1 + if (!gt.isCalled) Missing.BYTE_NA_VALUE + else if (gt.isHomRef) 0 + else if (gt.isHomVar || gt.isHetNonRef) 2 + else 1 } } -class VCFFeatureSource(vcfSource: VCFSource, converter: VariantToFeatureConverter) +class VCFFeatureSource(vcfSource: VCFSource, converter: VariantToFeatureConverter, + imputationStrategy: String) extends FeatureSource { override lazy val sampleNames: List[String] = vcfSource.header.getGenotypeSamples.asScala.toList override def features: RDD[Feature] = { val converterRef = converter - vcfSource.genotypes().map(converterRef.convert) + imputationStrategy match { + case "none" => vcfSource.genotypes().map(converterRef.convert) + case "mode" => vcfSource.genotypes().map(converterRef.convertModeImputed) + case "zeros" => vcfSource.genotypes().map(converterRef.convertZeroImputed) + case _ => + throw new IllegalArgumentException(s"Unknown imputation strategy: $imputationStrategy") + } } } object VCFFeatureSource { - def apply(vcfSource: VCFSource, biallelic: Boolean = false, + def apply(vcfSource: VCFSource, imputationStrategy: String, biallelic: Boolean = false, separator: String = "_"): VCFFeatureSource = { - new VCFFeatureSource(vcfSource, DefVariantToFeatureConverter(biallelic, separator)) + new VCFFeatureSource(vcfSource, DefVariantToFeatureConverter(biallelic, separator), + imputationStrategy) } } diff --git a/src/test/scala/au/csiro/variantspark/input/DefVariantToFeatureConverterTest.scala b/src/test/scala/au/csiro/variantspark/input/DefVariantToFeatureConverterTest.scala index 00faf2ad..b062fb92 100644 --- a/src/test/scala/au/csiro/variantspark/input/DefVariantToFeatureConverterTest.scala +++ b/src/test/scala/au/csiro/variantspark/input/DefVariantToFeatureConverterTest.scala @@ -37,7 +37,7 @@ class DefVariantToFeatureConverterTest { @Test def testConvertsBialleicVariantCorrctly() { val converter = DefVariantToFeatureConverter(true, ":") - val result = converter.convert(bialellicVC) + val result = converter.convertZeroImputed(bialellicVC) assertEquals("chr1:10:T:A", result.label) assertArrayEquals(expectedEncodedGenotype, result.valueAsByteArray) } @@ -45,7 +45,7 @@ class DefVariantToFeatureConverterTest { @Test def testConvertsMultialleicVariantCorrctly() { val converter = DefVariantToFeatureConverter(false) - val result = converter.convert(multialleciVC) + val result = converter.convertZeroImputed(multialleciVC) assertEquals("chr1_10_T_A|G", result.label) assertArrayEquals(expectedEncodedGenotype, result.valueAsByteArray) } diff --git a/src/test/scala/au/csiro/variantspark/misc/CovariateReproducibilityTest.scala b/src/test/scala/au/csiro/variantspark/misc/CovariateReproducibilityTest.scala index 8a95afd0..1175dd2f 100644 --- a/src/test/scala/au/csiro/variantspark/misc/CovariateReproducibilityTest.scala +++ b/src/test/scala/au/csiro/variantspark/misc/CovariateReproducibilityTest.scala @@ -28,7 +28,7 @@ class CovariateReproducibilityTest extends SparkTest { def testCovariateReproducibleResults() { implicit val vsContext = VSContext(spark) implicit val sqlContext = spark.sqlContext - val genotypes = vsContext.importVCF("data/chr22_1000.vcf", 3) + val genotypes = vsContext.importVCF("data/chr22_1000.vcf") val optVariableTypes = new ArrayList[String](Arrays.asList("CONTINUOUS", "ORDINAL(2)", "CONTINUOUS", "CONTINUOUS", "CONTINUOUS", "CONTINUOUS")) val covariates = diff --git a/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala b/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala index 066f6e33..6fb581a6 100644 --- a/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala +++ b/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala @@ -25,7 +25,7 @@ class ReproducibilityTest extends SparkTest { def testReproducibleResults() { implicit val vsContext = VSContext(spark) implicit val sqlContext = spark.sqlContext - val features = vsContext.importVCF("data/chr22_1000.vcf", 3) + val features = vsContext.importVCF("data/chr22_1000.vcf") val label = vsContext.loadLabel("data/chr22-labels.csv", "22_16051249") val params = RandomForestParams(seed = 13L) val rfModel1 = RFModelTrainer.trainModel(features, label, params, 40, 20)