Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Extract Backend Methods called from Python into Py4JBackendExtensions #14767

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:

def persist_expression(self, expr):
t = expr.dtype
return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t)
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def set_flags(self, **flags: Mapping[str, str]):
available = self._jbackend.availableFlags()
available = self._jbackend.pyAvailableFlags()
invalid = []
for flag, value in flags.items():
if flag in available:
self._jbackend.setFlag(flag, value)
self._jbackend.pySetFlag(flag, value)
else:
invalid.append(flag)
if len(invalid) != 0:
Expand All @@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]):
)

def get_flags(self, *flags) -> Mapping[str, str]:
return {flag: self._jbackend.getFlag(flag) for flag in flags}
return {flag: self._jbackend.pyGetFlag(flag) for flag in flags}

def _add_reference_to_scala_backend(self, rg):
self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8'))
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3880,7 +3880,7 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)


class JavaIR(IR):
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,4 +1215,4 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)
2 changes: 1 addition & 1 deletion hail/python/test/hail/genetics/test_reference_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def assert_rg_loaded_correctly(name):
# loading different reference genome with same name should fail
# (different `test_rg_o` definition)
with pytest.raises(FatalError):
hl.read_matrix_table(resource('custom_references_2.t')).count()
hl.read_table(resource('custom_references_2.t')).count()

assert hl.read_matrix_table(resource('custom_references.mt')).count_rows() == 14
assert_rg_loaded_correctly('test_rg_1')
Expand Down
115 changes: 19 additions & 96 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package is.hail.backend

import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
}
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand All @@ -20,16 +20,14 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import java.io._
import java.nio.charset.StandardCharsets

import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

object Backend {
Expand All @@ -41,13 +39,6 @@ object Backend {
s"hail_query_$id"
}

private var irID: Int = 0

def nextIRID(): Int = {
irID += 1
irID
}

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
Expand All @@ -66,6 +57,9 @@ object Backend {
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
}

abstract class BroadcastValue[T] { def value: T }
Expand All @@ -75,28 +69,8 @@ trait BackendContext {
}

abstract class Backend extends Closeable {
// From https://github.com/hail-is/hail/issues/14580 :
// IR can get quite big, especially as it can contain an arbitrary
// amount of encoded literals from the user's python session. This
// was a (controversial) restriction imposed by Jackson and should be lifted.
//
// We remove this restriction for all backends, and we do so here, in the
// constructor since constructing a backend is one of the first things that
// happens and this constraint should be overrided as early as possible.
StreamReadConstraints.overrideDefaultStreamReadConstraints(
StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()
)

val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()

protected[this] def addJavaIR(ir: BaseIR): Int = {
val id = Backend.nextIRID()
persistedIR += (id -> ir)
id
}

def removeJavaIR(id: Int): Unit = persistedIR.remove(id)

def defaultParallelism: Int

def canExecuteParallelTasksOnDriver: Boolean = true
Expand Down Expand Up @@ -131,30 +105,7 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

var references: Map[String, ReferenceGenome] = Map.empty

def addDefaultReferences(): Unit =
references = ReferenceGenome.builtinReferences()

def addReference(rg: ReferenceGenome): Unit = {
references.get(rg.name) match {
case Some(rg2) =>
if (rg != rg2) {
fatal(
s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " +
s"@1",
references.keys.truncatable("\n "),
)
}
case None =>
references += (rg.name -> rg)
}
}

def hasReference(name: String) = references.contains(name)

def removeReference(name: String): Unit =
references -= name
def references: mutable.Map[String, ReferenceGenome]

def lowerDistributedSort(
ctx: ExecuteContext,
Expand Down Expand Up @@ -189,9 +140,6 @@ abstract class Backend extends Closeable {

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

private[this] def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
Expand Down Expand Up @@ -220,15 +168,7 @@ abstract class Backend extends Closeable {
}
}

def loadReferencesFromDataset(path: String): Array[Byte] = {
withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

implicit val formats: Formats = defaultJSONFormats
Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8)
}
}
def loadReferencesFromDataset(path: String): Array[Byte]
patrick-schultz marked this conversation as resolved.
Show resolved Hide resolved

def fromFASTAFile(
name: String,
Expand All @@ -240,18 +180,22 @@ abstract class Backend extends Closeable {
parInput: Array[String],
): Array[Byte] =
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput)
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
jsonToBytes {
Extraction.decompose {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}(defaultJSONFormats)
}
}

def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
def parseVCFMetadata(path: String): Array[Byte] =
withExecuteContext { ctx =>
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
jsonToBytes {
Extraction.decompose {
LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
}(defaultJSONFormats)
}
}
}

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
Expand All @@ -261,27 +205,6 @@ abstract class Backend extends Closeable {
)
}

def pyRegisterIR(
name: String,
typeParamStrs: java.util.ArrayList[String],
argNameStrs: java.util.ArrayList[String],
argTypeStrs: java.util.ArrayList[String],
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
typeParamStrs.asScala.toArray,
argNameStrs.asScala.toArray,
argTypeStrs.asScala.toArray,
returnType,
bodyStr,
)
}
}

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
}
return
}

val response: Array[Byte] = exchange.getRequestURI.getPath match {
case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir)
case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ExecuteContext(
)
}

val stateManager = HailStateManager(backend.references)
def stateManager = HailStateManager(backend.references.toMap)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand Down
Loading