Skip to content
Merged
80 changes: 80 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Integration Tests with Service Account Authentication
#
# Required GitHub Secrets:
# - GCP_SA_KEY: Service account JSON key (project_id and client_email extracted automatically)
# - GCP_REGION: Google Cloud Region (optional, defaults to us-central1)
# - GCP_SUBNET: Dataproc subnet URI
#
# See INTEGRATION_TESTS.md for setup instructions.

name: Integration Tests
on:
pull_request:
branches: [ main ]
workflow_dispatch:

jobs:
integration-test:
name: Run integration tests
runs-on: ubuntu-latest

# Only run integration tests if secrets are available
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository)

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-integration-${{ hashFiles('requirements-dev.txt', 'requirements-test.txt') }}
restore-keys: |
${{ runner.os }}-pip-integration-
${{ runner.os }}-pip-

- name: Install dependencies
run: |
pip install -r requirements-dev.txt
pip install -r requirements-test.txt

- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v2
with:
credentials_json: ${{ secrets.GCP_SA_KEY }}

- name: Set up Cloud SDK
uses: google-github-actions/setup-gcloud@v2

- name: Run integration tests
env:
CI: "true"
# Extract from service account JSON automatically
GOOGLE_CLOUD_PROJECT: ${{ fromJson(secrets.GCP_SA_KEY).project_id }}
DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT: ${{ fromJson(secrets.GCP_SA_KEY).client_email }}
# Infrastructure-specific secrets
GOOGLE_CLOUD_REGION: ${{ secrets.GCP_REGION || 'us-central1' }}
DATAPROC_SPARK_CONNECT_SUBNET: ${{ secrets.GCP_SUBNET }}
DATAPROC_SPARK_CONNECT_AUTH_TYPE: "SERVICE_ACCOUNT"
run: |
python -m pytest tests/integration/ -v --tb=short -x
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[tool.pyink]
line-length = 80 # (default is 88)
pyink-indentation = 4 # (default is 4)

[tool.pytest.ini_options]
markers = [
"integration: marks tests as integration tests",
"ci_safe: marks tests that work in CI environment",
]
109 changes: 76 additions & 33 deletions tests/integration/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
TerminateSessionRequest,
)
from pyspark.errors.exceptions import connect as connect_exceptions
from pyspark.sql import functions as F
from pyspark.sql.types import StringType


Expand All @@ -49,9 +48,28 @@ def test_project():
return os.getenv("GOOGLE_CLOUD_PROJECT")


def is_ci_environment():
"""Detect if running in CI environment."""
return os.getenv("CI") == "true" or os.getenv("GITHUB_ACTIONS") == "true"


@pytest.fixture
def auth_type(request):
return getattr(request, "param", "SERVICE_ACCOUNT")
"""Auto-detect authentication type based on environment.

CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT
Local environment: Uses END_USER_CREDENTIALS
Test parametrization can still override this default.
"""
# Allow test parametrization to override
if hasattr(request, "param"):
return request.param

# Auto-detect based on environment
if is_ci_environment():
return "SERVICE_ACCOUNT"
else:
return "END_USER_CREDENTIALS"


@pytest.fixture
Expand Down Expand Up @@ -113,23 +131,29 @@ def session_template_controller_client(test_client_options):

@pytest.fixture
def connect_session(test_project, test_region, os_environment):
return (
session = (
DataprocSparkSession.builder.projectId(test_project)
.location(test_region)
.getOrCreate()
)
yield session
# Clean up the session after each test to prevent resource conflicts
try:
session.stop()
except Exception:
# Ignore cleanup errors to avoid masking the actual test failure
pass


@pytest.fixture
def session_name(test_project, test_region, connect_session):
return f"projects/{test_project}/locations/{test_region}/sessions/{DataprocSparkSession._active_s8s_session_id}"


@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
def test_create_spark_session_with_default_notebook_behavior(
auth_type, connect_session, session_name, session_controller_client
):
"""Test creating a Spark session with default notebook behavior using end user credentials."""
"""Test creating a Spark session with default notebook behavior using auto-detected authentication."""
get_session_request = GetSessionRequest()
get_session_request.name = session_name
session = session_controller_client.get_session(get_session_request)
Expand Down Expand Up @@ -348,8 +372,14 @@ def test_create_spark_session_with_session_template_and_user_provided_dataproc_c
assert DataprocSparkSession._active_s8s_session_uuid is None


@pytest.mark.skip(
reason="Skipping PyPI package installation test since it's not supported yet"
)
def test_add_artifacts_pypi_package():
"""Test adding PyPI packages as artifacts to a Spark session."""
"""Test adding PyPI packages as artifacts to a Spark session.

Note: Skipped in CI due to infrastructure issues with PyPI package installation.
"""
connect_session = DataprocSparkSession.builder.getOrCreate()
from pyspark.sql.connect.functions import udf, sum
from pyspark.sql.types import IntegerType
Expand Down Expand Up @@ -377,35 +407,38 @@ def generate_random2(row) -> int:

def test_sql_functions(connect_session):
"""Test basic SQL functions like col(), sum(), count(), etc."""
# Import SparkConnect-compatible functions
from pyspark.sql.connect.functions import col, sum, count

# Create a test DataFrame
df = connect_session.createDataFrame(
[(1, "Alice", 100), (2, "Bob", 200), (3, "Charlie", 150)],
["id", "name", "amount"],
)

# Test col() function
result_col = df.select(F.col("name")).collect()
result_col = df.select(col("name")).collect()
assert len(result_col) == 3
assert result_col[0]["name"] == "Alice"

# Test aggregation functions
sum_result = df.select(F.sum("amount")).collect()[0][0]
sum_result = df.select(sum("amount")).collect()[0][0]
assert sum_result == 450

count_result = df.select(F.count("id")).collect()[0][0]
count_result = df.select(count("id")).collect()[0][0]
assert count_result == 3

# Test with where clause using col()
filtered_df = df.where(F.col("amount") > 150)
filtered_df = df.where(col("amount") > 150)
filtered_count = filtered_df.count()
assert filtered_count == 1

# Test multiple column operations
df_with_calc = df.select(
F.col("id"),
F.col("name"),
F.col("amount"),
(F.col("amount") * 0.1).alias("tax"),
col("id"),
col("name"),
col("amount"),
(col("amount") * 0.1).alias("tax"),
)
tax_results = df_with_calc.collect()
assert tax_results[0]["tax"] == 10.0
Expand All @@ -415,6 +448,9 @@ def test_sql_functions(connect_session):

def test_sql_udf(connect_session):
"""Test SQL UDF registration and usage."""
# Import SparkConnect-compatible functions
from pyspark.sql.connect.functions import col, udf

# Create a test DataFrame
df = connect_session.createDataFrame(
[(1, "hello"), (2, "world"), (3, "spark")], ["id", "text"]
Expand All @@ -428,9 +464,9 @@ def uppercase_func(text):
return text.upper() if text else None

# Test UDF with DataFrame API
uppercase_udf = F.udf(uppercase_func, StringType())
uppercase_udf = udf(uppercase_func, StringType())
df_with_udf = df.select(
"id", "text", uppercase_udf(F.col("text")).alias("upper_text")
"id", "text", uppercase_udf(col("text")).alias("upper_text")
)
df_result = df_with_udf.collect()

Expand All @@ -441,7 +477,6 @@ def uppercase_func(text):
connect_session.sql("DROP VIEW IF EXISTS test_table")


@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
def test_session_reuse_with_custom_id(
auth_type,
test_project,
Expand All @@ -450,7 +485,8 @@ def test_session_reuse_with_custom_id(
os_environment,
):
"""Test the real-world session reuse scenario: create → terminate → recreate with same ID."""
custom_session_id = "ml-pipeline-session"
# Use a randomized session ID to avoid conflicts between test runs
custom_session_id = f"ml-pipeline-session-{uuid.uuid4().hex[:8]}"

# Stop any existing session first to ensure clean state
if DataprocSparkSession._active_s8s_session_id:
Expand All @@ -462,9 +498,12 @@ def test_session_reuse_with_custom_id(
pass

# PHASE 1: Create initial session with custom ID
spark1 = DataprocSparkSession.builder.dataprocSessionId(
custom_session_id
).getOrCreate()
spark1 = (
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
.projectId(test_project)
.location(test_region)
.getOrCreate()
)

# Verify session is created with custom ID
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
Expand All @@ -479,9 +518,12 @@ def test_session_reuse_with_custom_id(
# Clear cache to force session lookup
DataprocSparkSession._default_session = None

spark2 = DataprocSparkSession.builder.dataprocSessionId(
custom_session_id
).getOrCreate()
spark2 = (
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
.projectId(test_project)
.location(test_region)
.getOrCreate()
)

# Should reuse the same active session
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
Expand All @@ -492,7 +534,7 @@ def test_session_reuse_with_custom_id(
result2 = df2.count()
assert result2 == 1

# PHASE 3: Terminate session explicitly
# PHASE 3: Stop should not terminate named session
spark2.stop()

# PHASE 4: Recreate with same ID - this tests the cleanup and recreation logic
Expand All @@ -501,16 +543,19 @@ def test_session_reuse_with_custom_id(
DataprocSparkSession._active_s8s_session_id = None
DataprocSparkSession._active_s8s_session_uuid = None

spark3 = DataprocSparkSession.builder.dataprocSessionId(
custom_session_id
).getOrCreate()
spark3 = (
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
.projectId(test_project)
.location(test_region)
.getOrCreate()
)

# Should be a new session with same ID but different UUID
# Should be a same session and same ID
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
third_session_uuid = spark3._active_s8s_session_uuid

# Should be different UUID (new session instance)
assert third_session_uuid != first_session_uuid
# Should be same UUID
assert third_session_uuid == first_session_uuid

# Test functionality on recreated session
df3 = spark3.createDataFrame([(3, "recreated")], ["id", "stage"])
Expand Down Expand Up @@ -543,7 +588,6 @@ def test_session_id_validation_in_integration(
assert builder._custom_session_id == valid_id


@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
def test_sparksql_magic_library_available(connect_session):
"""Test that sparksql-magic library can be imported and loaded."""
pytest.importorskip(
Expand Down Expand Up @@ -577,7 +621,6 @@ def test_sparksql_magic_library_available(connect_session):
assert data[0]["test_column"] == "integration_test"


@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
def test_sparksql_magic_with_dataproc_session(connect_session):
"""Test that sparksql-magic works with registered DataprocSparkSession."""
pytest.importorskip(
Expand Down