From b6106b523e13909bab23a8c6eb93d478c586f9d9 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 1 Oct 2024 12:08:19 -0400 Subject: [PATCH] wire up spark, local and py4j backends --- hail/python/hail/backend/local_backend.py | 15 ++-------- hail/python/hail/backend/py4j_backend.py | 29 +++++++++++++++++-- hail/python/hail/backend/spark_backend.py | 29 ++++--------------- hail/python/hail/context.py | 13 +++------ .../is/hail/backend/api/Py4JBackendApi.scala | 20 +++++++++---- .../hail/backend/api/ServiceBackendApi.scala | 4 ++- .../main/scala/is/hail/utils/package.scala | 1 - hail/src/test/scala/is/hail/HailSuite.scala | 4 +-- 8 files changed, 59 insertions(+), 56 deletions(-) diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 0908ee06986f..311a5c7e9bb4 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -81,23 +81,14 @@ def __init__( ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) - super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc) + super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_configuration self._fs = self._exit_stack.enter_context( RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration}) ) self._logger = None - - flags = {} - if gcs_requester_pays_configuration is not None: - if isinstance(gcs_requester_pays_configuration, str): - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration - else: - assert isinstance(gcs_requester_pays_configuration, tuple) - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0] - flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1]) - - self._initialize_flags(flags) + self._initialize_flags({}) def validate_file(self, uri: str) -> None: async_to_blocking(validate_file(uri, self._fs.afs)) diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 49838ea4a851..d59ae456a937 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -13,8 +13,10 @@ import hail from hail.expr import construct_expr +from hail.fs.hadoop_fs import HadoopFS from hail.ir import JavaIR from hail.utils.java import Env, FatalError, scala_package_object +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from ..hail_logging import Logger from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet @@ -193,11 +195,17 @@ def decode_bytearray(encoded): self._utils_package_object = scala_package_object(self._hail_package.utils) self._jhc = jhc - self._jbackend = self._hail_package.backend.api.P4jBackendApi(jbackend) + self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend) + self._jbackend.pySetLocalTmp(tmpdir) + self._jbackend.pySetRemoteTmp(remote_tmpdir) + self._jhttp_server = self._jbackend.pyHttpServer() - self._backend_server_port: int = self._jbackend.HttpServer.port() + self._backend_server_port: int = self._jhttp_server.port() self._requests_session = requests.Session() + self._gcs_requester_pays_config = None + self._fs = None + # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? install_exception_handler() @@ -221,6 +229,23 @@ def hail_package(self): def utils_package_object(self): return self._utils_package_object + @property + def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]: + return self._gcs_requester_pays_config + + @gcs_requester_pays_configuration.setter + def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]): + self._gcs_requester_pays_config = config + project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config + self._jbackend.pySetGcsRequesterPaysConfig(project, buckets) + self._fs = None # stale + + @property + def fs(self): + if self._fs is None: + self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs()) + return self._fs + @property def logger(self): if self._logger is None: diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index 69f532924438..192ca57eb750 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -7,11 +7,11 @@ import pyspark.sql from hail.expr.table_type import ttable -from hail.fs.hadoop_fs import HadoopFS from hail.ir import BaseIR from hail.ir.renderer import CSERenderer from hail.table import Table from hail.utils import copy_log +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.aiotools.validators import validate_file from hailtop.utils import async_to_blocking @@ -47,12 +47,9 @@ def __init__( skip_logging_configuration, optimizer_iterations, *, - gcs_requester_pays_project: Optional[str] = None, - gcs_requester_pays_buckets: Optional[str] = None, + gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None, copy_log_on_error: bool = False, ): - assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None - try: local_jar_info = local_jar_information() except ValueError: @@ -120,10 +117,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations) else: @@ -137,10 +130,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) @@ -149,12 +138,12 @@ def __init__( self.sc = sc else: self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc)) - self._jspark_session = jbackend.sparkSession() + self._jspark_session = jbackend.sparkSession().apply() self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session) - super(SparkBackend, self).__init__(jvm, jbackend, jhc) + super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_config - self._fs = None self._logger = None if not quiet: @@ -167,7 +156,7 @@ def __init__( self._initialize_flags({}) self._router_async_fs = RouterAsyncFS( - gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project} + gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config} ) self._tmpdir = tmpdir @@ -181,12 +170,6 @@ def stop(self): self.sc.stop() self.sc = None - @property - def fs(self): - if self._fs is None: - self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs()) - return self._fs - def from_spark(self, df, key): result_tuple = self._jbackend.pyFromDF(df._jdf, key) tir_id, type_json = result_tuple._1(), result_tuple._2() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 5258f27fbc13..3d8689f4be10 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -474,14 +474,10 @@ def init_spark( optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3) app_name = app_name or 'Hail' - ( - gcs_requester_pays_project, - gcs_requester_pays_buckets, - ) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style( - get_gcs_requester_pays_configuration( - gcs_requester_pays_configuration=gcs_requester_pays_configuration, - ) + gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( + gcs_requester_pays_configuration=gcs_requester_pays_configuration, ) + backend = SparkBackend( idempotent, sc, @@ -498,8 +494,7 @@ def init_spark( local_tmpdir, skip_logging_configuration, optimizer_iterations, - gcs_requester_pays_project=gcs_requester_pays_project, - gcs_requester_pays_buckets=gcs_requester_pays_buckets, + gcs_requester_pays_config=gcs_requester_pays_configuration, copy_log_on_error=copy_log_on_error, ) if not backend.fs.exists(tmpdir): diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala index ce069416e05f..24bd3400d2d0 100644 --- a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -6,7 +6,10 @@ import is.hail.backend._ import is.hail.backend.caching.BlockMatrixCache import is.hail.backend.spark.SparkBackend import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue} +import is.hail.expr.ir.{ + BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, + Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, +} import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.functions.IRFunctionRegistry @@ -21,14 +24,18 @@ import is.hail.utils._ import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome +import scala.annotation.nowarn import scala.collection.mutable import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} + import java.io.Closeable import java.net.InetSocketAddress import java.util import java.util.concurrent._ + import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpServer} +import javax.annotation.Nullable import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.DataFrame @@ -37,9 +44,6 @@ import org.json4s._ import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing -import javax.annotation.Nullable -import scala.annotation.nowarn - final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() @@ -74,6 +78,9 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin manager.close() } + def pyFs: FS = + tmpFileManager.getFs + def pyGetFlag(name: String): String = flags.get(name) @@ -83,13 +90,14 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin def pyAvailableFlags: java.util.ArrayList[String] = flags.available - def pySetTmpdir(tmp: String): Unit = + def pySetRemoteTmp(tmp: String): Unit = tmpdir = tmp def pySetLocalTmp(tmp: String): Unit = localTmpdir = tmp - def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = { + def pySetGcsRequesterPaysConfig(@Nullable project: String, @Nullable buckets: util.List[String]) + : Unit = { val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) val rpConfig: Option[RequesterPaysConfig] = diff --git a/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala index 4e5659dc13b2..f089d79fb84c 100644 --- a/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala @@ -11,7 +11,9 @@ import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS} import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.services._ import is.hail.types.virtual.Kinds -import is.hail.utils.{toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging} +import is.hail.utils.{ + toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging, +} import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome diff --git a/hail/src/main/scala/is/hail/utils/package.scala b/hail/src/main/scala/is/hail/utils/package.scala index 4fc2acbaf268..36fc0b91c707 100644 --- a/hail/src/main/scala/is/hail/utils/package.scala +++ b/hail/src/main/scala/is/hail/utils/package.scala @@ -92,7 +92,6 @@ package utils { } } - class Lazy[A] private[utils] (f: => A) { private[this] var option: Option[A] = None diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index 057d3b777441..b5b40cdf7bde 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -70,8 +70,8 @@ class HailSuite extends TestNGSuite with TestUtils { var pool: RegionPool = _ private[this] var ctx_ : ExecuteContext = _ - def backend: Backend = ctx.backend - def sc: SparkContext = backend.asSpark.sc + def backend: Backend = hc.backend + def sc: SparkContext = hc.backend.asSpark.sc def timer: ExecutionTimer = ctx.timer def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader override def ctx: ExecuteContext = ctx_