Skip to content

Commit

Permalink
[query] Remove BlockMatrix persist from Backend interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 29, 2025
1 parent bd12bb0 commit 387bde2
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 96 deletions.
12 changes: 1 addition & 11 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
import is.hail.io.plink.LoadPlink
import is.hail.io.vcf.LoadVCF
import is.hail.linalg.BlockMatrix
import is.hail.types._
import is.hail.types.encoded.EType
import is.hail.types.physical.PTuple
import is.hail.types.virtual.{BlockMatrixType, TFloat64}
import is.hail.types.virtual.TFloat64
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -77,15 +76,6 @@ abstract class Backend extends Closeable {

def broadcast[T: ClassTag](value: T): BroadcastValue[T]

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit

def unpersist(backendContext: BackendContext, id: String): Unit

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType

def parallelizeAndComputeWithIndex(
backendContext: BackendContext,
fs: FS,
Expand Down
6 changes: 6 additions & 0 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -71,6 +72,7 @@ object ExecuteContext {
flags: HailFeatureFlags,
backendContext: BackendContext,
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -89,6 +91,7 @@ object ExecuteContext {
flags,
backendContext,
irMetadata,
blockMatrixCache,
))(f(_))
}
}
Expand Down Expand Up @@ -118,6 +121,7 @@ class ExecuteContext(
val flags: HailFeatureFlags,
val backendContext: BackendContext,
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -184,6 +188,7 @@ class ExecuteContext(
flags: HailFeatureFlags = this.flags,
backendContext: BackendContext = this.backendContext,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -200,5 +205,6 @@ class ExecuteContext(
flags,
backendContext,
irMetadata,
blockMatrixCache,
))(f)
}
13 changes: 2 additions & 11 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.types._
import is.hail.types.physical.PTuple
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
import is.hail.types.virtual.{BlockMatrixType, TVoid}
import is.hail.types.virtual.TVoid
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -114,6 +113,7 @@ class LocalBackend(
ExecutionCache.fromFlags(flags, fs, tmpdir)
},
new IrMetadata(),
ImmutableMap.empty,
)(f)
}

Expand Down Expand Up @@ -216,15 +216,6 @@ class LocalBackend(
): TableReader =
LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt, nPartitions)

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit = ???

def unpersist(backendContext: BackendContext, id: String): Unit = ???

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ???

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ???

def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage =
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)
Expand Down
11 changes: 1 addition & 10 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
import is.hail.types._
Expand Down Expand Up @@ -375,15 +374,6 @@ class ServiceBackend(
): TableReader =
LowerDistributedSort.distributedSort(ctx, inputStage, sortFields, rt, nPartitions)

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit = ???

def unpersist(backendContext: BackendContext, id: String): Unit = ???

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ???

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ???

def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage =
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)
Expand All @@ -402,6 +392,7 @@ class ServiceBackend(
flags,
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
)(f)
}

Expand Down
23 changes: 6 additions & 17 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
Expand All @@ -12,7 +13,6 @@ import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.lowering._
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.rvd.RVD
import is.hail.types._
import is.hail.types.physical.{PStruct, PTuple}
Expand All @@ -26,10 +26,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionException
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import java.io.PrintWriter

import com.fasterxml.jackson.core.StreamReadConstraints
import is.hail.linalg.BlockMatrix
import org.apache.hadoop
import org.apache.hadoop.conf.Configuration
import org.apache.spark._
Expand Down Expand Up @@ -352,20 +351,7 @@ class SparkBackend(
override val longLifeTempFileManager: TempFileManager =
new OwningTempFileManager(fs)

val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache()

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit = bmCache.persistBlockMatrix(id, value, storageLevel)

def unpersist(backendContext: BackendContext, id: String): Unit = unpersist(id)

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix =
bmCache.getPersistedBlockMatrix(id)

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType =
bmCache.getPersistedBlockMatrixType(id)

def unpersist(id: String): Unit = bmCache.unpersistBlockMatrix(id)
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -388,6 +374,7 @@ class SparkBackend(
ExecutionCache.forTesting
},
new IrMetadata(),
ImmutableMap.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -407,6 +394,7 @@ class SparkBackend(
ExecutionCache.fromFlags(flags, fs, tmpdir)
},
new IrMetadata(),
bmCache,
)(f)
}

Expand Down Expand Up @@ -471,6 +459,7 @@ class SparkBackend(
override def asSpark(op: String): SparkBackend = this

def close(): Unit = {
bmCache.values.foreach(_.unpersist())
SparkBackend.stop()
longLifeTempFileManager.close()
}
Expand Down
25 changes: 0 additions & 25 deletions hail/hail/src/is/hail/backend/spark/SparkBlockMatrixCache.scala

This file was deleted.

13 changes: 5 additions & 8 deletions hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.annotations.NDArray
import is.hail.backend.{BackendContext, ExecuteContext}
import is.hail.backend.ExecuteContext
import is.hail.expr.Nat
import is.hail.expr.ir.defs._
import is.hail.expr.ir.lowering.{BMSContexts, BlockMatrixStage2, LowererUnsupportedOperation}
Expand Down Expand Up @@ -107,7 +106,7 @@ object BlockMatrixReader {
def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixReader =
(jv \ "name").extract[String] match {
case "BlockMatrixNativeReader" => BlockMatrixNativeReader.fromJValue(ctx.fs, jv)
case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx.backendContext, jv)
case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx, jv)
case _ => jv.extract[BlockMatrixReader]
}
}
Expand Down Expand Up @@ -275,22 +274,20 @@ case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockS
case class BlockMatrixNativePersistParameters(id: String)

object BlockMatrixPersistReader {
def fromJValue(ctx: BackendContext, jv: JValue): BlockMatrixPersistReader = {
def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixPersistReader = {
implicit val formats: Formats = BlockMatrixReader.formats
val params = jv.extract[BlockMatrixNativePersistParameters]
BlockMatrixPersistReader(
params.id,
HailContext.backend.getPersistedBlockMatrixType(ctx, params.id),
BlockMatrixType.fromBlockMatrix(ctx.BlockMatrixCache(params.id)),
)
}
}

case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends BlockMatrixReader {
def pathsUsed: Seq[String] = FastSeq()
lazy val fullType: BlockMatrixType = typ

def apply(ctx: ExecuteContext): BlockMatrix =
HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id)
def apply(ctx: ExecuteContext): BlockMatrix = ctx.BlockMatrixCache(id)
}

case class BlockMatrixMap(child: BlockMatrixIR, eltName: Name, f: IR, needsDense: Boolean)
Expand Down
3 changes: 1 addition & 2 deletions hail/hail/src/is/hail/expr/ir/BlockMatrixWriter.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
Expand Down Expand Up @@ -191,7 +190,7 @@ case class BlockMatrixPersistWriter(id: String, storageLevel: String) extends Bl
def pathOpt: Option[String] = None

def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit =
HailContext.backend.persist(ctx.backendContext, id, bm, storageLevel)
ctx.BlockMatrixCache += id -> bm.persist(storageLevel)

def loweredTyp: Type = TVoid
}
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/utils/ErrorHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class HailWorkerException(
trait ErrorHandling {
def fatal(msg: String): Nothing = throw new HailException(msg)

def fatal(msg: String, errorId: Int) = throw new HailException(msg, errorId)
def fatal(msg: String, errorId: Int): Nothing = throw new HailException(msg, errorId)

def fatal(msg: String, cause: Throwable): Nothing = throw new HailException(msg, None, cause)

Expand Down
17 changes: 17 additions & 0 deletions hail/hail/src/is/hail/utils/ImmutableMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package is.hail.utils

import scala.collection.mutable

object ImmutableMap {

private[this] object EmptyInstance extends mutable.AbstractMap[Any, Any] {
override def +=(kv: (Any, Any)): EmptyInstance.this.type = this
override def -=(key: Any): EmptyInstance.this.type = this
override def get(key: Any): Option[Any] = None
override def iterator: Iterator[(Any, Any)] = Iterator.empty
override def getOrElseUpdate(key: Any, op: => Any): Any = op
}

def empty[K, V]: mutable.Map[K, V] =
EmptyInstance.asInstanceOf[mutable.Map[K, V]]
}
31 changes: 20 additions & 11 deletions hail/hail/test/src/is/hail/expr/ir/IRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import is.hail.ExecStrategy.ExecStrategy
import is.hail.TestUtils._
import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray}
import is.hail.backend.ExecuteContext
import is.hail.backend.caching.BlockMatrixCache
import is.hail.expr.Nat
import is.hail.expr.ir.agg._
import is.hail.expr.ir.defs._
import is.hail.expr.ir.defs.ArrayZipBehavior.ArrayZipBehavior
import is.hail.expr.ir.functions._
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.bgen.MatrixBGENReader
import is.hail.linalg.BlockMatrix
import is.hail.methods._
import is.hail.rvd.{PartitionBoundOrdering, RVD, RVDPartitioner}
import is.hail.types.{tcoerce, VirtualTypeWithReq}
Expand Down Expand Up @@ -3907,17 +3907,26 @@ class IRSuite extends HailSuite {
assert(x2 == x)
}

def testBlockMatrixIRParserPersist(): Unit = {
val bm = BlockMatrix.fill(1, 1, 0.0, 5)
backend.persist(ctx.backendContext, "x", bm, "MEMORY_ONLY")
val persist =
BlockMatrixRead(BlockMatrixPersistReader("x", BlockMatrixType.fromBlockMatrix(bm)))
@Test def testBlockMatrixIRParserPersist(): Unit =
using(new BlockMatrixCache()) { cache =>
val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3)

val s = Pretty.sexprStyle(persist, elideLiterals = false)
val x2 = IRParser.parse_blockmatrix_ir(ctx, s)
assert(x2 == persist)
backend.unpersist(ctx.backendContext, "x")
}
backend.withExecuteContext { ctx =>
ctx.local(blockMatrixCache = cache) { ctx =>
backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY")))
}
}

backend.withExecuteContext { ctx =>
ctx.local(blockMatrixCache = cache) { ctx =>
val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ))

val s = Pretty.sexprStyle(persist, elideLiterals = false)
val x2 = IRParser.parse_blockmatrix_ir(ctx, s)
assert(x2 == persist)
}
}
}

@Test def testCachedIR(): Unit = {
val cached = Literal(TSet(TInt32), Set(1))
Expand Down

0 comments on commit 387bde2

Please sign in to comment.