diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala index cfbe7112cc14..39beaaa4969d 100644 --- a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -32,10 +32,10 @@ import java.io.Closeable import java.net.InetSocketAddress import java.util import java.util.concurrent._ +import java.util.concurrent.locks.ReentrantReadWriteLock import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpServer} -import javax.annotation.Nullable import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.DataFrame @@ -46,6 +46,10 @@ import sourcecode.Enclosing final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { + private[this] val rwlock = new ReentrantReadWriteLock() + private[this] def weak = rwlock.readLock() + private[this] def mutex = rwlock.writeLock() + private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() private[this] val hcl = new HailClassLoader(getClass.getClassLoader) private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) @@ -79,62 +83,63 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin } def pyFs: FS = - tmpFileManager.getFs + using(weak.acquire())(_ => tmpFileManager.getFs) def pyGetFlag(name: String): String = - flags.get(name) + using(weak.acquire())(_ => flags.get(name)) def pySetFlag(name: String, value: String): Unit = - flags.set(name, value) + using(weak.acquire())(_ => flags.set(name, value)) def pyAvailableFlags: java.util.ArrayList[String] = flags.available def pySetRemoteTmp(tmp: String): Unit = - tmpdir = tmp + using(weak.acquire())(_ => tmpdir = tmp) def pySetLocalTmp(tmp: String): Unit = - localTmpdir = tmp - - def pySetGcsRequesterPaysConfig(@Nullable project: String, @Nullable buckets: util.List[String]) - : Unit = { - val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) - - val rpConfig: Option[RequesterPaysConfig] = - ( - Option(project).filter(_.nonEmpty), - Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty), - ) match { - case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets)) - case (None, Some(_)) => fatal( - "A non-empty, non-null requester pays google project is required to configure requester pays buckets." - ) - case (None, None) => None - } - - val fs = newFs( - cloudfsConf.copy( - google = (cloudfsConf.google, rpConfig) match { - case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig)) - case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig)) - case _ => None + using(weak.acquire())(_ => localTmpdir = tmp) + + def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit = + using(weak.acquire()) { _ => + val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) + + val rpConfig: Option[RequesterPaysConfig] = + ( + Option(project).filter(_.nonEmpty), + Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty), + ) match { + case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets)) + case (None, Some(_)) => fatal( + "A non-empty, non-null requester pays google project is required to configure requester pays buckets." + ) + case (None, None) => None } + + val fs = newFs( + cloudfsConf.copy( + google = (cloudfsConf.google, rpConfig) match { + case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig)) + case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig)) + case _ => None + } + ) ) - ) - tmpFileManager.setFs(fs) - } + tmpFileManager.setFs(fs) + } def pyRemoveJavaIR(id: Int): Unit = - persistedIr.remove(id) + using(weak.acquire())(_ => persistedIr.remove(id)) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - references(name).addSequence( - IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile) - ) + using(weak.acquire()) { _ => + val seq = IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile) + references(name).addSequence(seq) + } def pyRemoveSequence(name: String): Unit = - references(name).removeSequence() + using(weak.acquire())(_ => references(name).removeSequence()) def pyExportBlockMatrix( pathIn: String, @@ -260,16 +265,21 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin }._1 def pyAddReference(jsonConfig: String): Unit = - addReference(ReferenceGenome.fromJSON(jsonConfig)) + using(weak.acquire())(_ => addReference(ReferenceGenome.fromJSON(jsonConfig))) def pyRemoveReference(name: String): Unit = - removeReference(name) + using(weak.acquire())(_ => removeReference(name)) def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - references(name).addLiftover(references(destRGName), LiftOver(tmpFileManager.getFs, chainFile)) + using(weak.acquire()) { _ => + references(name).addLiftover( + references(destRGName), + LiftOver(tmpFileManager.getFs, chainFile), + ) + } def pyRemoveLiftover(name: String, destRGName: String): Unit = - references(name).removeLiftover(destRGName) + using(weak.acquire())(_ => references(name).removeLiftover(destRGName)) def parse_blockmatrix_ir(s: String): BlockMatrixIR = withExecuteContext(selfContainedExecution = false) { ctx => @@ -300,25 +310,27 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin f: ExecuteContext => T )(implicit E: Enclosing ): (T, Timings) = - ExecutionTimer.time { timer => - ExecuteContext.scoped( - tmpdir = tmpdir, - localTmpdir = localTmpdir, - backend = backend, - fs = tmpFileManager.getFs, - timer = timer, - tempFileManager = - if (selfContainedExecution) null - else NonOwningTempFileManager(tmpFileManager), - theHailClassLoader = hcl, - flags = flags, - irMetadata = new IrMetadata(), - references = references, - blockMatrixCache = bmCache, - codeCache = codeCache, - irCache = persistedIr, - coercerCache = coercerCache, - )(f) + using(mutex.acquire()) { _ => + ExecutionTimer.time { timer => + ExecuteContext.scoped( + tmpdir = tmpdir, + localTmpdir = localTmpdir, + backend = backend, + fs = tmpFileManager.getFs, + timer = timer, + tempFileManager = + if (selfContainedExecution) null + else NonOwningTempFileManager(tmpFileManager), + theHailClassLoader = hcl, + flags = flags, + irMetadata = new IrMetadata(), + references = references, + blockMatrixCache = bmCache, + codeCache = codeCache, + irCache = persistedIr, + coercerCache = coercerCache, + )(f) + } } private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS = @@ -356,7 +368,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin } override def close(): Unit = - synchronized { + using(weak.acquire()) { _ => bmCache.close() codeCache.clear() persistedIr.clear() diff --git a/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala b/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala index a6870f50b569..6ce1116cd56e 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala @@ -12,8 +12,10 @@ import scala.reflect.ClassTag import scala.util.matching.Regex import java.io.InputStream +import java.util.concurrent.locks.Lock import breeze.linalg.DenseMatrix +import org.apache.hadoop.util.AutoCloseableLock import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix import org.apache.spark.rdd.RDD @@ -144,4 +146,7 @@ trait Implicits { implicit def valueToRichCodeIterator[T](it: Value[Iterator[T]]): RichCodeIterator[T] = new RichCodeIterator[T](it) + + implicit def lockToAutoClosableLock(l: Lock): AutoCloseableLock = + new AutoCloseableLock(l) }