Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Use sourcecode.Enclosing to handle timed blocks implicitly #14683

Merged
merged 10 commits into from
Oct 17, 2024
2 changes: 2 additions & 0 deletions hail/build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ object Deps {
val log4j = ivy"org.apache.logging.log4j:log4j-1.2-api:2.17.2"
val hadoopClient = ivy"org.apache.hadoop:hadoop-client:3.3.4"
val jackson = ivy"com.fasterxml.jackson.core:jackson-core:2.15.2"
val sourcecode = ivy"com.lihaoyi::sourcecode:0.4.2"

object Plugins {
val betterModadicFor = ivy"com.olegpy::better-monadic-for:0.3.1"
Expand Down Expand Up @@ -200,6 +201,7 @@ object main extends RootModule with HailScalaModule { outer =>
Deps.jna,
Deps.json4s.excludeOrg("com.fasterxml.jackson.core"),
Deps.zstd,
Deps.sourcecode
)

override def runIvyDeps: T[Agg[Dep]] = Agg(
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def execute(self, ir: BaseIR, timed: bool = False) -> Any:
return (value, timings) if timed else value

@abc.abstractmethod
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]:
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Optional[dict]]:
pass

def _render_ir(self, ir):
Expand Down
18 changes: 15 additions & 3 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socketserver
import sys
from threading import Thread
from typing import Mapping, Set, Tuple
from typing import Mapping, Optional, Set, Tuple

import orjson
import py4j
Expand Down Expand Up @@ -156,6 +156,18 @@ def connect_logger(utils_package_object, host, port):
}


def parse_timings(str: Optional[str]) -> Optional[dict]:
def parse(node):
return {
'name': node[0],
'total_time': node[1],
'self_time': node[2],
'children': [parse(c) for c in node[3]],
}

return None if str is None else parse(orjson.loads(str))


class Py4JBackend(Backend):
@abc.abstractmethod
def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject):
Expand Down Expand Up @@ -211,7 +223,7 @@ def logger(self):
self._logger = Log4jLogger(self._utils_package_object)
return self._logger

def _rpc(self, action, payload) -> Tuple[bytes, str]:
def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:
data = orjson.dumps(payload)
path = action_routes[action]
port = self._backend_server_port
Expand All @@ -221,7 +233,7 @@ def _rpc(self, action, payload) -> Tuple[bytes, str]:
raise fatal_error_from_java_error_triplet(
error_json['short'], error_json['expanded'], error_json['error_id']
)
return resp.content, resp.headers.get('X-Hail-Timings', '')
return resp.content, parse_timings(resp.headers.get('X-Hail-Timings', None))

def persist_expression(self, expr):
t = expr.dtype
Expand Down
6 changes: 3 additions & 3 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ async def _run_on_batch(
progress: Optional[BatchProgressBar] = None,
driver_cores: Optional[Union[int, str]] = None,
driver_memory: Optional[str] = None,
) -> Tuple[bytes, str]:
) -> Tuple[bytes, Optional[dict]]:
timings = Timings()
async with TemporaryDirectory(ensure_exists=False) as iodir:
with timings.step("write input"):
Expand Down Expand Up @@ -414,7 +414,7 @@ async def _run_on_batch(

with timings.step("read output"):
result_bytes = await retry_transient_errors(self._read_output, iodir + '/out', iodir + '/in')
return result_bytes, str(timings.to_dict())
return result_bytes, timings.to_dict()

async def _read_output(self, output_uri: str, input_uri: str) -> bytes:
try:
Expand Down Expand Up @@ -462,7 +462,7 @@ def _cancel_on_ctrl_c(self, coro: Awaitable[T]) -> T:
self._batch_was_submitted = False
raise

def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]:
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Optional[dict]]:
return self._cancel_on_ctrl_c(self._async_rpc(action, payload))

async def _async_rpc(self, action: ActionTag, payload: ActionPayload):
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/expr/expressions/expression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def eval_timed(expression):
uid = Env.get_uid()
ir = expression._indices.source.select_globals(**{uid: expression}).index_globals()[uid]._ir

return Env.backend().execute(MakeTuple([ir]), timed=True)[0]
(value, timings) = Env.backend().execute(MakeTuple([ir]), timed=True)
return value[0], timings


@typecheck(expression=expr_any)
Expand Down
8 changes: 5 additions & 3 deletions hail/python/hailtop/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,9 +1122,11 @@ def step(self, name: str):
d: Dict[str, int] = {}
self.timings[name] = d
d['start_time'] = time_msecs()
yield
d['finish_time'] = time_msecs()
d['duration'] = d['finish_time'] - d['start_time']
try:
yield
finally:
d['finish_time'] = time_msecs()
d['duration'] = d['finish_time'] - d['start_time']

def to_dict(self):
return self.timings
Expand Down
68 changes: 32 additions & 36 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package is.hail.backend
import is.hail.asm4s._
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IRParser, IRParserEnvironment, LoweringAnalyses,
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
}
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand All @@ -30,6 +30,7 @@ import java.nio.charset.StandardCharsets
import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

object Backend {

Expand All @@ -46,6 +47,25 @@ object Backend {
irID += 1
irID
}

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
off: Long,
bufferSpecString: String,
os: OutputStream,
): Unit = {
val bs = BufferSpec.parseOrDefault(bufferSpecString)
assert(t.size == 1)
val elementType = t.fields(0).typ
val codec = TypedCodecSpec(
EType.fromPythonTypeEncoding(elementType.virtualType),
elementType.virtualType,
bs,
)
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}
}

abstract class BroadcastValue[T] { def value: T }
Expand Down Expand Up @@ -169,41 +189,41 @@ abstract class Backend {
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage

def withExecuteContext[T](methodName: String)(f: ExecuteContext => T): T
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

private[this] def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("valueType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

final def tableType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("tableType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

final def matrixTableType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("matrixTableType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

final def blockMatrixType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("blockMatrixType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

def loadReferencesFromDataset(path: String): Array[Byte] = {
withExecuteContext("loadReferencesFromDataset") { ctx =>
withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

Expand All @@ -221,14 +241,14 @@ abstract class Backend {
mtContigs: Array[String],
parInput: Array[String],
): Array[Byte] =
withExecuteContext("fromFASTAFile") { ctx =>
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput)
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
}

def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
withExecuteContext("parseVCFMetadata") { ctx =>
withExecuteContext { ctx =>
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
Expand All @@ -237,7 +257,7 @@ abstract class Backend {

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
withExecuteContext("importFam") { ctx =>
withExecuteContext { ctx =>
LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes(
StandardCharsets.UTF_8
)
Expand All @@ -251,7 +271,7 @@ abstract class Backend {
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext("pyRegisterIR") { ctx =>
withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
Expand All @@ -264,31 +284,7 @@ abstract class Backend {
}
}

def execute(
ir: String,
timed: Boolean,
)(
consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit
): Unit = ()

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
off: Long,
bufferSpecString: String,
os: OutputStream,
): Unit = {
val bs = BufferSpec.parseOrDefault(bufferSpecString)
assert(t.size == 1)
val elementType = t.fields(0).typ
val codec = TypedCodecSpec(
EType.fromPythonTypeEncoding(elementType.virtualType),
elementType.virtualType,
bs,
)
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

trait BackendWithCodeCache {
Expand Down
29 changes: 23 additions & 6 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package is.hail.backend

import is.hail.expr.ir.{IRParser, IRParserEnvironment}
import is.hail.utils._

import scala.util.control.NonFatal

import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent._

import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer}
import org.json4s._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.compact

case class IRTypePayload(ir: String)
case class LoadReferencesFromDatasetPayload(path: String)
Expand Down Expand Up @@ -84,15 +88,28 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
try {
val body = using(exchange.getRequestBody)(JsonMethods.parse(_))
if (exchange.getRequestURI.getPath == "/execute") {
val config = body.extract[ExecutePayload]
backend.execute(config.ir, config.timed) { (ctx, res, timings) =>
exchange.getResponseHeaders().add("X-Hail-Timings", timings)
val ExecutePayload(irStr, streamCodec, timed) = body.extract[ExecutePayload]
backend.withExecuteContext { ctx =>
val (res, timings) = ExecutionTimer.time { timer =>
ctx.local(timer = timer) { ctx =>
val irData = IRParser.parse_value_ir(
irStr,
IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap),
)
backend.execute(ctx, irData)
}
}

if (timed) {
exchange.getResponseHeaders.add("X-Hail-Timings", compact(timings.toJSON))
}

res match {
case Left(_) => exchange.sendResponseHeaders(200, -1L)
case Right((t, off)) =>
exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body
using(exchange.getResponseBody()) { os =>
backend.encodeToOutputStream(ctx, t, off, config.stream_codec, os)
using(exchange.getResponseBody) { os =>
Backend.encodeToOutputStream(ctx, t, off, streamCodec, os)
}
}
}
Expand Down Expand Up @@ -126,7 +143,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
exchange.sendResponseHeaders(200, response.length)
using(exchange.getResponseBody())(_.write(response))
} catch {
case t: Throwable =>
case NonFatal(t) =>
val (shortMessage, expandedMessage, errorId) = handleForPython(t)
val errorJson = JObject(
"short" -> JString(shortMessage),
Expand Down
Loading