Skip to content

Commit

Permalink
[query] Lift backend state into {Service|Py4J}BackendApi
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 24, 2025
1 parent 23f6536 commit 1825e8e
Show file tree
Hide file tree
Showing 59 changed files with 1,300 additions and 1,421 deletions.
16 changes: 3 additions & 13 deletions hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,21 @@ def __init__(
hail_package = getattr(self._gateway.jvm, 'is').hail

jbackend = hail_package.backend.local.LocalBackend.apply(
tmpdir,
log,
True,
append,
skip_logging_configuration,
)
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)

super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc)
super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration
self._fs = self._exit_stack.enter_context(
RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration})
)

self._logger = None

flags = {}
if gcs_requester_pays_configuration is not None:
if isinstance(gcs_requester_pays_configuration, str):
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration
else:
assert isinstance(gcs_requester_pays_configuration, tuple)
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0]
flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1])

self._initialize_flags(flags)
self._initialize_flags({})

def validate_file(self, uri: str) -> None:
async_to_blocking(validate_file(uri, self._fs.afs))
Expand Down
45 changes: 38 additions & 7 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

import hail
from hail.expr import construct_expr
from hail.fs.hadoop_fs import HadoopFS
from hail.ir import JavaIR
from hail.utils.java import Env, FatalError, scala_package_object
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration

from ..hail_logging import Logger
from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet
Expand Down Expand Up @@ -170,8 +172,15 @@ def parse(node):

class Py4JBackend(Backend):
@abc.abstractmethod
def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject):
super(Py4JBackend, self).__init__()
def __init__(
self,
jvm: JVMView,
jbackend: JavaObject,
jhc: JavaObject,
tmpdir: str,
remote_tmpdir: str,
):
super().__init__()
import base64

def decode_bytearray(encoded):
Expand All @@ -184,14 +193,19 @@ def decode_bytearray(encoded):
self._jvm = jvm
self._hail_package = getattr(self._jvm, 'is').hail
self._utils_package_object = scala_package_object(self._hail_package.utils)
self._jbackend = jbackend
self._jhc = jhc

self._backend_server = self._hail_package.backend.BackendServer(self._jbackend)
self._backend_server_port: int = self._backend_server.port()
self._backend_server.start()
self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend)
self._jbackend.pySetLocalTmp(tmpdir)
self._jbackend.pySetRemoteTmp(remote_tmpdir)

self._jhttp_server = self._jbackend.pyHttpServer()
self._backend_server_port: int = self._jhttp_server.port()
self._requests_session = requests.Session()

self._gcs_requester_pays_config = None
self._fs = None

# This has to go after creating the SparkSession. Unclear why.
# Maybe it does its own patch?
install_exception_handler()
Expand All @@ -215,6 +229,23 @@ def hail_package(self):
def utils_package_object(self):
return self._utils_package_object

@property
def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]:
return self._gcs_requester_pays_config

@gcs_requester_pays_configuration.setter
def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]):
self._gcs_requester_pays_config = config
project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config
self._jbackend.pySetGcsRequesterPaysConfig(project, buckets)
self._fs = None # stale

@property
def fs(self):
if self._fs is None:
self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs())
return self._fs

@property
def logger(self):
if self._logger is None:
Expand Down Expand Up @@ -289,7 +320,7 @@ def _to_java_blockmatrix_ir(self, ir):
return self._parse_blockmatrix_ir(self._render_ir(ir))

def stop(self):
self._backend_server.close()
self._jhttp_server.close()
self._jbackend.close()
self._jhc.stop()
self._jhc = None
Expand Down
29 changes: 6 additions & 23 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import pyspark.sql

from hail.expr.table_type import ttable
from hail.fs.hadoop_fs import HadoopFS
from hail.ir import BaseIR
from hail.ir.renderer import CSERenderer
from hail.table import Table
from hail.utils import copy_log
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration
from hailtop.aiotools.router_fs import RouterAsyncFS
from hailtop.aiotools.validators import validate_file
from hailtop.utils import async_to_blocking
Expand Down Expand Up @@ -47,12 +47,9 @@ def __init__(
skip_logging_configuration,
optimizer_iterations,
*,
gcs_requester_pays_project: Optional[str] = None,
gcs_requester_pays_buckets: Optional[str] = None,
gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None,
copy_log_on_error: bool = False,
):
assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None

try:
local_jar_info = local_jar_information()
except ValueError:
Expand Down Expand Up @@ -120,10 +117,6 @@ def __init__(
append,
skip_logging_configuration,
min_block_size,
tmpdir,
local_tmpdir,
gcs_requester_pays_project,
gcs_requester_pays_buckets,
)
jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations)
else:
Expand All @@ -137,10 +130,6 @@ def __init__(
append,
skip_logging_configuration,
min_block_size,
tmpdir,
local_tmpdir,
gcs_requester_pays_project,
gcs_requester_pays_buckets,
)
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)

Expand All @@ -149,12 +138,12 @@ def __init__(
self.sc = sc
else:
self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc))
self._jspark_session = jbackend.sparkSession()
self._jspark_session = jbackend.sparkSession().apply()
self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session)

super(SparkBackend, self).__init__(jvm, jbackend, jhc)
super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_config

self._fs = None
self._logger = None

if not quiet:
Expand All @@ -167,7 +156,7 @@ def __init__(
self._initialize_flags({})

self._router_async_fs = RouterAsyncFS(
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project}
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config}
)

self._tmpdir = tmpdir
Expand All @@ -181,12 +170,6 @@ def stop(self):
self.sc.stop()
self.sc = None

@property
def fs(self):
if self._fs is None:
self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs())
return self._fs

def from_spark(self, df, key):
result_tuple = self._jbackend.pyFromDF(df._jdf, key)
tir_id, type_json = result_tuple._1(), result_tuple._2()
Expand Down
13 changes: 4 additions & 9 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,10 @@ def init_spark(
optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3)

app_name = app_name or 'Hail'
(
gcs_requester_pays_project,
gcs_requester_pays_buckets,
) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style(
get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
)
gcs_requester_pays_configuration = get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
)

backend = SparkBackend(
idempotent,
sc,
Expand All @@ -498,8 +494,7 @@ def init_spark(
local_tmpdir,
skip_logging_configuration,
optimizer_iterations,
gcs_requester_pays_project=gcs_requester_pays_project,
gcs_requester_pays_buckets=gcs_requester_pays_buckets,
gcs_requester_pays_config=gcs_requester_pays_configuration,
copy_log_on_error=copy_log_on_error,
)
if not backend.fs.exists(tmpdir):
Expand Down
5 changes: 2 additions & 3 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import is.hail.io.fs.FS
import is.hail.types.RTable
import is.hail.types.encoded.EType
import is.hail.types.physical.PTuple
import is.hail.utils.ExecutionTimer.Timings
import is.hail.utils.fatal

import scala.reflect.ClassTag
Expand Down Expand Up @@ -105,7 +104,7 @@ abstract class Backend extends Closeable {
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings)

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]

def backendContext(ctx: ExecuteContext): BackendContext
}
4 changes: 0 additions & 4 deletions hail/src/main/scala/is/hail/backend/BackendRpc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,4 @@ trait HttpLikeBackendRpc[A] extends BackendRpc {
)
}
}

implicit protected def Ask: Routing
implicit protected def Write: Write[A]
implicit protected def Context: Context[A]
}
123 changes: 0 additions & 123 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala

This file was deleted.

Loading

0 comments on commit 1825e8e

Please sign in to comment.