Skip to content

Commit

Permalink
prevent python modifying py4jbackend while executing a query
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 30, 2024
1 parent d5c609c commit 27771b4
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 61 deletions.
134 changes: 73 additions & 61 deletions hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import java.io.Closeable
import java.net.InetSocketAddress
import java.util
import java.util.concurrent._
import java.util.concurrent.locks.ReentrantReadWriteLock

import com.google.api.client.http.HttpStatusCodes
import com.sun.net.httpserver.{HttpExchange, HttpServer}
import javax.annotation.Nullable
import org.apache.hadoop
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.DataFrame
Expand All @@ -46,6 +46,10 @@ import sourcecode.Enclosing

final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling {

private[this] val rwlock = new ReentrantReadWriteLock()
private[this] def weak = rwlock.readLock()
private[this] def mutex = rwlock.writeLock()

private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
private[this] val hcl = new HailClassLoader(getClass.getClassLoader)
private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*)
Expand Down Expand Up @@ -79,62 +83,63 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
}

def pyFs: FS =
tmpFileManager.getFs
using(weak.acquire())(_ => tmpFileManager.getFs)

def pyGetFlag(name: String): String =
flags.get(name)
using(weak.acquire())(_ => flags.get(name))

def pySetFlag(name: String, value: String): Unit =
flags.set(name, value)
using(weak.acquire())(_ => flags.set(name, value))

def pyAvailableFlags: java.util.ArrayList[String] =
flags.available

def pySetRemoteTmp(tmp: String): Unit =
tmpdir = tmp
using(weak.acquire())(_ => tmpdir = tmp)

def pySetLocalTmp(tmp: String): Unit =
localTmpdir = tmp

def pySetGcsRequesterPaysConfig(@Nullable project: String, @Nullable buckets: util.List[String])
: Unit = {
val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

val rpConfig: Option[RequesterPaysConfig] =
(
Option(project).filter(_.nonEmpty),
Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty),
) match {
case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets))
case (None, Some(_)) => fatal(
"A non-empty, non-null requester pays google project is required to configure requester pays buckets."
)
case (None, None) => None
}

val fs = newFs(
cloudfsConf.copy(
google = (cloudfsConf.google, rpConfig) match {
case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig))
case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig))
case _ => None
using(weak.acquire())(_ => localTmpdir = tmp)

def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
using(weak.acquire()) { _ =>
val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

val rpConfig: Option[RequesterPaysConfig] =
(
Option(project).filter(_.nonEmpty),
Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty),
) match {
case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets))
case (None, Some(_)) => fatal(
"A non-empty, non-null requester pays google project is required to configure requester pays buckets."
)
case (None, None) => None
}

val fs = newFs(
cloudfsConf.copy(
google = (cloudfsConf.google, rpConfig) match {
case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig))
case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig))
case _ => None
}
)
)
)

tmpFileManager.setFs(fs)
}
tmpFileManager.setFs(fs)
}

def pyRemoveJavaIR(id: Int): Unit =
persistedIr.remove(id)
using(weak.acquire())(_ => persistedIr.remove(id))

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
references(name).addSequence(
IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)
)
using(weak.acquire()) { _ =>
val seq = IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)
references(name).addSequence(seq)
}

def pyRemoveSequence(name: String): Unit =
references(name).removeSequence()
using(weak.acquire())(_ => references(name).removeSequence())

def pyExportBlockMatrix(
pathIn: String,
Expand Down Expand Up @@ -260,16 +265,21 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
}._1

def pyAddReference(jsonConfig: String): Unit =
addReference(ReferenceGenome.fromJSON(jsonConfig))
using(weak.acquire())(_ => addReference(ReferenceGenome.fromJSON(jsonConfig)))

def pyRemoveReference(name: String): Unit =
removeReference(name)
using(weak.acquire())(_ => removeReference(name))

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
references(name).addLiftover(references(destRGName), LiftOver(tmpFileManager.getFs, chainFile))
using(weak.acquire()) { _ =>
references(name).addLiftover(
references(destRGName),
LiftOver(tmpFileManager.getFs, chainFile),
)
}

def pyRemoveLiftover(name: String, destRGName: String): Unit =
references(name).removeLiftover(destRGName)
using(weak.acquire())(_ => references(name).removeLiftover(destRGName))

def parse_blockmatrix_ir(s: String): BlockMatrixIR =
withExecuteContext(selfContainedExecution = false) { ctx =>
Expand Down Expand Up @@ -300,25 +310,27 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
f: ExecuteContext => T
)(implicit E: Enclosing
): (T, Timings) =
ExecutionTimer.time { timer =>
ExecuteContext.scoped(
tmpdir = tmpdir,
localTmpdir = localTmpdir,
backend = backend,
fs = tmpFileManager.getFs,
timer = timer,
tempFileManager =
if (selfContainedExecution) null
else NonOwningTempFileManager(tmpFileManager),
theHailClassLoader = hcl,
flags = flags,
irMetadata = new IrMetadata(),
references = references,
blockMatrixCache = bmCache,
codeCache = codeCache,
irCache = persistedIr,
coercerCache = coercerCache,
)(f)
using(mutex.acquire()) { _ =>
ExecutionTimer.time { timer =>
ExecuteContext.scoped(
tmpdir = tmpdir,
localTmpdir = localTmpdir,
backend = backend,
fs = tmpFileManager.getFs,
timer = timer,
tempFileManager =
if (selfContainedExecution) null
else NonOwningTempFileManager(tmpFileManager),
theHailClassLoader = hcl,
flags = flags,
irMetadata = new IrMetadata(),
references = references,
blockMatrixCache = bmCache,
codeCache = codeCache,
irCache = persistedIr,
coercerCache = coercerCache,
)(f)
}
}

private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS =
Expand Down Expand Up @@ -356,7 +368,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
}

override def close(): Unit =
synchronized {
using(weak.acquire()) { _ =>
bmCache.close()
codeCache.clear()
persistedIr.clear()
Expand Down
5 changes: 5 additions & 0 deletions hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import scala.reflect.ClassTag
import scala.util.matching.Regex

import java.io.InputStream
import java.util.concurrent.locks.Lock

import breeze.linalg.DenseMatrix
import org.apache.hadoop.util.AutoCloseableLock
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -144,4 +146,7 @@ trait Implicits {

implicit def valueToRichCodeIterator[T](it: Value[Iterator[T]]): RichCodeIterator[T] =
new RichCodeIterator[T](it)

implicit def lockToAutoClosableLock(l: Lock): AutoCloseableLock =
new AutoCloseableLock(l)
}

0 comments on commit 27771b4

Please sign in to comment.