Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Remove BlockMatrix persist from Backend interface #14690

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
import is.hail.io.plink.LoadPlink
import is.hail.io.vcf.LoadVCF
import is.hail.linalg.BlockMatrix
import is.hail.types._
import is.hail.types.encoded.EType
import is.hail.types.physical.PTuple
import is.hail.types.virtual.{BlockMatrixType, TFloat64}
import is.hail.types.virtual.TFloat64
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -77,15 +76,6 @@ abstract class Backend extends Closeable {

def broadcast[T: ClassTag](value: T): BroadcastValue[T]

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit

def unpersist(backendContext: BackendContext, id: String): Unit

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType

def parallelizeAndComputeWithIndex(
backendContext: BackendContext,
fs: FS,
Expand Down
6 changes: 6 additions & 0 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -71,6 +72,7 @@ object ExecuteContext {
flags: HailFeatureFlags,
backendContext: BackendContext,
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -89,6 +91,7 @@ object ExecuteContext {
flags,
backendContext,
irMetadata,
blockMatrixCache,
))(f(_))
}
}
Expand Down Expand Up @@ -118,6 +121,7 @@ class ExecuteContext(
val flags: HailFeatureFlags,
val backendContext: BackendContext,
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -184,6 +188,7 @@ class ExecuteContext(
flags: HailFeatureFlags = this.flags,
backendContext: BackendContext = this.backendContext,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -200,5 +205,6 @@ class ExecuteContext(
flags,
backendContext,
irMetadata,
blockMatrixCache,
))(f)
}
13 changes: 2 additions & 11 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.types._
import is.hail.types.physical.PTuple
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
import is.hail.types.virtual.{BlockMatrixType, TVoid}
import is.hail.types.virtual.TVoid
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -114,6 +113,7 @@ class LocalBackend(
ExecutionCache.fromFlags(flags, fs, tmpdir)
},
new IrMetadata(),
ImmutableMap.empty,
)(f)
}

Expand Down Expand Up @@ -216,15 +216,6 @@ class LocalBackend(
): TableReader =
LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt, nPartitions)

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit = ???

def unpersist(backendContext: BackendContext, id: String): Unit = ???

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ???

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ???

def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage =
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)
Expand Down
11 changes: 1 addition & 10 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
import is.hail.types._
Expand Down Expand Up @@ -375,15 +374,6 @@ class ServiceBackend(
): TableReader =
LowerDistributedSort.distributedSort(ctx, inputStage, sortFields, rt, nPartitions)

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit = ???

def unpersist(backendContext: BackendContext, id: String): Unit = ???

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ???

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ???

def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage =
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)
Expand All @@ -402,6 +392,7 @@ class ServiceBackend(
flags,
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
)(f)
}

Expand Down
18 changes: 4 additions & 14 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,20 +352,7 @@ class SparkBackend(
override val longLifeTempFileManager: TempFileManager =
new OwningTempFileManager(fs)

val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache()

def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String)
: Unit = bmCache.persistBlockMatrix(id, value, storageLevel)

def unpersist(backendContext: BackendContext, id: String): Unit = unpersist(id)

def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix =
bmCache.getPersistedBlockMatrix(id)

def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType =
bmCache.getPersistedBlockMatrixType(id)

def unpersist(id: String): Unit = bmCache.unpersistBlockMatrix(id)
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -388,6 +375,7 @@ class SparkBackend(
ExecutionCache.forTesting
},
new IrMetadata(),
ImmutableMap.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -407,6 +395,7 @@ class SparkBackend(
ExecutionCache.fromFlags(flags, fs, tmpdir)
},
new IrMetadata(),
bmCache,
)(f)
}

Expand Down Expand Up @@ -471,6 +460,7 @@ class SparkBackend(
override def asSpark(op: String): SparkBackend = this

def close(): Unit = {
bmCache.values.foreach(_.unpersist())
SparkBackend.stop()
longLifeTempFileManager.close()
}
Expand Down
25 changes: 0 additions & 25 deletions hail/hail/src/is/hail/backend/spark/SparkBlockMatrixCache.scala

This file was deleted.

13 changes: 5 additions & 8 deletions hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.annotations.NDArray
import is.hail.backend.{BackendContext, ExecuteContext}
import is.hail.backend.ExecuteContext
import is.hail.expr.Nat
import is.hail.expr.ir.defs._
import is.hail.expr.ir.lowering.{BMSContexts, BlockMatrixStage2, LowererUnsupportedOperation}
Expand Down Expand Up @@ -107,7 +106,7 @@ object BlockMatrixReader {
def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixReader =
(jv \ "name").extract[String] match {
case "BlockMatrixNativeReader" => BlockMatrixNativeReader.fromJValue(ctx.fs, jv)
case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx.backendContext, jv)
case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx, jv)
case _ => jv.extract[BlockMatrixReader]
}
}
Expand Down Expand Up @@ -275,22 +274,20 @@ case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockS
case class BlockMatrixNativePersistParameters(id: String)

object BlockMatrixPersistReader {
def fromJValue(ctx: BackendContext, jv: JValue): BlockMatrixPersistReader = {
def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixPersistReader = {
implicit val formats: Formats = BlockMatrixReader.formats
val params = jv.extract[BlockMatrixNativePersistParameters]
BlockMatrixPersistReader(
params.id,
HailContext.backend.getPersistedBlockMatrixType(ctx, params.id),
BlockMatrixType.fromBlockMatrix(ctx.BlockMatrixCache(params.id)),
)
}
}

case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends BlockMatrixReader {
def pathsUsed: Seq[String] = FastSeq()
lazy val fullType: BlockMatrixType = typ

def apply(ctx: ExecuteContext): BlockMatrix =
HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id)
def apply(ctx: ExecuteContext): BlockMatrix = ctx.BlockMatrixCache(id)
}

case class BlockMatrixMap(child: BlockMatrixIR, eltName: Name, f: IR, needsDense: Boolean)
Expand Down
3 changes: 1 addition & 2 deletions hail/hail/src/is/hail/expr/ir/BlockMatrixWriter.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
Expand Down Expand Up @@ -191,7 +190,7 @@ case class BlockMatrixPersistWriter(id: String, storageLevel: String) extends Bl
def pathOpt: Option[String] = None

def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit =
HailContext.backend.persist(ctx.backendContext, id, bm, storageLevel)
ctx.BlockMatrixCache += id -> bm.persist(storageLevel)

def loweredTyp: Type = TVoid
}
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/utils/ErrorHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class HailWorkerException(
trait ErrorHandling {
def fatal(msg: String): Nothing = throw new HailException(msg)

def fatal(msg: String, errorId: Int) = throw new HailException(msg, errorId)
def fatal(msg: String, errorId: Int): Nothing = throw new HailException(msg, errorId)

def fatal(msg: String, cause: Throwable): Nothing = throw new HailException(msg, None, cause)

Expand Down
17 changes: 17 additions & 0 deletions hail/hail/src/is/hail/utils/ImmutableMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package is.hail.utils

import scala.collection.mutable

object ImmutableMap {

private[this] object EmptyInstance extends mutable.AbstractMap[Any, Any] {
override def +=(kv: (Any, Any)): EmptyInstance.this.type = this
override def -=(key: Any): EmptyInstance.this.type = this
override def get(key: Any): Option[Any] = None
override def iterator: Iterator[(Any, Any)] = Iterator.empty
override def getOrElseUpdate(key: Any, op: => Any): Any = op
}

def empty[K, V]: mutable.Map[K, V] =
EmptyInstance.asInstanceOf[mutable.Map[K, V]]
}
29 changes: 20 additions & 9 deletions hail/hail/test/src/is/hail/expr/ir/IRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import is.hail.types.virtual.TIterable.elementType
import is.hail.utils.{FastSeq, _}
import is.hail.variant.{Call2, Locus}

import scala.collection.mutable
import scala.language.implicitConversions

import org.apache.spark.sql.Row
Expand Down Expand Up @@ -3907,16 +3908,26 @@ class IRSuite extends HailSuite {
assert(x2 == x)
}

def testBlockMatrixIRParserPersist(): Unit = {
val bm = BlockMatrix.fill(1, 1, 0.0, 5)
backend.persist(ctx.backendContext, "x", bm, "MEMORY_ONLY")
val persist =
BlockMatrixRead(BlockMatrixPersistReader("x", BlockMatrixType.fromBlockMatrix(bm)))
@Test def testBlockMatrixIRParserPersist(): Unit = {
val cache = mutable.Map.empty[String, BlockMatrix]
val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3)
try {
backend.withExecuteContext { ctx =>
ctx.local(blockMatrixCache = cache) { ctx =>
backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY")))
}
}
backend.withExecuteContext { ctx =>
ctx.local(blockMatrixCache = cache) { ctx =>
val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ))

val s = Pretty.sexprStyle(persist, elideLiterals = false)
val x2 = IRParser.parse_blockmatrix_ir(ctx, s)
assert(x2 == persist)
backend.unpersist(ctx.backendContext, "x")
val s = Pretty.sexprStyle(persist, elideLiterals = false)
val x2 = IRParser.parse_blockmatrix_ir(ctx, s)
assert(x2 == persist)
}
}
} finally
cache.values.foreach(_.unpersist())
}

@Test def testCachedIR(): Unit = {
Expand Down