diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 13ce2d64256b4..4641bc0a11067 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -436,6 +436,7 @@ object SparkConnectService extends Logging { return } + sessionManager.initializeBaseSession(sc) startGRPCService() createListenerAndUI(sc) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index f28af0379a04c..391f09c884b7b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import com.google.common.cache.CacheBuilder -import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO} import org.apache.spark.sql.classic.SparkSession @@ -39,6 +39,9 @@ import org.apache.spark.util.ThreadUtils */ class SparkConnectSessionManager extends Logging { + // Base SparkSession created from the SparkContext, used to create new isolated sessions + @volatile private var baseSession: Option[SparkSession] = None + private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] = new ConcurrentHashMap[SessionKey, SessionHolder]() @@ -48,6 +51,16 @@ class SparkConnectSessionManager extends Logging { .maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE)) .build[SessionKey, SessionHolderInfo]() + /** + * Initialize the base SparkSession from the provided SparkContext. + * This should be called once during SparkConnectService startup. + */ + def initializeBaseSession(sc: SparkContext): Unit = { + if (baseSession.isEmpty) { + baseSession = Some(SparkSession.builder().sparkContext(sc).getOrCreate().newSession()) + } + } + /** Executor for the periodic maintenance */ private val scheduledExecutor: AtomicReference[ScheduledExecutorService] = new AtomicReference[ScheduledExecutorService]() @@ -333,13 +346,7 @@ class SparkConnectSessionManager extends Logging { } private def newIsolatedSession(): SparkSession = { - val active = SparkSession.active - if (active.sparkContext.isStopped) { - assert(SparkSession.getDefaultSession.nonEmpty) - SparkSession.getDefaultSession.get.newSession() - } else { - active.newSession() - } + baseSession.get.newSession() } private def validateSessionCreate(key: SessionKey): Unit = { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala index a3d851c1ce7b9..deed5214907ed 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkSQLException +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl} import org.apache.spark.sql.pipelines.logging.PipelineEvent import org.apache.spark.sql.test.SharedSparkSession @@ -32,6 +33,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA override def beforeEach(): Unit = { super.beforeEach() SparkConnectService.sessionManager.invalidateAllSessions() + SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext) } test("sessionId needs to be an UUID") { @@ -171,4 +173,51 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA sessionHolder.getPipelineExecution(graphId).isEmpty, "pipeline execution was not removed") } + + test("baseSession allows creating sessions after default session is cleared") { + // Create a new session manager to test initialization + val sessionManager = new SparkConnectSessionManager() + + // Initialize the base session with the test SparkContext + sessionManager.initializeBaseSession(spark.sparkContext) + + // Clear the default and active sessions to simulate the scenario where + // SparkSession.active or SparkSession.getDefaultSession would fail + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + + // Create an isolated session - this should still work because we have baseSession + val key = SessionKey("user", UUID.randomUUID().toString) + val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None) + + // Verify the session was created successfully + assert(sessionHolder != null) + assert(sessionHolder.session != null) + + // Clean up + sessionManager.closeSession(key) + } + + test("initializeBaseSession is idempotent") { + // Create a new session manager to test initialization + val sessionManager = new SparkConnectSessionManager() + + // Initialize the base session multiple times + sessionManager.initializeBaseSession(spark.sparkContext) + val key1 = SessionKey("user1", UUID.randomUUID().toString) + val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None) + val baseSessionUUID1 = sessionHolder1.session.sessionUUID + + // Initialize again - should not change the base session + sessionManager.initializeBaseSession(spark.sparkContext) + val key2 = SessionKey("user2", UUID.randomUUID().toString) + val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None) + + // Both sessions should be isolated from each other + assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID) + + // Clean up + sessionManager.closeSession(key1) + sessionManager.closeSession(key2) + } }