Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ object SparkConnectService extends Logging {
return
}

sessionManager.initializeBaseSession(sc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, at some point we should consider making the initialisation logic of connect less singleton heavy so we can pass the SparkContext as a constructor parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but session manager is a class object. what difference does it make for SparkConnectService to be a class object too

startGRPCService()
createListenerAndUI(sc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.control.NonFatal

import com.google.common.cache.CacheBuilder

import org.apache.spark.{SparkEnv, SparkSQLException}
import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO}
import org.apache.spark.sql.classic.SparkSession
Expand All @@ -39,6 +39,9 @@ import org.apache.spark.util.ThreadUtils
*/
class SparkConnectSessionManager extends Logging {

// Base SparkSession created from the SparkContext, used to create new isolated sessions
@volatile private var baseSession: Option[SparkSession] = None

private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
new ConcurrentHashMap[SessionKey, SessionHolder]()

Expand All @@ -48,6 +51,16 @@ class SparkConnectSessionManager extends Logging {
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
.build[SessionKey, SessionHolderInfo]()

/**
* Initialize the base SparkSession from the provided SparkContext.
* This should be called once during SparkConnectService startup.
*/
def initializeBaseSession(sc: SparkContext): Unit = {
if (baseSession.isEmpty) {
baseSession = Some(SparkSession.builder().sparkContext(sc).create())
}
}

/** Executor for the periodic maintenance */
private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
Expand Down Expand Up @@ -332,14 +345,8 @@ class SparkConnectSessionManager extends Logging {
logDebug("Finished periodic run of SparkConnectSessionManager maintenance.")
}

private def newIsolatedSession(): SparkSession = {
val active = SparkSession.active
if (active.sparkContext.isStopped) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@garlandz-db can you figure out why this branch is here. We may have to recreate the session if this is an actual problem...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the original pr: #43701 by Kent Yao. if the spark context is stopped then active.newSession() would throw an exception

 org.apache.spark.SparkException: com.google.common.util.concurrent.UncheckedExecutionException: java.lang.IllegalStateException: Cannot call methods on a stopped SparkContext.

my guess: spark cluster prob isnt useful but technically you can call spark connect apis still. so we can create a valid spark session and continue handling the rpc.

however our fix is tangential to that error. we do not need to use active in this case.

assert(SparkSession.getDefaultSession.nonEmpty)
SparkSession.getDefaultSession.get.newSession()
} else {
active.newSession()
}
private def newIsolatedSession(): SparkSession = synchronized {
baseSession.get.newSession()
}

private def validateSessionCreate(key: SessionKey): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkSQLException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession
Expand All @@ -32,6 +33,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
override def beforeEach(): Unit = {
super.beforeEach()
SparkConnectService.sessionManager.invalidateAllSessions()
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
}

test("sessionId needs to be an UUID") {
Expand Down Expand Up @@ -171,4 +173,51 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
sessionHolder.getPipelineExecution(graphId).isEmpty,
"pipeline execution was not removed")
}

test("baseSession allows creating sessions after default session is cleared") {
// Create a new session manager to test initialization
val sessionManager = new SparkConnectSessionManager()

// Initialize the base session with the test SparkContext
sessionManager.initializeBaseSession(spark.sparkContext)

// Clear the default and active sessions to simulate the scenario where
// SparkSession.active or SparkSession.getDefaultSession would fail
SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()

// Create an isolated session - this should still work because we have baseSession
val key = SessionKey("user", UUID.randomUUID().toString)
val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)

// Verify the session was created successfully
assert(sessionHolder != null)
assert(sessionHolder.session != null)

// Clean up
sessionManager.closeSession(key)
}

test("initializeBaseSession is idempotent") {
// Create a new session manager to test initialization
val sessionManager = new SparkConnectSessionManager()

// Initialize the base session multiple times
sessionManager.initializeBaseSession(spark.sparkContext)
val key1 = SessionKey("user1", UUID.randomUUID().toString)
val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
val baseSessionUUID1 = sessionHolder1.session.sessionUUID

// Initialize again - should not change the base session
sessionManager.initializeBaseSession(spark.sparkContext)
val key2 = SessionKey("user2", UUID.randomUUID().toString)
val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)

// Both sessions should be isolated from each other
assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID)

// Clean up
sessionManager.closeSession(key1)
sessionManager.closeSession(key2)
}
}