diff --git a/tests/conftest.py b/tests/conftest.py index fcd912a83..41836fd5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,6 +98,21 @@ def is_port_free(port: int) -> bool: return s.connect_ex(("localhost", port)) != 0 +def force_restart_spark_context(): + # Restart SparkContext will make sure that the new environment variables are available inside the JVM + # This is a hacky way to allow debugging in the same process. + from pyspark import SparkContext + + with SparkContext._lock: + # Need to shut down before creating a new SparkConf (Before SparkContext is not enough) + # Like this, the new environment variables are available inside the JVM + if SparkContext._active_spark_context: + SparkContext._active_spark_context.stop() + SparkContext._gateway.shutdown() + SparkContext._gateway = None + SparkContext._jvm = None + + def _setup_local_spark(out: TerminalReporter, verbosity=0): # TODO make a "spark_context" fixture instead of doing this through pytest_configure out.write_line("[conftest.py] Setting up local Spark") diff --git a/tests/deploy/test_batch_job.py b/tests/deploy/test_batch_job.py index 92f1dc412..b773beda7 100644 --- a/tests/deploy/test_batch_job.py +++ b/tests/deploy/test_batch_job.py @@ -1,5 +1,6 @@ import json import logging +import os import re import shutil import subprocess @@ -41,7 +42,7 @@ read_gdal_raster_metadata, ) from openeogeotrellis.testing import gps_config_overrides -from openeogeotrellis.utils import get_jvm, to_s3_url, s3_client +from openeogeotrellis.utils import get_jvm, to_s3_url, s3_client, stream_s3_binary_file_contents _log = logging.getLogger(__name__) @@ -1305,7 +1306,7 @@ def test_run_job_get_projection_extension_metadata_assets_in_s3_multiple_assets( ) -@pytest.mark.skip("Can only run manually") # TODO: Fix so it can run in Jenkins too +# @pytest.mark.skip("Can only run manually") # TODO: Fix so it can run in Jenkins too def test_run_job_to_s3( tmp_path, mock_s3_bucket, @@ -1335,37 +1336,61 @@ def test_run_job_to_s3( "result": True, }, } - json_path = tmp_path / "process_graph.json" - json.dump(process_graph, json_path.open("wt")) - - containing_folder = Path(__file__).parent - cmd = [ - sys.executable, - containing_folder.parent.parent / "openeogeotrellis/deploy/run_graph_locally.py", - json_path, - ] - # Run in separate subprocess so that all environment variables are - # set correctly at the moment the SparkContext is created: - try: - output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) - except subprocess.CalledProcessError as e: - _log.error("run_graph_locally failed. Output: " + e.output) - raise - print(output) + separate_process = True + if separate_process: + json_path = tmp_path / "process_graph.json" + json.dump(process_graph, json_path.open("wt"), indent=2) + containing_folder = Path(__file__).parent + cmd = [ + sys.executable, + containing_folder.parent.parent / "openeogeotrellis/deploy/run_graph_locally.py", + json_path, + ] + # Run in separate subprocess so that all environment variables are + # set correctly at the moment the SparkContext is created: + try: + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) + except subprocess.CalledProcessError as e: + _log.error("run_graph_locally failed. Output: " + e.output) + raise + print(output) + else: + from openeogeotrellis.configparams import ConfigParams + + if ConfigParams().use_object_storage: + from tests.conftest import force_restart_spark_context + + force_restart_spark_context() + + # Run in the same process, so that we can check the output directly: + from openeogeotrellis.deploy.run_graph_locally import run_graph_locally + + run_graph_locally(process_graph, tmp_path) s3_instance = s3_client() from openeogeotrellis.config import get_backend_config - with open(json_path, "rb") as f: - s3_instance.upload_fileobj( - f, get_backend_config().s3_bucket_name, str((tmp_path / "test.json").relative_to("/")) - ) - - files = {o["Key"] for o in s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"]} - files = [f[len(str(tmp_path)) :] for f in files] + files_absolute = { + o["Key"] for o in s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"] + } + files = [f[len(str(tmp_path)) :] for f in files_absolute] assert files == ListSubSet(["collection.json", "openEO_2021-01-05Z.tif", "openEO_2021-01-05Z.tif.json"]) + metadata_file = next(f for f in files_absolute if str(f).__contains__("metadata")) + s3_file_object = s3_instance.get_object( + Bucket=get_backend_config().s3_bucket_name, + Key=str(metadata_file).strip("/"), + ) + streaming_body = s3_file_object["Body"] + with open(tmp_path / "metadata.json", "wb") as f: + f.write(streaming_body.read()) + + metadata = json.load(open(tmp_path / "metadata.json")) + s3_links = [metadata["assets"][a]["href"] for a in metadata["assets"]] + test = stream_s3_binary_file_contents(s3_links[0]) + print(test) + # TODO: Update this test to include statistics or not? Would need to update the json file. @pytest.mark.parametrize(