From bde87550bcd64c64613574e87de7e6814bdb6865 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 16 Dec 2024 13:46:21 -0500 Subject: [PATCH] [query] Expose references via `ExecuteContext` --- hail/hail/src/is/hail/HailFeatureFlags.scala | 3 + hail/hail/src/is/hail/backend/Backend.scala | 8 +- .../src/is/hail/backend/ExecuteContext.scala | 9 +- .../is/hail/backend/local/LocalBackend.scala | 2 + .../backend/py4j/Py4JBackendExtensions.scala | 38 ++-- .../hail/backend/service/ServiceBackend.scala | 28 ++- .../is/hail/backend/spark/SparkBackend.scala | 3 + .../is/hail/expr/ir/EmitClassBuilder.scala | 2 +- .../hail/expr/ir/ExtractIntervalFilters.scala | 4 +- .../src/is/hail/expr/ir/MatrixValue.scala | 2 +- .../hail/src/is/hail/io/plink/LoadPlink.scala | 2 +- .../src/is/hail/io/reference/LiftOver.scala | 69 ------ .../src/is/hail/io/reference/package.scala | 164 ++++++++++++++ hail/hail/src/is/hail/io/vcf/LoadVCF.scala | 4 +- .../src/is/hail/variant/ReferenceGenome.scala | 204 ++++-------------- hail/hail/test/src/is/hail/HailSuite.scala | 4 - .../src/is/hail/annotations/UnsafeSuite.scala | 2 +- .../expr/ir/ExtractIntervalFiltersSuite.scala | 2 +- .../test/src/is/hail/expr/ir/IRSuite.scala | 6 +- .../is/hail/expr/ir/LocusFunctionsSuite.scala | 6 +- .../is/hail/expr/ir/RequirednessSuite.scala | 3 +- .../is/hail/expr/ir/StagedMinHeapSuite.scala | 97 ++++----- .../lowering/LowerDistributedSortSuite.scala | 10 +- .../is/hail/utils/FlipbookIteratorSuite.scala | 27 ++- .../is/hail/variant/LocusIntervalSuite.scala | 6 +- .../hail/variant/ReferenceGenomeSuite.scala | 16 +- 26 files changed, 343 insertions(+), 378 deletions(-) delete mode 100644 hail/hail/src/is/hail/io/reference/LiftOver.scala create mode 100644 hail/hail/src/is/hail/io/reference/package.scala diff --git a/hail/hail/src/is/hail/HailFeatureFlags.scala b/hail/hail/src/is/hail/HailFeatureFlags.scala index 48bb22bb390..49eff3139ec 100644 --- a/hail/hail/src/is/hail/HailFeatureFlags.scala +++ b/hail/hail/src/is/hail/HailFeatureFlags.scala @@ -68,6 +68,9 @@ class HailFeatureFlags private ( flags.update(flag, value) } + def +(feature: (String, String)): HailFeatureFlags = + new HailFeatureFlags(flags + (feature._1 -> feature._2)) + def get(flag: String): String = flags(flag) def lookup(flag: String): Option[String] = diff --git a/hail/hail/src/is/hail/backend/Backend.scala b/hail/hail/src/is/hail/backend/Backend.scala index c9e97038b7b..772bfb193f8 100644 --- a/hail/hail/src/is/hail/backend/Backend.scala +++ b/hail/hail/src/is/hail/backend/Backend.scala @@ -105,8 +105,6 @@ abstract class Backend extends Closeable { def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) : CompiledFunction[T] - def references: mutable.Map[String, ReferenceGenome] - def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, @@ -181,10 +179,8 @@ abstract class Backend extends Closeable { ): Array[Byte] = withExecuteContext { ctx => jsonToBytes { - Extraction.decompose { - ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput).toJSON - }(defaultJSONFormats) + ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, + xContigs, yContigs, mtContigs, parInput).toJSON } } diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index ce477a3232f..da0acc04459 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -63,6 +63,7 @@ object ExecuteContext { tmpdir: String, localTmpdir: String, backend: Backend, + references: Map[String, ReferenceGenome], fs: FS, timer: ExecutionTimer, tempFileManager: TempFileManager, @@ -79,6 +80,7 @@ object ExecuteContext { tmpdir, localTmpdir, backend, + references, fs, region, timer, @@ -107,6 +109,7 @@ class ExecuteContext( val tmpdir: String, val localTmpdir: String, val backend: Backend, + val references: Map[String, ReferenceGenome], val fs: FS, val r: Region, val timer: ExecutionTimer, @@ -128,7 +131,7 @@ class ExecuteContext( ) } - def stateManager = HailStateManager(backend.references.toMap) + val stateManager = HailStateManager(references) val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) @@ -154,8 +157,6 @@ class ExecuteContext( def getFlag(name: String): String = flags.get(name) - def getReference(name: String): ReferenceGenome = backend.references(name) - def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null def shouldNotLogIR(): Boolean = flags.get("no_ir_logging") != null @@ -174,6 +175,7 @@ class ExecuteContext( tmpdir: String = this.tmpdir, localTmpdir: String = this.localTmpdir, backend: Backend = this.backend, + references: Map[String, ReferenceGenome] = this.references, fs: FS = this.fs, r: Region = this.r, timer: ExecutionTimer = this.timer, @@ -189,6 +191,7 @@ class ExecuteContext( tmpdir, localTmpdir, backend, + references, fs, r, timer, diff --git a/hail/hail/src/is/hail/backend/local/LocalBackend.scala b/hail/hail/src/is/hail/backend/local/LocalBackend.scala index 5c9eb29a5ce..c7c79e1caf5 100644 --- a/hail/hail/src/is/hail/backend/local/LocalBackend.scala +++ b/hail/hail/src/is/hail/backend/local/LocalBackend.scala @@ -87,6 +87,7 @@ class LocalBackend( override val references: mutable.Map[String, ReferenceGenome], ) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { + override def backend: Backend = this override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() override def longLifeTempFileManager: TempFileManager = null @@ -102,6 +103,7 @@ class LocalBackend( tmpdir, tmpdir, this, + references.toMap, fs, timer, null, diff --git a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala index 5f4fe8eeb5a..3cb0f436607 100644 --- a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -10,12 +10,14 @@ import is.hail.expr.ir.{ import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx} import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.linalg.RowMatrix import is.hail.types.physical.PStruct import is.hail.types.virtual.{TArray, TInterval} import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval} import is.hail.variant.ReferenceGenome +import scala.collection.mutable import scala.jdk.CollectionConverters.{ asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, } @@ -29,7 +31,10 @@ import org.json4s.Formats import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing -trait Py4JBackendExtensions { this: Backend => +trait Py4JBackendExtensions { + def backend: Backend + def references: mutable.Map[String, ReferenceGenome] + def persistedIR: mutable.Map[Int, BaseIR] def flags: HailFeatureFlags def longLifeTempFileManager: TempFileManager @@ -59,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend => persistedIR.remove(id) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) + backend.withExecuteContext { ctx => + references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile)) + } def pyRemoveSequence(name: String): Unit = references(name).removeSequence() @@ -74,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend => partitionSize: java.lang.Integer, entries: String, ): Unit = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) entries match { case "full" => @@ -112,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend => returnType: String, bodyStr: String, ): Unit = { - withExecuteContext { ctx => + backend.withExecuteContext { ctx => IRFunctionRegistry.registerIR( ctx, name, @@ -126,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend => } def pyExecuteLiteral(irStr: String): Int = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) assert(ir.typ.isRealizable) - execute(ctx, ir) match { + backend.execute(ctx, ir) match { case Left(_) => throw new HailException("Can't create literal") case Right((pt, addr)) => val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) @@ -158,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend => } def pyToDF(s: String): DataFrame = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) Interpret(tir, ctx).toDF() } def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => log.info("pyReadMultipleMatrixTables: got query") val kvs = JsonMethods.parse(jsonQuery) match { case json4s.JObject(values) => values.toMap @@ -193,10 +200,12 @@ trait Py4JBackendExtensions { this: Backend => addReference(ReferenceGenome.fromJSON(jsonConfig)) def pyRemoveReference(name: String): Unit = - references.remove(name) + removeReference(name) def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) + backend.withExecuteContext { ctx => + references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile)) + } def pyRemoveLiftover(name: String, destRGName: String): Unit = references(name).removeLiftover(destRGName) @@ -204,8 +213,11 @@ trait Py4JBackendExtensions { this: Backend => private[this] def addReference(rg: ReferenceGenome): Unit = ReferenceGenome.addFatalOnCollision(references, FastSeq(rg)) + private[this] def removeReference(name: String): Unit = + references -= name + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => IRParser.parse_value_ir( s, IRParserEnvironment(ctx, irMap = persistedIR.toMap), @@ -231,7 +243,7 @@ trait Py4JBackendExtensions { this: Backend => } def loadReferencesFromDataset(path: String): Array[Byte] = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) ReferenceGenome.addFatalOnCollision(references, rgs) @@ -245,7 +257,7 @@ trait Py4JBackendExtensions { this: Backend => f: ExecuteContext => T )(implicit E: Enclosing ): T = - withExecuteContext { ctx => + backend.withExecuteContext { ctx => val tempFileManager = longLifeTempFileManager if (selfContainedExecution && tempFileManager != null) f(ctx) else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index 77f8f1d8752..e1daafc1bdf 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -13,6 +13,7 @@ import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.linalg.BlockMatrix import is.hail.services.{BatchClient, JobGroupRequest, _} import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} @@ -93,7 +94,16 @@ object ServiceBackend { rpcConfig.custom_references.map(ReferenceGenome.fromJSON), ) - val backend = new ServiceBackend( + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } + + new ServiceBackend( JarUrl(jarLocation), name, theHailClassLoader, @@ -106,19 +116,6 @@ object ServiceBackend { backendContext, scratchDir, ) - - backend.withExecuteContext { ctx => - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(ctx, chainFile, destGenome) - } - } - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(ctx, seq.fasta, seq.index) - } - } - - backend } } @@ -126,7 +123,7 @@ class ServiceBackend( val jarSpec: JarSpec, var name: String, val theHailClassLoader: HailClassLoader, - override val references: mutable.Map[String, ReferenceGenome], + val references: mutable.Map[String, ReferenceGenome], val batchClient: BatchClient, val batchConfig: BatchConfig, val flags: HailFeatureFlags, @@ -397,6 +394,7 @@ class ServiceBackend( tmpdir, "file:///tmp", this, + references.toMap, fs, timer, null, diff --git a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala index f35c275001e..eda20bedde7 100644 --- a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala +++ b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala @@ -346,6 +346,7 @@ class SparkBackend( new HadoopFS(new SerializableHadoopConfiguration(conf)) } + override def backend: Backend = this override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() override val longLifeTempFileManager: TempFileManager = @@ -375,6 +376,7 @@ class SparkBackend( tmpdir, localTmpdir, this, + references.toMap, fs, region, timer, @@ -394,6 +396,7 @@ class SparkBackend( tmpdir, localTmpdir, this, + references.toMap, fs, timer, null, diff --git a/hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala b/hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala index f568208ed33..e008486d437 100644 --- a/hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala +++ b/hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala @@ -72,7 +72,7 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) { } def referenceGenomes(): IndexedSeq[ReferenceGenome] = - rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name) + rgContainers.keys.map(ctx.references(_)).toIndexedSeq.sortBy(_.name) def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] = rgContainers.toFastSeq.sortBy(_._1).map(_._2) diff --git a/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala index 3a46852a293..e326a210811 100644 --- a/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -647,7 +647,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { BoolValue.fromComparison(l, op).restrict(keySet) case Contig(rgStr) => // locus contig equality comparison - val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match { + val b = getIntervalFromContig(l.asInstanceOf[String], ctx.references(rgStr)) match { case Some(i) => val b = BoolValue( KeySet(i), @@ -671,7 +671,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case Position(rgStr) => // locus position comparison val posBoolValue = BoolValue.fromComparison(l, op) - val rg = ctx.getReference(rgStr) + val rg = ctx.references(rgStr) val b = BoolValue( KeySet(liftPosIntervalsToLocus(posBoolValue.trueBound, rg, ctx)), KeySet(liftPosIntervalsToLocus(posBoolValue.falseBound, rg, ctx)), diff --git a/hail/hail/src/is/hail/expr/ir/MatrixValue.scala b/hail/hail/src/is/hail/expr/ir/MatrixValue.scala index 2a7fc472e6a..aeeaa36a399 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixValue.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixValue.scala @@ -210,7 +210,7 @@ case class MatrixValue( ReferenceGenome.exportReferences( fs, refPath, - ReferenceGenome.getReferences(t).map(ctx.getReference(_)), + ReferenceGenome.getReferences(t).map(ctx.references(_)), ) } diff --git a/hail/hail/src/is/hail/io/plink/LoadPlink.scala b/hail/hail/src/is/hail/io/plink/LoadPlink.scala index 721843cc199..4b7b30007b4 100644 --- a/hail/hail/src/is/hail/io/plink/LoadPlink.scala +++ b/hail/hail/src/is/hail/io/plink/LoadPlink.scala @@ -187,7 +187,7 @@ object MatrixPLINKReader { implicit val formats: Formats = DefaultFormats val params = jv.extract[MatrixPLINKReaderParameters] - val referenceGenome = params.rg.map(ctx.getReference) + val referenceGenome = params.rg.map(ctx.references) referenceGenome.foreach(_.validateContigRemap(params.contigRecoding)) val locusType = TLocus.schemaFromRG(params.rg) diff --git a/hail/hail/src/is/hail/io/reference/LiftOver.scala b/hail/hail/src/is/hail/io/reference/LiftOver.scala deleted file mode 100644 index c24f81f6d5b..00000000000 --- a/hail/hail/src/is/hail/io/reference/LiftOver.scala +++ /dev/null @@ -1,69 +0,0 @@ -package is.hail.io.reference - -import is.hail.io.fs.FS -import is.hail.utils._ -import is.hail.variant.{Locus, ReferenceGenome} - -import scala.collection.JavaConverters._ - -object LiftOver { - def apply(fs: FS, chainFile: String): LiftOver = new LiftOver(fs, chainFile) -} - -class LiftOver(fs: FS, val chainFile: String) { - val lo = using(fs.open(chainFile))(new htsjdk.samtools.liftover.LiftOver(_, chainFile)) - - def queryInterval( - interval: is.hail.utils.Interval, - minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH, - ): (is.hail.utils.Interval, Boolean) = { - val start = interval.start.asInstanceOf[Locus] - val end = interval.end.asInstanceOf[Locus] - - if (start.contig != end.contig) - fatal(s"'start' and 'end' contigs must be identical. Found '$interval'.") - - val contig = start.contig - val startPos = if (interval.includesStart) start.position else start.position + 1 - val endPos = if (interval.includesEnd) end.position else end.position - 1 - - if (startPos == endPos) - fatal( - s"Cannot liftover a 0-length interval: ${interval.toString}.\nDid you mean to use 'liftover_locus'?" - ) - - val result = lo.liftOver(new htsjdk.samtools.util.Interval(contig, startPos, endPos), minMatch) - if (result != null) - ( - Interval( - Locus(result.getContig, result.getStart), - Locus(result.getContig, result.getEnd), - includesStart = true, - includesEnd = true, - ), - result.isNegativeStrand, - ) - else - null - } - - def queryLocus( - l: Locus, - minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH, - ): (Locus, Boolean) = { - val result = - lo.liftOver(new htsjdk.samtools.util.Interval(l.contig, l.position, l.position), minMatch) - if (result != null) - (Locus(result.getContig, result.getStart), result.isNegativeStrand) - else - null - } - - def checkChainFile(srcRG: ReferenceGenome, destRG: ReferenceGenome): Unit = { - val cMap = lo.getContigMap.asScala - cMap.foreach { case (srcContig, destContigs) => - srcRG.checkContig(srcContig) - destContigs.asScala.foreach(destRG.checkContig) - } - } -} diff --git a/hail/hail/src/is/hail/io/reference/package.scala b/hail/hail/src/is/hail/io/reference/package.scala new file mode 100644 index 00000000000..819863f2d01 --- /dev/null +++ b/hail/hail/src/is/hail/io/reference/package.scala @@ -0,0 +1,164 @@ +package is.hail.io + +import is.hail.io.fs.FS +import is.hail.io.reference.LiftOver.MinMatchDefault +import is.hail.utils.{fatal, toRichIterable, using, Interval} +import is.hail.variant.{Locus, ReferenceGenome} + +import scala.collection.convert.ImplicitConversions.{`collection AsScalaIterable`, `map AsScala`} +import scala.jdk.CollectionConverters.iterableAsScalaIterableConverter + +import htsjdk.samtools.reference.FastaSequenceIndexEntry + +package reference { + + /* ASSUMPTION: The following will not move or change for the entire duration of a hail pipeline: + * - chainFile + * - fastaFile + * - indexFile */ + + object LiftOver { + def apply(fs: FS, chainFile: String): LiftOver = { + val lo = new LiftOver(chainFile) + lo.restore(fs) + lo + } + + val MinMatchDefault: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH + } + + class LiftOver private (chainFile: String) extends Serializable { + + @transient var asJava: htsjdk.samtools.liftover.LiftOver = _ + + def queryInterval(interval: Interval, minMatch: Double = MinMatchDefault) + : (Interval, Boolean) = { + val start = interval.start.asInstanceOf[Locus] + val end = interval.end.asInstanceOf[Locus] + + if (start.contig != end.contig) + fatal(s"'start' and 'end' contigs must be identical. Found '$interval'.") + + val contig = start.contig + val startPos = if (interval.includesStart) start.position else start.position + 1 + val endPos = if (interval.includesEnd) end.position else end.position - 1 + + if (startPos == endPos) + fatal( + s"Cannot liftover a 0-length interval: ${interval.toString}.\nDid you mean to use 'liftover_locus'?" + ) + + val result = asJava.liftOver( + new htsjdk.samtools.util.Interval(contig, startPos, endPos), + minMatch, + ) + if (result != null) + ( + Interval( + Locus(result.getContig, result.getStart), + Locus(result.getContig, result.getEnd), + includesStart = true, + includesEnd = true, + ), + result.isNegativeStrand, + ) + else + null + } + + def queryLocus(l: Locus, minMatch: Double = MinMatchDefault): (Locus, Boolean) = { + val result = asJava.liftOver( + new htsjdk.samtools.util.Interval(l.contig, l.position, l.position), + minMatch, + ) + if (result != null) (Locus(result.getContig, result.getStart), result.isNegativeStrand) + else null + } + + def checkChainFile(srcRG: ReferenceGenome, destRG: ReferenceGenome): Unit = + asJava.getContigMap.foreach { case (srcContig, destContigs) => + srcRG.checkContig(srcContig) + destContigs.foreach(destRG.checkContig) + } + + def restore(fs: FS): Unit = { + if (!fs.isFile(chainFile)) + fatal(s"Chain file '$chainFile' does not exist, is not a file, or you do not have access.") + + using(fs.open(chainFile)) { is => + asJava = new htsjdk.samtools.liftover.LiftOver(is, chainFile) + } + } + } + + object IndexedFastaSequenceFile { + + def apply(fs: FS, fastaFile: String, indexFile: String): IndexedFastaSequenceFile = { + if (!fs.isFile(fastaFile)) + fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") + + new IndexedFastaSequenceFile(fastaFile, FastaSequenceIndex(fs, indexFile)) + } + + } + + class IndexedFastaSequenceFile private (val path: String, val index: FastaSequenceIndex) + extends Serializable { + + def raiseIfIncompatible(rg: ReferenceGenome): Unit = { + val jindex = index.asJava + + val missingContigs = rg.contigs.filterNot(jindex.hasIndexEntry) + if (missingContigs.nonEmpty) + fatal( + s"Contigs missing in FASTA '$path' that are present in reference genome '${rg.name}':\n " + + s"@1", + missingContigs.truncatable("\n "), + ) + + val invalidLengths = + for { + (contig, length) <- rg.lengths + fastaLength = jindex.getIndexEntry(contig).getSize + if fastaLength != length + } yield (contig, length, fastaLength) + + if (invalidLengths.nonEmpty) + fatal( + s"Contig sizes in FASTA '$path' do not match expected sizes for reference genome '${rg.name}':\n " + + s"@1", + invalidLengths.map { case (c, e, f) => s"$c\texpected:$e\tfound:$f" }.truncatable("\n "), + ) + } + + def restore(fs: FS): Unit = + index.restore(fs) + } + + object FastaSequenceIndex { + def apply(fs: FS, indexFile: String): FastaSequenceIndex = { + val index = new FastaSequenceIndex(indexFile) + index.restore(fs) + index + } + } + + class FastaSequenceIndex private (val path: String) + extends Iterable[FastaSequenceIndexEntry] with Serializable { + + @transient var asJava: htsjdk.samtools.reference.FastaSequenceIndex = _ + + def restore(fs: FS): Unit = { + if (!fs.isFile(path)) + fatal( + s"FASTA index file '$path' does not exist, is not a file, or you do not have access." + ) + + using(fs.open(path))(is => asJava = new htsjdk.samtools.reference.FastaSequenceIndex(is)) + } + + override def iterator: Iterator[FastaSequenceIndexEntry] = + asJava.asScala.iterator + } + +} diff --git a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala index 8cb2f464bd3..15ffe4a6ec2 100644 --- a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala +++ b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala @@ -1779,7 +1779,7 @@ object MatrixVCFReader { val backend = ctx.backend val fs = ctx.fs - val referenceGenome = params.rg.map(ctx.getReference) + val referenceGenome = params.rg.map(ctx.references) referenceGenome.foreach(_.validateContigRemap(params.contigRecoding)) @@ -2011,7 +2011,7 @@ class MatrixVCFReader( val fs = ctx.fs val sm = ctx.stateManager - val rgBc = referenceGenome.map(_.broadcast) + val rgBc = referenceGenome.map(ctx.backend.broadcast) val localArrayElementsRequired = params.arrayElementsRequired val localContigRecoding = params.contigRecoding val localSkipInvalidLoci = params.skipInvalidLoci diff --git a/hail/hail/src/is/hail/variant/ReferenceGenome.scala b/hail/hail/src/is/hail/variant/ReferenceGenome.scala index e17f49c0a64..1aa5a14d526 100644 --- a/hail/hail/src/is/hail/variant/ReferenceGenome.scala +++ b/hail/hail/src/is/hail/variant/ReferenceGenome.scala @@ -1,15 +1,16 @@ package is.hail.variant -import is.hail.HailContext import is.hail.annotations.ExtendedOrdering -import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.backend.ExecuteContext import is.hail.check.Gen import is.hail.expr.{ JSONExtractContig, JSONExtractIntervalLocus, JSONExtractReferenceGenome, Parser, } import is.hail.expr.ir.RelationalSpec import is.hail.io.fs.FS -import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} +import is.hail.io.reference.{ + FASTAReader, FASTAReaderConfig, FastaSequenceIndex, IndexedFastaSequenceFile, LiftOver, +} import is.hail.types._ import is.hail.types.virtual.{TLocus, Type} import is.hail.utils._ @@ -18,31 +19,10 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import java.io.{FileNotFoundException, InputStream} -import java.lang.ThreadLocal -import htsjdk.samtools.reference.FastaSequenceIndex -import org.apache.spark.TaskContext -import org.json4s._ +import org.json4s.{DefaultFormats, Extraction, Formats, JValue} import org.json4s.jackson.{JsonMethods, Serialization} -class BroadcastRG(rgParam: ReferenceGenome) extends Serializable { - @transient private[this] val rg: ReferenceGenome = rgParam - - private[this] val rgBc: BroadcastValue[ReferenceGenome] = - if (TaskContext.get != null) - null - else - rg.broadcast - - def value: ReferenceGenome = { - val t = if (rg != null) - rg - else - rgBc.value - t - } -} - case class ReferenceGenome( name: String, contigs: Array[String], @@ -53,7 +33,6 @@ case class ReferenceGenome( parInput: Array[(Locus, Locus)] = Array.empty[(Locus, Locus)], ) extends Serializable { - @transient lazy val broadcastRG: BroadcastRG = new BroadcastRG(this) val nContigs = contigs.length if (nContigs <= 0) @@ -165,9 +144,8 @@ case class ReferenceGenome( Interval(start, end, includesStart = true, includesEnd = false) } - private var fastaFilePath: String = _ - private var fastaIndexPath: String = _ - @transient private var fastaReaderCfg: FASTAReaderConfig = _ + private[this] var fastaFile: IndexedFastaSequenceFile = _ + @transient private[this] var fastaReaderCfg: FASTAReaderConfig = _ @transient lazy val contigParser = Parser.oneOfLiteral(contigs) @@ -365,50 +343,14 @@ case class ReferenceGenome( ) } - def hasSequence: Boolean = fastaFilePath != null + private def hasSequence: Boolean = fastaFile != null - def addSequence(ctx: ExecuteContext, fastaFile: String, indexFile: String): Unit = { + def addSequence(fasta: IndexedFastaSequenceFile): Unit = { if (hasSequence) fatal(s"FASTA sequence has already been loaded for reference genome '$name'.") - val tmpdir = ctx.localTmpdir - val fs = ctx.fs - if (!fs.isFile(fastaFile)) - fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") - if (!fs.isFile(indexFile)) - fatal( - s"FASTA index file '$indexFile' does not exist, is not a file, or you do not have access." - ) - fastaFilePath = fastaFile - fastaIndexPath = indexFile - - /* assumption, fastaFile and indexFile will not move or change for the entire duration of a hail - * pipeline */ - val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) - - val missingContigs = contigs.filterNot(index.hasIndexEntry) - if (missingContigs.nonEmpty) - fatal( - s"Contigs missing in FASTA '$fastaFile' that are present in reference genome '$name':\n " + - s"@1", - missingContigs.truncatable("\n "), - ) - - val invalidLengths = lengths.flatMap { case (c, l) => - val fastaLength = index.getIndexEntry(c).getSize - if (fastaLength != l) - Some((c, l, fastaLength)) - else - None - }.map { case (c, e, f) => s"$c\texpected:$e\tfound:$f" } - - if (invalidLengths.nonEmpty) - fatal( - s"Contig sizes in FASTA '$fastaFile' do not match expected sizes for reference genome '$name':\n " + - s"@1", - invalidLengths.truncatable("\n "), - ) - heal(tmpdir, fs) + fasta.raiseIfIncompatible(this) + fastaFile = fasta } @transient private lazy val realFastaReader: ThreadLocal[FASTAReader] = @@ -417,6 +359,7 @@ case class ReferenceGenome( private def fastaReader(): FASTAReader = { if (!hasSequence) fatal(s"FASTA file has not been loaded for reference genome '$name'.") + if (realFastaReader.get() == null) realFastaReader.set(fastaReaderCfg.reader) if (realFastaReader.get().cfg != fastaReaderCfg) @@ -436,93 +379,46 @@ case class ReferenceGenome( def removeSequence(): Unit = { if (!hasSequence) fatal(s"Reference genome '$name' does not have sequence loaded.") - fastaFilePath = null - fastaIndexPath = null + fastaFile = null fastaReaderCfg = null } - private var chainFiles: Map[String, String] = Map.empty - @transient private[this] lazy val liftoverMap: mutable.Map[String, LiftOver] = mutable.Map.empty + private[this] val liftovers: mutable.Map[String, LiftOver] = + mutable.Map.empty - def hasLiftover(destRGName: String): Boolean = chainFiles.contains(destRGName) - - def addLiftover(ctx: ExecuteContext, chainFile: String, destRGName: String): Unit = { - if (name == destRGName) + def addLiftover(destRef: ReferenceGenome, liftOver: LiftOver): Unit = { + if (name == destRef.name) fatal(s"Destination reference genome cannot have the same name as this reference '$name'") - if (hasLiftover(destRGName)) + if (liftovers.contains(destRef.name)) fatal( - s"Chain file already exists for source reference '$name' and destination reference '$destRGName'." + s"LiftOver already exists for source reference '$name' and destination reference '${destRef.name}'." ) - val tmpdir = ctx.localTmpdir - val fs = ctx.fs - - if (!fs.isFile(chainFile)) - fatal(s"Chain file '$chainFile' does not exist, is not a file, or you do not have access.") - - val chainFilePath = fs.parseUrl(chainFile).toString - val lo = LiftOver(fs, chainFilePath) - val destRG = ctx.getReference(destRGName) - lo.checkChainFile(this, destRG) - - chainFiles += destRGName -> chainFile - heal(tmpdir, fs) - } - - def getLiftover(destRGName: String): LiftOver = { - if (!hasLiftover(destRGName)) - fatal( - s"Chain file has not been loaded for source reference '$name' and destination reference '$destRGName'." - ) - liftoverMap(destRGName) + liftOver.checkChainFile(this, destRef) + liftovers += destRef.name -> liftOver } - def removeLiftover(destRGName: String): Unit = { - if (!hasLiftover(destRGName)) - fatal(s"liftover does not exist from reference genome '$name' to '$destRGName'.") - chainFiles -= destRGName - liftoverMap -= destRGName - } + def removeLiftover(destRGName: String): Unit = + liftovers -= destRGName - def liftoverLocus(destRGName: String, l: Locus, minMatch: Double): (Locus, Boolean) = { - val lo = getLiftover(destRGName) - lo.queryLocus(l, minMatch) - } + def liftoverLocus(destRGName: String, l: Locus, minMatch: Double): (Locus, Boolean) = + liftovers(destRGName).queryLocus(l, minMatch) def liftoverLocusInterval(destRGName: String, interval: Interval, minMatch: Double) - : (Interval, Boolean) = { - val lo = getLiftover(destRGName) - lo.queryInterval(interval, minMatch) - } + : (Interval, Boolean) = + liftovers(destRGName).queryInterval(interval, minMatch) def heal(tmpdir: String, fs: FS): Unit = synchronized { - // Add liftovers - // NOTE: it shouldn't be possible for the liftover map to have more elements than the chain file - // since removeLiftover updates both maps, so we don't check to see if liftoverMap has - // keys that are not in chainFiles - for ((destRGName, chainFile) <- chainFiles) { - val chainFilePath = fs.parseUrl(chainFile).toString - liftoverMap.get(destRGName) match { - case Some(lo) if lo.chainFile == chainFilePath => // do nothing - case _ => liftoverMap += destRGName -> LiftOver(fs, chainFilePath) - } - } - + liftovers.values.foreach(_.restore(fs)) // add sequence - if (fastaFilePath != null) { - val fastaPath = fs.parseUrl(fastaFilePath).toString - val indexPath = fs.parseUrl(fastaIndexPath).toString - if ( - fastaReaderCfg == null || fastaReaderCfg.fastaFile != fastaPath || fastaReaderCfg.indexFile != indexPath - ) { - fastaReaderCfg = FASTAReaderConfig(tmpdir, fs, this, fastaPath, indexPath) - } + if (fastaFile != null) { + fastaFile.restore(fs) + val fastaPath = fs.parseUrl(fastaFile.path).toString + val indexPath = fs.parseUrl(fastaFile.index.path).toString + fastaReaderCfg = FASTAReaderConfig(tmpdir, fs, this, fastaPath, indexPath) } } - @transient lazy val broadcast: BroadcastValue[ReferenceGenome] = - HailContext.backend.broadcast(this) - override def hashCode: Int = { import org.apache.commons.lang3.builder.HashCodeBuilder @@ -555,9 +451,14 @@ case class ReferenceGenome( override def toString: String = name + @transient implicit private[this] val fmts: Formats = DefaultFormats + def write(fs: is.hail.io.fs.FS, file: String): Unit = - using(fs.create(file)) { out => - val jrg = JSONExtractReferenceGenome( + using(fs.create(file))(out => Serialization.write(toJSON, out)) + + def toJSON: JValue = + Extraction.decompose( + JSONExtractReferenceGenome( name, contigs.map(contig => JSONExtractContig(contig, contigLength(contig))), xContigs, @@ -567,23 +468,8 @@ case class ReferenceGenome( JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus]) ), ) - implicit val formats: Formats = defaultJSONFormats - Serialization.write(jrg, out) - } + ) - def toJSON: JSONExtractReferenceGenome = JSONExtractReferenceGenome( - name, - contigs.map(contig => JSONExtractContig(contig, contigLength(contig))), - xContigs, - yContigs, - mtContigs, - par.map(i => JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus])), - ) - - def toJSONString: String = { - implicit val formats: Formats = defaultJSONFormats - Serialization.write(toJSON) - } } object ReferenceGenome { @@ -644,17 +530,11 @@ object ReferenceGenome { if (!fs.isFile(fastaFile)) fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") - if (!fs.isFile(indexFile)) - fatal( - s"FASTA index file '$indexFile' does not exist, is not a file, or you do not have access." - ) - - val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val contigs = new BoxedArrayBuilder[String] val lengths = new BoxedArrayBuilder[(String, Int)] - index.iterator().asScala.foreach { entry => + FastaSequenceIndex(fs, indexFile).foreach { entry => val contig = entry.getContig val length = entry.getSize contigs += contig diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index f58a1a74053..e986ee5dc6d 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.Row import org.scalatestplus.testng.TestNGSuite import org.testng.ITestContext import org.testng.annotations.{AfterMethod, BeforeClass, BeforeMethod} -import sourcecode.Enclosing object HailSuite { val theHailClassLoader = TestUtils.theHailClassLoader @@ -100,9 +99,6 @@ class HailSuite extends TestNGSuite { throw new RuntimeException(s"method stopped spark context!") } - def withExecuteContext[T]()(f: ExecuteContext => T)(implicit E: Enclosing): T = - hc.sparkBackend("HailSuite.withExecuteContext").withExecuteContext(f) - def assertEvalsTo( x: IR, env: Env[(Any, Type)], diff --git a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala index d954822f0ba..81a2a82e3a7 100644 --- a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala @@ -55,7 +55,7 @@ class UnsafeSuite extends HailSuite { @DataProvider(name = "codecs") def codecs(): Array[Array[Any]] = - withExecuteContext()(ctx => codecs(ctx)) + ExecuteContext.scoped(ctx => codecs(ctx)) def codecs(ctx: ExecuteContext): Array[Array[Any]] = (BufferSpec.specs ++ Array(TypedCodecSpec( diff --git a/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala index 079e16ec965..145a70cd76e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala @@ -36,7 +36,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => def wrappedIntervalEndpoint(x: Any, sign: Int) = IntervalEndpoint(Row(x), sign) - def grch38: ReferenceGenome = ctx.getReference(ReferenceGenome.GRCh38) + def grch38: ReferenceGenome = ctx.references(ReferenceGenome.GRCh38) def lt(l: IR, r: IR): IR = ApplyComparisonOp(LT(l.typ), l, r) def gt(l: IR, r: IR): IR = ApplyComparisonOp(GT(l.typ), l, r) diff --git a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala index 54b085d09c7..5dea1d2ab26 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala @@ -3251,7 +3251,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "valueIRs") def valueIRs(): Array[Array[Object]] = - withExecuteContext()(ctx => valueIRs(ctx)) + ExecuteContext.scoped(ctx => valueIRs(ctx)) def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { val fs = ctx.fs @@ -3596,7 +3596,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "tableIRs") def tableIRs(): Array[Array[TableIR]] = - withExecuteContext()(ctx => tableIRs(ctx)) + ExecuteContext.scoped(ctx => tableIRs(ctx)) def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { try { @@ -3705,7 +3705,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "matrixIRs") def matrixIRs(): Array[Array[MatrixIR]] = - withExecuteContext()(ctx => matrixIRs(ctx)) + ExecuteContext.scoped(ctx => matrixIRs(ctx)) def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { try { diff --git a/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala index 0745d8ce14a..eeaa8df06e4 100644 --- a/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala @@ -13,7 +13,7 @@ class LocusFunctionsSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - private def grch38: ReferenceGenome = ctx.getReference(ReferenceGenome.GRCh38) + private def grch38: ReferenceGenome = ctx.references(ReferenceGenome.GRCh38) private def tlocus = TLocus(grch38.name) private def tvariant = TStruct("locus" -> tlocus, "alleles" -> TArray(TString)) @@ -83,8 +83,8 @@ class LocusFunctionsSuite extends HailSuite { assertEvalsTo( ir, Row( - Locus("1", 1, ctx.getReference(ReferenceGenome.GRCh37)), - Locus("chr1", 1, ctx.getReference(ReferenceGenome.GRCh38)), + Locus("1", 1, ctx.references(ReferenceGenome.GRCh37)), + Locus("chr1", 1, ctx.references(ReferenceGenome.GRCh38)), ), ) } diff --git a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala index 63f516cc7f0..e62319bc968 100644 --- a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala @@ -1,6 +1,7 @@ package is.hail.expr.ir import is.hail.HailSuite +import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.agg.CallStatsState import is.hail.expr.ir.defs._ @@ -103,7 +104,7 @@ class RequirednessSuite extends HailSuite { def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) @DataProvider(name = "valueIR") - def valueIR(): Array[Array[Any]] = withExecuteContext() { ctx => + def valueIR(): Array[Array[Any]] = ExecuteContext.scoped { ctx => val nodes = new BoxedArrayBuilder[Array[Any]](50) val allRequired = Array( diff --git a/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala index be15fd59677..81145a20bca 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.annotations.{Region, SafeIndexedSeq} import is.hail.asm4s._ +import is.hail.backend.ExecuteContext import is.hail.check.Gen import is.hail.check.Prop.forAll import is.hail.expr.ir.functions.LocusFunctions @@ -17,46 +18,19 @@ import is.hail.variant.{Locus, ReferenceGenome} import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} import org.testng.annotations.Test -sealed trait StagedCoercions[A] { - def ti: TypeInfo[A] - def sType: SType - def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[A]): SValue - def toType(cb: EmitCodeBuilder, sa: SValue): Value[A] -} +class StagedMinHeapSuite extends HailSuite { -sealed trait StagedCoercionInstances { implicit object StagedIntCoercions extends StagedCoercions[Int] { - override def ti: TypeInfo[Int] = - implicitly - - override def sType: SType = - SInt32 + override def ti: TypeInfo[Int] = implicitly + override def sType: SType = SInt32 - override def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[Int]): SValue = + override def fromValue(cb: EmitCodeBuilder, region: Value[Region], a: Value[Int]): SValue = new SInt32Value(a) - override def toType(cb: EmitCodeBuilder, sa: SValue): Value[Int] = + override def toValue(cb: EmitCodeBuilder, sa: SValue): Value[Int] = sa.asInt.value } - def stagedLocusCoercions(rg: ReferenceGenome): StagedCoercions[Locus] = - new StagedCoercions[Locus] { - override def ti: TypeInfo[Locus] = - implicitly - - override def sType: SType = - PCanonicalLocus(rg.name, required = true).sType - - override def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[Locus]): SValue = - LocusFunctions.emitLocus(cb, region, a, sType.storageType().asInstanceOf[PCanonicalLocus]) - - override def toType(cb: EmitCodeBuilder, sa: SValue): Value[Locus] = - sa.asLocus.getLocusObj(cb) - } -} - -class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { - @Test def testSorting(): Unit = forAll((xs: IndexedSeq[Int]) => sort(xs) == xs.sorted).check() @@ -70,7 +44,7 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { }.check() @Test def testNonEmpty(): Unit = - gen("NonEmpty") { (heap: IntHeap) => + gen(ctx, "NonEmpty") { (heap: IntHeap) => heap.nonEmpty should be(false) for (i <- 0 to 10) heap.push(i) heap.nonEmpty should be(true) @@ -86,10 +60,12 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { @Test def testLocus(): Unit = forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) => - withReferenceGenome(rg) { + ctx.local(references = Map(rg.name -> rg)) { ctx => + implicit val coercions: StagedCoercions[Locus] = + stagedLocusCoercions(rg) val sortedLoci = - gen("Locus", stagedLocusCoercions(rg)) { (heap: LocusHeap) => + gen[Locus, LocusHeap, IndexedSeq[Locus]](ctx, "Locus") { (heap: LocusHeap) => loci.foreach(heap.push) IndexedSeq.fill(loci.size)(heap.pop()) } @@ -98,20 +74,14 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { } }.check() - def withReferenceGenome[A](rg: ReferenceGenome)(f: => A): A = { - ctx.backend.references += (rg.name -> rg) - try f - finally ctx.backend.references.remove(rg.name) - } - def sort(xs: IndexedSeq[Int]): IndexedSeq[Int] = - gen("Sort") { (heap: IntHeap) => + gen(ctx, "Sort") { (heap: IntHeap) => xs.foreach(heap.push) IndexedSeq.fill(xs.size)(heap.pop()) } def heapify(xs: IndexedSeq[Int]): IndexedSeq[Int] = - gen("Heapify") { (heap: IntHeap) => + gen(ctx, "Heapify") { (heap: IntHeap) => pool.scopedRegion { r => xs.foreach(heap.push) val ptr = heap.toArray(r) @@ -119,17 +89,9 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { } } - def gen[H <: Heap[A], A, B]( + def gen[A, H <: Heap[A], B]( + ctx: ExecuteContext, name: String, - A: StagedCoercions[A], - )( - f: H => B - )(implicit H: TypeInfo[H] - ): B = - gen[H, A, B](name)(f)(H, A) - - def gen[H <: Heap[A], A, B]( - name: String )( f: H => B )(implicit @@ -146,13 +108,13 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { Main.defineEmitMethod("push", FastSeq(A.ti), UnitInfo) { mb => mb.voidWithBuilder { cb => - MinHeap.push(cb, A.fromType(cb, Main.partitionRegion, mb.getCodeParam[A](1)(A.ti))) + MinHeap.push(cb, A.fromValue(cb, Main.partitionRegion, mb.getCodeParam[A](1)(A.ti))) } } Main.defineEmitMethod("pop", FastSeq(), A.ti) { mb => mb.emitWithBuilder[A] { cb => - val res = A.toType(cb, MinHeap.peek(cb)) + val res = A.toValue(cb, MinHeap.peek(cb)) MinHeap.pop(cb) MinHeap.realloc(cb) res @@ -182,11 +144,11 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { MinHeap.init(cb, Main.pool()) } } - Main.defineEmitMethod("close", FastSeq(), UnitInfo)(mb => mb.voidWithBuilder(MinHeap.close)) + Main.defineEmitMethod("close", FastSeq(), UnitInfo)(_.voidWithBuilder(MinHeap.close)) - pool.scopedRegion { r => + ctx.scopedExecution { (cl, fs, tc, r) => val heap = Main - .resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) + .resultWithIndex()(cl, fs, tc, r) .asInstanceOf[H with Resource] heap.init() @@ -208,4 +170,23 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { def nonEmpty: Boolean def toArray(r: Region): Long } + + sealed trait StagedCoercions[A] { + def ti: TypeInfo[A] + def sType: SType + def fromValue(cb: EmitCodeBuilder, region: Value[Region], a: Value[A]): SValue + def toValue(cb: EmitCodeBuilder, sa: SValue): Value[A] + } + + def stagedLocusCoercions(rg: ReferenceGenome): StagedCoercions[Locus] = + new StagedCoercions[Locus] { + override def ti: TypeInfo[Locus] = implicitly + override def sType: SType = PCanonicalLocus(rg.name, required = true).sType + + override def fromValue(cb: EmitCodeBuilder, region: Value[Region], a: Value[Locus]): SValue = + LocusFunctions.emitLocus(cb, region, a, sType.storageType().asInstanceOf[PCanonicalLocus]) + + override def toValue(cb: EmitCodeBuilder, sa: SValue): Value[Locus] = + sa.asLocus.getLocusObj(cb) + } } diff --git a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index d4617313034..6bd31b6d4b4 100644 --- a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -62,10 +62,8 @@ class LowerDistributedSortSuite extends HailSuite { } // Only does ascending for now - def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = { - val originalShuffleCutoff = backend.flags.get("shuffle_cutoff_to_local_sort") - try { - backend.flags.set("shuffle_cutoff_to_local_sort", "40") + def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = + ctx.local(flags = ctx.flags + ("shuffle_cutoff_to_local_sort" -> "40")) { ctx => val analyses: LoweringAnalyses = LoweringAnalyses.apply(myTable, ctx) val rt = analyses.requirednessAnalysis.lookup(myTable).asInstanceOf[RTable] val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) @@ -105,9 +103,7 @@ class LowerDistributedSortSuite extends HailSuite { ans } assert(res == scalaSorted) - } finally - backend.flags.set("shuffle_cutoff_to_local_sort", originalShuffleCutoff) - } + } @Test def testDistributedSort(): Unit = { val tableRange = TableRange(100, 10) diff --git a/hail/hail/test/src/is/hail/utils/FlipbookIteratorSuite.scala b/hail/hail/test/src/is/hail/utils/FlipbookIteratorSuite.scala index 85a33e681db..7461984108d 100644 --- a/hail/hail/test/src/is/hail/utils/FlipbookIteratorSuite.scala +++ b/hail/hail/test/src/is/hail/utils/FlipbookIteratorSuite.scala @@ -4,48 +4,47 @@ import is.hail.HailSuite import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe._ import org.testng.annotations.Test class FlipbookIteratorSuite extends HailSuite { - class Box[A] extends AnyRef { + class Box[A: TypeTag] extends AnyRef { + def typ = typeOf[A] var value: A = _ - def canEqual(a: Any): Boolean = a.isInstanceOf[Box[A]] - override def equals(that: Any): Boolean = that match { - case that: Box[A] => value == that.value + case that: Box[_] if typ == that.typ => value == that.value.asInstanceOf[A] case _ => false } - } object Box { - def apply[A](): Box[A] = new Box + def apply[A: TypeTag]: Box[A] = new Box[A]() - def apply[A](a: A): Box[A] = { - val box = Box[A]() + def apply[A: TypeTag](a: A): Box[A] = { + val box = Box[A] box.value = a box } } - def boxOrdView[A](implicit ord: Ordering[A]): OrderingView[Box[A]] = new OrderingView[Box[A]] { + def boxOrdView[A: Ordering]: OrderingView[Box[A]] = new OrderingView[Box[A]] { var value: A = _ def setFiniteValue(a: Box[A]): Unit = value = a.value def compareFinite(a: Box[A]): Int = - ord.compare(value, a.value) + implicitly[Ordering[A]].compare(value, a.value) } - def boxBuffer[A]: Growable[Box[A]] with Iterable[Box[A]] = + def boxBuffer[A: TypeTag]: Growable[Box[A]] with Iterable[Box[A]] = new Growable[Box[A]] with Iterable[Box[A]] { val buf = ArrayBuffer[A]() - val box = Box[A]() + val box = Box[A] def clear(): Unit = buf.clear() def +=(x: Box[A]) = { @@ -72,10 +71,10 @@ class FlipbookIteratorSuite extends HailSuite { else l.value - r.value } - def makeTestIterator[A](elems: A*): StagingIterator[Box[A]] = { + def makeTestIterator[A: TypeTag](elems: A*): StagingIterator[Box[A]] = { val it = elems.iterator val sm = new StateMachine[Box[A]] { - val value: Box[A] = Box() + val value: Box[A] = Box[A] var isValid = true def advance(): Unit = if (it.hasNext) diff --git a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala index 56027cbb1ac..0816b4afa89 100644 --- a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala +++ b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala @@ -6,7 +6,7 @@ import is.hail.utils._ import org.testng.annotations.Test class LocusIntervalSuite extends HailSuite { - def rg = ctx.getReference(ReferenceGenome.GRCh37) + def rg = ctx.references(ReferenceGenome.GRCh37) @Test def testParser(): Unit = { val xMax = rg.contigLength("X") @@ -224,8 +224,8 @@ class LocusIntervalSuite extends HailSuite { Locus.parseInterval("1:1.1111111M-2M", rg) } - val gr37 = ctx.getReference(ReferenceGenome.GRCh37) - val gr38 = ctx.getReference(ReferenceGenome.GRCh38) + val gr37 = ctx.references(ReferenceGenome.GRCh37) + val gr38 = ctx.references(ReferenceGenome.GRCh38) val x = "[GL000197.1:3739-GL000202.1:7538)" assert(Locus.parseInterval(x, gr37) == diff --git a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala index 101b12b5ff6..693a56c6061 100644 --- a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,11 +1,11 @@ package is.hail.variant import is.hail.{HailSuite, TestUtils} -import is.hail.backend.HailStateManager +import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder -import is.hail.io.reference.{FASTAReader, FASTAReaderConfig} +import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} import is.hail.types.virtual.TLocus import is.hail.utils._ @@ -16,7 +16,7 @@ class ReferenceGenomeSuite extends HailSuite { def hasReference(name: String) = ctx.stateManager.referenceGenomes.contains(name) - def getReference(name: String) = ctx.getReference(name) + def getReference(name: String) = ctx.references(name) @Test def testGRCh37(): Unit = { assert(hasReference(ReferenceGenome.GRCh37)) @@ -222,8 +222,8 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeOnFB(): Unit = { - withExecuteContext() { ctx => - val grch38 = ctx.getReference(ReferenceGenome.GRCh38) + ExecuteContext.scoped { ctx => + val grch38 = ctx.references(ReferenceGenome.GRCh38) val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") val rgfield = fb.getReferenceGenome(grch38.name) fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) @@ -234,11 +234,11 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeWithLiftoverOnFB(): Unit = { - withExecuteContext() { ctx => - val grch37 = ctx.getReference(ReferenceGenome.GRCh37) + ExecuteContext.scoped { ctx => + val grch37 = ctx.references(ReferenceGenome.GRCh37) val liftoverFile = getTestResource("grch37_to_grch38_chr20.over.chain.gz") - grch37.addLiftover(ctx, liftoverFile, "GRCh38") + grch37.addLiftover(ctx.references("GRCh38"), LiftOver(ctx.fs, liftoverFile)) val fb = EmitFunctionBuilder[String, Locus, Double, (Locus, Boolean)](ctx, "serialize_with_liftover")