From 56548393d735bc5119c58e7c50940493b53a8cde Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 17 Sep 2024 19:56:58 -0400 Subject: [PATCH] [query] Remove `lookupOrCompileCachedFunction` from `Backend` interface --- hail/hail/src/is/hail/backend/Backend.scala | 26 +-- .../src/is/hail/backend/ExecuteContext.scala | 6 + .../is/hail/backend/local/LocalBackend.scala | 7 +- .../hail/backend/service/ServiceBackend.scala | 8 +- .../is/hail/backend/spark/SparkBackend.scala | 11 +- hail/hail/src/is/hail/expr/ir/Compile.scala | 186 ++++++++---------- .../is/hail/expr/ir/CompileAndEvaluate.scala | 1 + hail/hail/src/is/hail/expr/ir/Emit.scala | 1 + hail/hail/src/is/hail/expr/ir/Interpret.scala | 1 + hail/hail/src/is/hail/expr/ir/TableIR.scala | 28 +-- .../src/is/hail/expr/ir/agg/Extract.scala | 7 +- .../ir/lowering/LowerDistributedSort.scala | 1 + .../expr/ir/lowering/RVDToTableStage.scala | 1 + hail/hail/src/is/hail/utils/Cache.scala | 20 +- hail/hail/test/src/is/hail/TestUtils.scala | 5 +- .../is/hail/expr/ir/Aggregators2Suite.scala | 1 + .../src/is/hail/expr/ir/EmitStreamSuite.scala | 1 + .../test/src/is/hail/expr/ir/IRSuite.scala | 12 +- 18 files changed, 156 insertions(+), 167 deletions(-) diff --git a/hail/hail/src/is/hail/backend/Backend.scala b/hail/hail/src/is/hail/backend/Backend.scala index f0c5e958e56..a89a4b0c1e5 100644 --- a/hail/hail/src/is/hail/backend/Backend.scala +++ b/hail/hail/src/is/hail/backend/Backend.scala @@ -4,8 +4,7 @@ import is.hail.asm4s._ import is.hail.backend.Backend.jsonToBytes import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ - BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses, - SortField, TableIR, TableReader, + BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader, } import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} @@ -92,9 +91,6 @@ abstract class Backend extends Closeable { def shouldCacheQueryInfo: Boolean = true - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] - def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, @@ -193,23 +189,3 @@ abstract class Backend extends Closeable { def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } - -trait BackendWithCodeCache { - private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50) - - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] = { - codeCache.get(k) match { - case Some(v) => v.asInstanceOf[CompiledFunction[T]] - case None => - val compiledFunction = f - codeCache += ((k, compiledFunction)) - compiledFunction - } - } -} - -trait BackendWithNoCodeCache { - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] = f -} diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index bc853bc4c15..e1722ff299c 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext +import is.hail.expr.ir.{CodeCacheKey, CompiledFunction} import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS import is.hail.linalg.BlockMatrix @@ -73,6 +74,7 @@ object ExecuteContext { backendContext: BackendContext, irMetadata: IrMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix], + codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], )( f: ExecuteContext => T ): T = { @@ -92,6 +94,7 @@ object ExecuteContext { backendContext, irMetadata, blockMatrixCache, + codeCache, ))(f(_)) } } @@ -122,6 +125,7 @@ class ExecuteContext( val backendContext: BackendContext, val irMetadata: IrMetadata, val BlockMatrixCache: mutable.Map[String, BlockMatrix], + val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], ) extends Closeable { val rngNonce: Long = @@ -189,6 +193,7 @@ class ExecuteContext( backendContext: BackendContext = this.backendContext, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, + codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, )( f: ExecuteContext => A ): A = @@ -206,5 +211,6 @@ class ExecuteContext( backendContext, irMetadata, blockMatrixCache, + codeCache, ))(f) } diff --git a/hail/hail/src/is/hail/backend/local/LocalBackend.scala b/hail/hail/src/is/hail/backend/local/LocalBackend.scala index e2e730f2a1f..8a35ca7cf17 100644 --- a/hail/hail/src/is/hail/backend/local/LocalBackend.scala +++ b/hail/hail/src/is/hail/backend/local/LocalBackend.scala @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.lowering._ import is.hail.io.fs._ @@ -84,13 +85,14 @@ object LocalBackend { class LocalBackend( val tmpdir: String, override val references: mutable.Map[String, ReferenceGenome], -) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { +) extends Backend with Py4JBackendExtensions { override def backend: Backend = this override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() override def longLifeTempFileManager: TempFileManager = null - private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) + private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -114,6 +116,7 @@ class LocalBackend( }, new IrMetadata(), ImmutableMap.empty, + codeCache, )(f) } diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index 4ac82a6e74e..e81e40a710e 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -6,9 +6,10 @@ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate import is.hail.expr.ir.{ - Compile, IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck, + IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ @@ -51,7 +52,6 @@ class ServiceBackendContext( ) extends BackendContext with Serializable {} object ServiceBackend { - private val log = Logger.getLogger(getClass.getName()) def apply( jarLocation: String, @@ -130,8 +130,7 @@ class ServiceBackend( val fs: FS, val serviceBackendContext: ServiceBackendContext, val scratchDir: String, -) extends Backend with BackendWithNoCodeCache { - import ServiceBackend.log +) extends Backend with Logging { private[this] var stageCount = 0 private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 @@ -393,6 +392,7 @@ class ServiceBackend( serviceBackendContext, new IrMetadata(), ImmutableMap.empty, + mutable.Map.empty, )(f) } diff --git a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala index a35d0cbb091..7f7f4fcf1e4 100644 --- a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala +++ b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala @@ -4,15 +4,16 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.caching.BlockMatrixCache import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ +import is.hail.linalg.BlockMatrix import is.hail.rvd.RVD import is.hail.types._ import is.hail.types.physical.{PStruct, PTuple} @@ -26,9 +27,10 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionException import scala.reflect.ClassTag import scala.util.control.NonFatal + import java.io.PrintWriter + import com.fasterxml.jackson.core.StreamReadConstraints -import is.hail.linalg.BlockMatrix import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark._ @@ -320,7 +322,7 @@ class SparkBackend( override val references: mutable.Map[String, ReferenceGenome], gcsRequesterPaysProject: String, gcsRequesterPaysBuckets: String, -) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { +) extends Backend with Py4JBackendExtensions { assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() @@ -352,6 +354,7 @@ class SparkBackend( new OwningTempFileManager(fs) private[this] val bmCache = mutable.Map.empty[String, BlockMatrix] + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) def createExecuteContextForTests( timer: ExecutionTimer, @@ -375,6 +378,7 @@ class SparkBackend( }, new IrMetadata(), ImmutableMap.empty, + mutable.Map.empty, ) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = @@ -395,6 +399,7 @@ class SparkBackend( }, new IrMetadata(), bmCache, + codeCache, )(f) } diff --git a/hail/hail/src/is/hail/expr/ir/Compile.scala b/hail/hail/src/is/hail/expr/ir/Compile.scala index 2dfe1c185fe..1b50071b78b 100644 --- a/hail/hail/src/is/hail/expr/ir/Compile.scala +++ b/hail/hail/src/is/hail/expr/ir/Compile.scala @@ -14,11 +14,12 @@ import is.hail.types.physical.stypes.{ PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType, } import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream} -import is.hail.types.virtual.Type import is.hail.utils._ import java.io.PrintWriter +import sourcecode.Enclosing + case class CodeCacheKey( aggSigs: IndexedSeq[AggStateSig], args: Seq[(Name, EmitParamType)], @@ -33,8 +34,9 @@ case class CompiledFunction[T]( (typ, f) } -object Compile { - def apply[F: TypeInfo]( +object compile { + + def Compile[F: TypeInfo]( ctx: ExecuteContext, params: IndexedSeq[(Name, EmitParamType)], expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], @@ -43,60 +45,18 @@ object Compile { optimize: Boolean = true, print: Option[PrintWriter] = None, ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = - ctx.time { - val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) - val k = - CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody) - (ctx.backend.lookupOrCompileCachedFunction[F](k) { - - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck(ctx, ir, BindingEnv.empty) - - val returnParam = CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)) - - val fb = EmitFunctionBuilder[F]( - ctx, - "Compiled", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => - pt - }, - returnParam, - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } - * } - * - * visit(ir) } */ - - assert( - fb.mb.parameterTypeInfo == expectedCodeParamTypes, - s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", - ) - assert( - fb.mb.returnTypeInfo == expectedCodeReturnType, - s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", - ) - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length) - CompiledFunction(rt, fb.resultWithIndex(print)) - }).tuple - } -} + Impl[F, AnyVal]( + ctx, + params, + None, + expectedCodeParamTypes, + expectedCodeReturnType, + body, + optimize, + print, + ) -object CompileWithAggregators { - def apply[F: TypeInfo]( + def CompileWithAggregators[F: TypeInfo]( ctx: ExecuteContext, aggSigs: Array[AggStateSig], params: IndexedSeq[(Name, EmitParamType)], @@ -104,60 +64,74 @@ object CompileWithAggregators { expectedCodeReturnType: TypeInfo[_], body: IR, optimize: Boolean = true, + print: Option[PrintWriter] = None, ): ( Option[SingleCodeType], - (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion), + (HailClassLoader, FS, HailTaskContext, Region) => F with FunctionWithAggRegion, ) = + Impl[F, FunctionWithAggRegion]( + ctx, + params, + Some(aggSigs), + expectedCodeParamTypes, + expectedCodeReturnType, + body, + optimize, + print, + ) + + private[this] def Impl[F: TypeInfo, Mixin]( + ctx: ExecuteContext, + params: IndexedSeq[(Name, EmitParamType)], + aggSigs: Option[Array[AggStateSig]], + expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], + expectedCodeReturnType: TypeInfo[_], + body: IR, + optimize: Boolean, + print: Option[PrintWriter], + )(implicit + E: Enclosing, + N: sourcecode.Name, + ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) = ctx.time { - val normalizedBody = - NormalizeNames(ctx, body, allowFreeVariables = true) - val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody) - (ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) { - - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck( - ctx, - ir, - BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })), - ) - - val fb = EmitFunctionBuilder[F]( - ctx, - "CompiledWithAggs", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt }, - SingleCodeType.typeInfoFromType(ir.typ), - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } - * } - * - * visit(ir) } */ - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs)) - - val f = fb.resultWithIndex() - CompiledFunction( - rt, - f.asInstanceOf[( - HailClassLoader, - FS, - HailTaskContext, - Region, - ) => (F with FunctionWithAggRegion)], - ) - }).tuple + val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) + ctx.CodeCache.getOrElseUpdate( + CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), { + var ir = Subst( + body, + BindingEnv(Env.fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In(i, t) })), + ) + ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx) + TypeCheck(ctx, ir) + + val fb = EmitFunctionBuilder[F]( + ctx, + N.value, + CodeParamType(typeInfo[Region]) +: params.map(_._2), + CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)), + Some("Emit.scala"), + ) + + /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ + * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) + * } } + * + * visit(ir) } */ + + assert( + fb.mb.parameterTypeInfo == expectedCodeParamTypes, + s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", + ) + assert( + fb.mb.returnTypeInfo == expectedCodeReturnType, + s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", + ) + + val emitContext = EmitContext.analyze(ctx, ir) + val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, aggSigs) + CompiledFunction(rt, fb.resultWithIndex(print)) + }, + ).asInstanceOf[CompiledFunction[F with Mixin]].tuple } } diff --git a/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala b/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala index b099347e827..b4fe88bdb22 100644 --- a/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala +++ b/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.annotations.{Region, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.{Begin, EncodedLiteral, Literal, MakeTuple, NA} import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.types.physical.PTuple diff --git a/hail/hail/src/is/hail/expr/ir/Emit.scala b/hail/hail/src/is/hail/expr/ir/Emit.scala index be2e79a091b..07cabb71695 100644 --- a/hail/hail/src/is/hail/expr/ir/Emit.scala +++ b/hail/hail/src/is/hail/expr/ir/Emit.scala @@ -7,6 +7,7 @@ import is.hail.expr.ir.agg.{AggStateSig, ArrayAggStateSig, GroupedStateSig} import is.hail.expr.ir.analyses.{ ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash, } +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.ndarrays.EmitNDArray diff --git a/hail/hail/src/is/hail/expr/ir/Interpret.scala b/hail/hail/src/is/hail/expr/ir/Interpret.scala index 1c0328804b2..067dd8c1c4b 100644 --- a/hail/hail/src/is/hail/expr/ir/Interpret.scala +++ b/hail/hail/src/is/hail/expr/ir/Interpret.scala @@ -4,6 +4,7 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.backend.spark.SparkTaskContext +import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.io.BufferSpec diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 27404c71c01..5d4719fbeec 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -5,7 +5,7 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer} import is.hail.backend.spark.{SparkBackend, SparkTaskContext} -import is.hail.expr.ir +import is.hail.expr.ir.compile.{Compile, CompileWithAggregators} import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.{ BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction, @@ -1932,7 +1932,7 @@ case class TableNativeZippedReader( val leftRef = Ref(freshName(), pLeft.virtualType) val rightRef = Ref(freshName(), pRight.virtualType) val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( leftRef.name -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)), @@ -2421,7 +2421,7 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR { else if (pred == False()) return TableValueIntermediate(tv.copy(rvd = RVD.empty(ctx, typ.canonicalRVDType))) - val (Some(BooleanSingleCodeType), f) = ir.Compile[AsmFunction3RegionLongLongBoolean]( + val (Some(BooleanSingleCodeType), f) = Compile[AsmFunction3RegionLongLongBoolean]( ctx, FastSeq( ( @@ -3036,7 +3036,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { if (extracted.aggs.isEmpty) { val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( ( @@ -3102,7 +3102,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { // 3. load in partition aggregations, comb op as necessary, serialize. // 4. load in partStarts, calculate newRow based on those results. - val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3116,7 +3116,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val serializeF = extracted.serialize(ctx, spec) - val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3139,7 +3139,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec) val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = - ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, extracted.states, FastSeq( @@ -3698,7 +3698,7 @@ case class TableKeyByAndAggregate( val localKeyType = keyType val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) = - ir.Compile[AsmFunction3RegionLongLongLong]( + Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( ( @@ -3724,7 +3724,7 @@ case class TableKeyByAndAggregate( val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3736,7 +3736,7 @@ case class TableKeyByAndAggregate( extracted.init, ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3755,7 +3755,7 @@ case class TableKeyByAndAggregate( ) val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) = - ir.CompileWithAggregators[AsmFunction2RegionLongLong]( + CompileWithAggregators[AsmFunction2RegionLongLong]( ctx, extracted.states, FastSeq(( @@ -3898,7 +3898,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit]( ctx, extracted.states, FastSeq(( @@ -3910,7 +3910,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { extracted.init, ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( ctx, extracted.states, FastSeq( @@ -3934,7 +3934,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val key = Ref(freshName(), keyType.virtualType) val value = Ref(freshName(), valueIR.typ) val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) = - ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, extracted.states, FastSeq( diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index a9807b283dc..e23c5f9012b 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -6,6 +6,7 @@ import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.backend.spark.SparkTaskContext import is.hail.expr.ir import is.hail.expr.ir._ +import is.hail.expr.ir.compile.CompileWithAggregators import is.hail.expr.ir.defs._ import is.hail.io.BufferSpec import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} @@ -248,7 +249,7 @@ class Aggs( def deserialize(ctx: ExecuteContext, spec: BufferSpec) : ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states, FastSeq(), @@ -269,7 +270,7 @@ class Aggs( def serialize(ctx: ExecuteContext, spec: BufferSpec) : (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states, FastSeq(), @@ -306,7 +307,7 @@ class Aggs( : (() => (RegionPool, HailClassLoader, HailTaskContext)) => ( (Array[Byte], Array[Byte]) => Array[Byte], ) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( ctx, states ++ states, FastSeq(), diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 68e6c28651f..4d6d8b4dbbf 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -4,6 +4,7 @@ import is.hail.annotations.{Annotation, ExtendedOrdering, Region, SafeRow} import is.hail.asm4s.{classInfo, AsmFunction1RegionLong, LongInfo} import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.{ArrayFunctions, IRRandomness, UtilFunctions} import is.hail.io.{BufferSpec, TypedCodecSpec} diff --git a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala index d2cf0499a8a..85d8673cc59 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -5,6 +5,7 @@ import is.hail.asm4s._ import is.hail.backend.{BroadcastValue, ExecuteContext} import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} import is.hail.expr.ir._ +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.{ GetField, In, Let, MakeStruct, ReadPartition, Ref, StreamRange, ToArray, } diff --git a/hail/hail/src/is/hail/utils/Cache.scala b/hail/hail/src/is/hail/utils/Cache.scala index 3aa40a6c547..3ec59a80498 100644 --- a/hail/hail/src/is/hail/utils/Cache.scala +++ b/hail/hail/src/is/hail/utils/Cache.scala @@ -2,20 +2,32 @@ package is.hail.utils import is.hail.annotations.{Region, RegionMemory} +import scala.collection.mutable +import scala.jdk.CollectionConverters.asScalaIteratorConverter + import java.io.Closeable import java.util import java.util.Map.Entry -class Cache[K, V](capacity: Int) { +class Cache[K, V](capacity: Int) extends mutable.AbstractMap[K, V] { private[this] val m = new util.LinkedHashMap[K, V](capacity, 0.75f, true) { override def removeEldestEntry(eldest: Entry[K, V]): Boolean = size() > capacity } - def get(k: K): Option[V] = synchronized(Option(m.get(k))) + override def +=(kv: (K, V)): Cache.this.type = + synchronized { m.put(kv._1, kv._2); this } + + override def -=(key: K): Cache.this.type = + synchronized { m.remove(key); this } - def +=(p: (K, V)): Unit = synchronized(m.put(p._1, p._2)) + override def get(key: K): Option[V] = + synchronized(Option(m.get(key))) - def size: Int = synchronized(m.size()) + override def iterator: Iterator[(K, V)] = + for { e <- m.entrySet().iterator().asScala } yield (e.getKey, e.getValue) + + override def clear(): Unit = + m.clear() } class LongToRegionValueCache(capacity: Int) extends Closeable { diff --git a/hail/hail/test/src/is/hail/TestUtils.scala b/hail/hail/test/src/is/hail/TestUtils.scala index 4e29c013d67..da061cc8dc6 100644 --- a/hail/hail/test/src/is/hail/TestUtils.scala +++ b/hail/hail/test/src/is/hail/TestUtils.scala @@ -4,9 +4,10 @@ import is.hail.annotations.{Region, RegionValueBuilder, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.{ - freshName, streamAggIR, BindingEnv, Compile, Env, IR, Interpret, MapIR, MatrixIR, MatrixRead, - Name, SingleCodeEmitParamType, Subst, + freshName, streamAggIR, BindingEnv, Env, IR, Interpret, MapIR, MatrixIR, MatrixRead, Name, + SingleCodeEmitParamType, Subst, } +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.{GetField, GetTupleElement, In, MakeTuple, Ref, ToStream} import is.hail.expr.ir.lowering.LowererUnsupportedOperation import is.hail.io.vcf.MatrixVCFReader diff --git a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala index abef7b459e9..0d50cf8a988 100644 --- a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala @@ -4,6 +4,7 @@ import is.hail.{ExecStrategy, HailSuite} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.expr.ir.agg._ +import is.hail.expr.ir.compile.CompileWithAggregators import is.hail.expr.ir.defs._ import is.hail.io.BufferSpec import is.hail.types.VirtualTypeWithReq diff --git a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala index 7cd6fdf0053..7ca4f22567f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala @@ -6,6 +6,7 @@ import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg.{CollectStateSig, PhysicalAggSig, TypedStateSig} +import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.ir.streams.{EmitStream, StreamUtils} diff --git a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala index a2d632f5eeb..995b435c875 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala @@ -5,7 +5,6 @@ import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} import is.hail.backend.ExecuteContext -import is.hail.backend.caching.BlockMatrixCache import is.hail.expr.Nat import is.hail.expr.ir.agg._ import is.hail.expr.ir.defs._ @@ -13,6 +12,7 @@ import is.hail.expr.ir.defs.ArrayZipBehavior.ArrayZipBehavior import is.hail.expr.ir.functions._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.bgen.MatrixBGENReader +import is.hail.linalg.BlockMatrix import is.hail.methods._ import is.hail.rvd.{PartitionBoundOrdering, RVD, RVDPartitioner} import is.hail.types.{tcoerce, VirtualTypeWithReq} @@ -24,6 +24,7 @@ import is.hail.types.virtual.TIterable.elementType import is.hail.utils.{FastSeq, _} import is.hail.variant.{Call2, Locus} +import scala.collection.mutable import scala.language.implicitConversions import org.apache.spark.sql.Row @@ -3907,8 +3908,9 @@ class IRSuite extends HailSuite { assert(x2 == x) } - @Test def testBlockMatrixIRParserPersist(): Unit = - using(new BlockMatrixCache()) { cache => + @Test def testBlockMatrixIRParserPersist(): Unit = { + val cache = mutable.Map.empty[String, BlockMatrix] + try { val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3) backend.withExecuteContext { ctx => @@ -3926,7 +3928,9 @@ class IRSuite extends HailSuite { assert(x2 == persist) } } - } + } finally + cache.values.foreach(_.unpersist()) + } @Test def testCachedIR(): Unit = { val cached = Literal(TSet(TInt32), Set(1))