|
| 1 | +import os |
| 2 | +import time |
| 3 | +import random |
| 4 | +from packaging import version |
| 5 | +from integration_tests.dataproc_test_case import DataprocTestCase |
| 6 | + |
| 7 | +DEFAULT_TIMEOUT = 45 # minutes |
| 8 | + |
| 9 | +class GpuTestCaseBase(DataprocTestCase): |
| 10 | + def __init__(self, *args, **kwargs): |
| 11 | + super().__init__(*args, **kwargs) |
| 12 | + |
| 13 | + def run_dataproc_job(self, |
| 14 | + cluster_name, |
| 15 | + job_type, |
| 16 | + job_params, |
| 17 | + timeout_in_minutes=DEFAULT_TIMEOUT): |
| 18 | + """Executes Dataproc job on a cluster and returns results. |
| 19 | +
|
| 20 | + Args: |
| 21 | + cluster_name: cluster name to submit job to |
| 22 | + job_type: type of the job, e.g. spark, hadoop, pyspark |
| 23 | + job_params: job parameters |
| 24 | + timeout_in_minutes: timeout in minutes |
| 25 | +
|
| 26 | + Returns: |
| 27 | + ret_code: the return code of the job |
| 28 | + stdout: standard output of the job |
| 29 | + stderr: error output of the job |
| 30 | + """ |
| 31 | + |
| 32 | + ret_code, stdout, stderr = DataprocTestCase.run_command( |
| 33 | + 'gcloud dataproc jobs submit {} --cluster={} --region={} {}'. |
| 34 | + format(job_type, cluster_name, self.REGION, |
| 35 | + job_params), timeout_in_minutes) |
| 36 | + return ret_code, stdout, stderr |
| 37 | + |
| 38 | + # Tests for PyTorch |
| 39 | + TORCH_TEST_SCRIPT_FILE_NAME = "verify_pytorch.py" |
| 40 | + |
| 41 | + # Tests for TensorFlow |
| 42 | + TF_TEST_SCRIPT_FILE_NAME = "verify_tensorflow.py" |
| 43 | + |
| 44 | + def assert_instance_command(self, |
| 45 | + instance, |
| 46 | + cmd, |
| 47 | + timeout_in_minutes=DEFAULT_TIMEOUT): |
| 48 | + retry_count = 5 |
| 49 | + ssh_cmd = 'gcloud compute ssh -q {} --zone={} --command="{}" -- -o ConnectTimeout=60 -o StrictHostKeyChecking=no'.format( |
| 50 | + instance, self.cluster_zone, cmd.replace('"', '\"')) |
| 51 | + |
| 52 | + while retry_count > 0: |
| 53 | + try: |
| 54 | + # Use self.assert_command from DataprocTestCase |
| 55 | + ret_code, stdout, stderr = self.assert_command(ssh_cmd, timeout_in_minutes) |
| 56 | + return ret_code, stdout, stderr |
| 57 | + except Exception as e: |
| 58 | + print(f"An error occurred in assert_instance_command: {e}") |
| 59 | + retry_count -= 1 |
| 60 | + if retry_count > 0: |
| 61 | + print(f"Retrying in 10 seconds...") |
| 62 | + time.sleep(10) |
| 63 | + continue |
| 64 | + else: |
| 65 | + print("Max retries reached.") |
| 66 | + raise |
| 67 | + |
| 68 | + def verify_instance(self, name): |
| 69 | + # Verify that nvidia-smi works |
| 70 | + self.assert_instance_command(name, "nvidia-smi", 1) |
| 71 | + print(f"OK: nvidia-smi on {name}") |
| 72 | + |
| 73 | + def verify_instance_gpu_agent(self, name): |
| 74 | + print(f"--- Verifying GPU Agent on {name} ---") |
| 75 | + self.assert_instance_command( |
| 76 | + name, "systemctl is-active gpu-utilization-agent.service") |
| 77 | + print(f"OK: GPU Agent on {name}") |
| 78 | + |
| 79 | + def get_dataproc_image_version(self, instance): |
| 80 | + _, stdout, _ = self.assert_instance_command(instance, "grep DATAPROC_IMAGE_VERSION /etc/environment | cut -d= -f2") |
| 81 | + return stdout.strip() |
| 82 | + |
| 83 | + def version_lt(self, v1, v2): |
| 84 | + return version.parse(v1) < version.parse(v2) |
| 85 | + |
| 86 | + def verify_pytorch(self, name): |
| 87 | + print(f"--- Verifying PyTorch on {name} ---") |
| 88 | + test_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "gpu", |
| 89 | + self.TORCH_TEST_SCRIPT_FILE_NAME) |
| 90 | + self.upload_test_file(test_filename, name) |
| 91 | + |
| 92 | + image_version = self.get_dataproc_image_version(name) |
| 93 | + conda_root_path = "/opt/conda/miniconda3" |
| 94 | + if not self.version_lt(image_version, "2.3"): |
| 95 | + conda_root_path = "/opt/conda" |
| 96 | + |
| 97 | + conda_env = "dpgce" |
| 98 | + env_path = f"{conda_root_path}/envs/{conda_env}" |
| 99 | + python_bin = f"{env_path}/bin/python3" |
| 100 | + |
| 101 | + verify_cmd = ( |
| 102 | + f"for f in /sys/module/nvidia/drivers/pci:nvidia/*/numa_node; do " |
| 103 | + f" if [[ -e \\\"$f\\\" ]]; then echo 0 > \\\"$f\\\"; fi; " |
| 104 | + f"done; " |
| 105 | + f"if /usr/share/google/get_metadata_value attributes/include-pytorch; then" |
| 106 | + f" {python_bin} {self.TORCH_TEST_SCRIPT_FILE_NAME}; " |
| 107 | + f"else echo 'PyTorch test skipped as include-pytorch is not set'; fi" |
| 108 | + ) |
| 109 | + _, stdout, _ = self.assert_instance_command(name, verify_cmd) |
| 110 | + if "PyTorch test skipped" not in stdout: |
| 111 | + self.assertTrue("True" in stdout, f"PyTorch CUDA not available or python not found in {env_path}") |
| 112 | + print(f"OK: PyTorch on {name}") |
| 113 | + self.remove_test_script(self.TORCH_TEST_SCRIPT_FILE_NAME, name) |
| 114 | + |
| 115 | + def verify_tensorflow(self, name): |
| 116 | + print(f"--- Verifying TensorFlow on {name} ---") |
| 117 | + test_filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "gpu", |
| 118 | + self.TF_TEST_SCRIPT_FILE_NAME) |
| 119 | + self.upload_test_file(test_filename, name) |
| 120 | + |
| 121 | + image_version = self.get_dataproc_image_version(name) |
| 122 | + conda_root_path = "/opt/conda/miniconda3" |
| 123 | + if not self.version_lt(image_version, "2.3"): |
| 124 | + conda_root_path = "/opt/conda" |
| 125 | + |
| 126 | + conda_env="dpgce" |
| 127 | + env_path = f"{conda_root_path}/envs/{conda_env}" |
| 128 | + python_bin = f"{env_path}/bin/python3" |
| 129 | + |
| 130 | + verify_cmd = ( |
| 131 | + f"for f in $(ls /sys/module/nvidia/drivers/pci:nvidia/*/numa_node) ; do echo 0 > ${{f}} ; done ;" |
| 132 | + f"{python_bin} {self.TF_TEST_SCRIPT_FILE_NAME}" |
| 133 | + ) |
| 134 | + self.assert_instance_command(name, verify_cmd) |
| 135 | + print(f"OK: TensorFlow on {name}") |
| 136 | + self.remove_test_script(self.TF_TEST_SCRIPT_FILE_NAME, name) |
0 commit comments