-
Notifications
You must be signed in to change notification settings - Fork 248
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[query] Lift backend state into
{Service|Py4J}BackendApi
- Loading branch information
Showing
4 changed files
with
350 additions
and
362 deletions.
There are no files selected for viewing
123 changes: 0 additions & 123 deletions
123
hail/src/main/scala/is/hail/backend/BackendServer.scala
This file was deleted.
Oops, something went wrong.
154 changes: 154 additions & 0 deletions
154
hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
package is.hail.backend.py4j | ||
|
||
import com.google.api.client.http.HttpStatusCodes | ||
import com.sun.net.{httpserver => http} | ||
import is.hail.HailFeatureFlags | ||
import is.hail.asm4s.HailClassLoader | ||
import is.hail.backend.caching.BlockMatrixCache | ||
import is.hail.backend.{Backend, BackendContext, ExecuteContext, HttpLikeBackendRpc, TempFileManager} | ||
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer | ||
import is.hail.expr.ir.lowering.IrMetadata | ||
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction} | ||
import is.hail.linalg.BlockMatrix | ||
import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} | ||
import is.hail.utils.ExecutionTimer.Timings | ||
import is.hail.utils._ | ||
import is.hail.variant.ReferenceGenome | ||
import org.json4s._ | ||
import org.json4s.jackson.{JsonMethods, Serialization} | ||
|
||
import java.io.Closeable | ||
import java.net.InetSocketAddress | ||
import java.util.concurrent._ | ||
import scala.collection.mutable | ||
|
||
final class Py4JBackendApi(override val backend: Backend) extends Py4JBackendExtensions with Closeable { | ||
|
||
override val references: mutable.Map[String, ReferenceGenome] = | ||
mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) | ||
|
||
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() | ||
override def longLifeTempFileManager: TempFileManager = null | ||
|
||
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) | ||
private[this] val bmCache = new BlockMatrixCache() | ||
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) | ||
private[this] val persistedIr: mutable.Map[Int, BaseIR] = mutable.Map() | ||
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) | ||
|
||
|
||
override def close(): Unit = { | ||
HttpServer.close() | ||
backend.close() | ||
} | ||
|
||
object HttpServer extends HttpLikeBackendRpc[http.HttpExchange] with Closeable { | ||
// 0 => let the OS pick an available port | ||
private[this] val httpServer = http.HttpServer.create(new InetSocketAddress(0), 10) | ||
Check failure on line 47 in hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
|
||
|
||
private[this] val thread = { | ||
// This HTTP server *must not* start non-daemon threads because such threads keep the JVM | ||
// alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest | ||
// when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of the | ||
// JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel | ||
// explicitly regardless of the JVM). It *does* manifest when submitting jobs with | ||
// | ||
// gcloud dataproc submit ... | ||
// | ||
// or | ||
// | ||
// spark-submit | ||
// | ||
// setExecutor(null) ensures the server creates no new threads: | ||
// | ||
/* > If this method is not called (before start()) or if it is called with a null Executor, | ||
* then */ | ||
// > a default implementation is used, which uses the thread which was created by the start() | ||
// > method. | ||
// | ||
/* Source: | ||
* https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ | ||
// | ||
httpServer.createContext("/", runRpc(_: http.HttpExchange)) | ||
httpServer.setExecutor(null) | ||
val t = Executors.defaultThreadFactory().newThread(() => httpServer.start()) | ||
t.setDaemon(true) | ||
t | ||
} | ||
|
||
def port: Int = httpServer.getAddress.getPort | ||
|
||
override def close(): Unit = | ||
httpServer.stop(10) | ||
|
||
thread.start() | ||
|
||
implicit private object Handler | ||
extends Routing with Write[http.HttpExchange] with Context[http.HttpExchange] | ||
with ErrorHandling { | ||
|
||
override def route(req: http.HttpExchange): Route = | ||
req.getRequestURI.getPath match { | ||
case "/value/type" => Routes.TypeOf(Value) | ||
case "/table/type" => Routes.TypeOf(Table) | ||
case "/matrixtable/type" => Routes.TypeOf(Matrix) | ||
case "/blockmatrix/type" => Routes.TypeOf(BlockMatrix) | ||
case "/execute" => Routes.Execute | ||
case "/vcf/metadata/parse" => Routes.ParseVcfMetadata | ||
case "/fam/import" => Routes.ImportFam | ||
case "/references/load" => Routes.LoadReferencesFromDataset | ||
case "/references/from_fasta" => Routes.LoadReferencesFromFASTA | ||
} | ||
|
||
override def payload(req: http.HttpExchange): JValue = | ||
using(req.getRequestBody)(JsonMethods.parse(_)) | ||
|
||
override def timings(req: http.HttpExchange)(t: Timings): Unit = { | ||
val ts = Serialization.write(Map("timings" -> t)) | ||
req.getResponseHeaders.add("X-Hail-Timings", ts) | ||
} | ||
|
||
override def result(req: http.HttpExchange)(result: Array[Byte]): Unit = | ||
respond(req)(HttpStatusCodes.STATUS_CODE_OK, result) | ||
|
||
override def error(req: http.HttpExchange)(t: Throwable): Unit = | ||
respond(req)( | ||
HttpStatusCodes.STATUS_CODE_SERVER_ERROR, | ||
jsonToBytes { | ||
val (shortMessage, expandedMessage, errorId) = handleForPython(t) | ||
JObject( | ||
"short" -> JString(shortMessage), | ||
"expanded" -> JString(expandedMessage), | ||
"error_id" -> JInt(errorId), | ||
) | ||
}, | ||
) | ||
|
||
private[this] def respond(req: http.HttpExchange)(code: Int, payload: Array[Byte]): Unit = { | ||
req.sendResponseHeaders(code, payload.length) | ||
using(req.getResponseBody)(_.write(payload)) | ||
} | ||
|
||
override def scoped[A](req: http.HttpExchange)(f: ExecuteContext => A): (A, Timings) = { | ||
ExecutionTimer.time { timer => | ||
ExecuteContext.scoped( | ||
tmpdir = String, | ||
localTmpdir = String, | ||
backend = Py4JBackendApi.this.backend, | ||
fs = Py4JBackendApi.this.fs, | ||
timer = timer, | ||
tempFileManager = null, | ||
theHailClassLoader = Py4JBackendApi.this.theHailClassLoader, | ||
flags = Py4JBackendApi.this.flags, | ||
irMetadata = IrMetadata(None), | ||
references = Py4JBackendApi.this.references, | ||
blockMatrixCache = Py4JBackendApi.this.bmCache, | ||
codeCache = Py4JBackendApi.this.codeCache, | ||
irCache = Py4JBackendApi.this.persistedIr, | ||
coercerCache = Py4JBackendApi.this.coercerCache, | ||
)(f) | ||
} | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.