From 523cf68148239fbd9bc91510e4cf82b4aa397887 Mon Sep 17 00:00:00 2001 From: natthan-pigoux Date: Mon, 4 Nov 2024 11:23:43 +0100 Subject: [PATCH] feat: lock credentials path to avoid concurrent access --- .../src/diracx/client/patches/utils.py | 58 +++-- diracx-core/src/diracx/core/utils.py | 6 +- diracx-core/tests/test_utils.py | 231 +++++++++++++++--- 3 files changed, 229 insertions(+), 66 deletions(-) diff --git a/diracx-client/src/diracx/client/patches/utils.py b/diracx-client/src/diracx/client/patches/utils.py index 5710a6f5..13ad931f 100644 --- a/diracx-client/src/diracx/client/patches/utils.py +++ b/diracx-client/src/diracx/client/patches/utils.py @@ -83,27 +83,37 @@ def get_token( fcntl.flock(f, fcntl.LOCK_UN) return None - # If we are here, it means the token needs to be refreshed - token_response = refresh_token( - token_endpoint, - client_id, - response.refresh_token, - verify=verify, - ) - - # Write the new credentials to the file - f.seek(0) - f.truncate() - f.write(serialize_credentials(token_response)) - f.flush() - os.fsync(f.fileno()) - - # Get an AccessToken instance - return AccessToken( - token=token_response.access_token, - expires_on=datetime.now(tz=timezone.utc) - + timedelta(seconds=token_response.expires_in - EXPIRES_GRACE_SECONDS), - ) + if response.status == TokenStatus.REFRESH and response.refresh_token: + # If we are here, it means the token needs to be refreshed + token_response = refresh_token( + token_endpoint, + client_id, + response.refresh_token, + verify=verify, + ) + + # Write the new credentials to the file + f.seek(0) + f.truncate() + f.write(serialize_credentials(token_response)) + f.flush() + os.fsync(f.fileno()) + + # Get an AccessToken instance + return AccessToken( + token=token_response.access_token, + expires_on=int( + ( + datetime.now(tz=timezone.utc) + + timedelta( + seconds=token_response.expires_in + - EXPIRES_GRACE_SECONDS + ) + ).timestamp() + ), + ) + else: + return None finally: # Release the lock fcntl.flock(f, fcntl.LOCK_UN) @@ -166,7 +176,7 @@ def extract_token_from_credentials( return TokenResult(TokenStatus.VALID, access_token=token) if is_refresh_token_valid(refresh_token): - return TokenResult(TokenStatus.REFRESH, refresh_token=credentials.refresh_token) + return TokenResult(TokenStatus.REFRESH, refresh_token=refresh_token) # If we are here, it means the refresh token is not valid anymore return TokenResult(TokenStatus.INVALID) @@ -243,7 +253,9 @@ def on_request(self, request: PipelineRequest) -> None: :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - self._token = self._credential.get_token("", token=self._token) + self._token: AccessToken | None = self._credential.get_token( + "", token=self._token + ) if not self._token: # If we are here, it means the token is not available # we suppose it is not needed to perform the request diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index 95c1319a..8320dcc7 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -46,14 +46,14 @@ def read_credentials(location: Path) -> TokenResponse: try: with open(credentials_path, "r") as f: # Lock the file to prevent other processes from writing to it at the same time - fcntl.flock(f, fcntl.LOCK_SH | fcntl.LOCK_NB) + fcntl.flock(f, fcntl.LOCK_SH) # Read the credentials from the file try: credentials = json.load(f) finally: # Release the lock fcntl.flock(f, fcntl.LOCK_UN) - except (BlockingIOError, FileNotFoundError, json.JSONDecodeError) as e: + except (FileNotFoundError, json.JSONDecodeError) as e: raise RuntimeError(f"Error reading credentials: {e}") from e return TokenResponse( @@ -74,7 +74,7 @@ def write_credentials(token_response: TokenResponse, *, location: Path | None = with open(credentials_path, "w") as f: # Lock the file to prevent other processes from writing to it at the same time - fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + fcntl.flock(f, fcntl.LOCK_EX) try: # Write the credentials to the file f.write(serialize_credentials(token_response)) diff --git a/diracx-core/tests/test_utils.py b/diracx-core/tests/test_utils.py index 0da0cc3d..879958f2 100644 --- a/diracx-core/tests/test_utils.py +++ b/diracx-core/tests/test_utils.py @@ -1,14 +1,18 @@ from __future__ import annotations import fcntl +import json import time from datetime import datetime, timedelta, timezone from multiprocessing import Pool from pathlib import Path from tempfile import NamedTemporaryFile +from unittest.mock import patch import pytest +from azure.core.credentials import AccessToken +from diracx.client.patches.utils import get_token from diracx.core.models import TokenResponse from diracx.core.utils import ( dotenv_files_from_environment, @@ -47,13 +51,13 @@ def test_dotenv_files_from_environment(monkeypatch): "token_type": "Bearer", "refresh_token": "test_refresh", } -CREDENTIALS_CONTENT = serialize_credentials(TokenResponse(**TOKEN_RESPONSE_DICT)) +CREDENTIALS_CONTENT: str = serialize_credentials(TokenResponse(**TOKEN_RESPONSE_DICT)) def lock_and_read_file(file_path): """Lock and read file.""" with open(file_path, "r") as f: - fcntl.flock(f, fcntl.LOCK_SH | fcntl.LOCK_NB) + fcntl.flock(f, fcntl.LOCK_SH) f.read() time.sleep(2) fcntl.flock(f, fcntl.LOCK_UN) @@ -62,18 +66,20 @@ def lock_and_read_file(file_path): def lock_and_write_file(file_path: Path): """Lock and write file.""" with open(file_path, "a") as f: - fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + fcntl.flock(f, fcntl.LOCK_EX) f.write(CREDENTIALS_CONTENT) time.sleep(2) fcntl.flock(f, fcntl.LOCK_UN) @pytest.fixture -def token_setup() -> tuple[TokenResponse, Path]: +def token_setup() -> tuple[TokenResponse, Path, AccessToken]: """Setup token response and location.""" token_location = Path(NamedTemporaryFile().name) token_response = TokenResponse(**TOKEN_RESPONSE_DICT) - return token_response, token_location + access_token = AccessToken(token_response.access_token, token_response.expires_in) + + return token_response, token_location, access_token @pytest.fixture @@ -104,7 +110,7 @@ def run_processes(proc_to_test, *, read=True): ), ) time.sleep(1) - pool.apply_async( + result = pool.apply_async( proc_to_test[0], kwds=proc_to_test[1], error_callback=lambda e: error_callback( @@ -113,7 +119,8 @@ def run_processes(proc_to_test, *, read=True): ) pool.close() pool.join() - return error_dict + res = result.get(timeout=1) + return res, error_dict return run_processes @@ -127,49 +134,37 @@ def assert_read_credentials_error_message(exc_info): assert "Error reading credentials:" in exc_info.value.args[0] +def create_temp_file(content=None) -> Path: + """Helper function to create a temporary file with optional content.""" + temp_file = NamedTemporaryFile(delete=False) + temp_path = Path(temp_file.name) + temp_file.close() + if content is not None: + temp_path.write_text(content) + return temp_path + + def test_read_credentials_reading_locked_file( token_setup, concurrent_access_to_lock_file ): - """Test that read_credentials reading a locked file end in error.""" - _, token_location = token_setup + """Test that read_credentials is waiting to read a locked file end in error.""" + _, token_location, _ = token_setup process_to_test = (read_credentials, {"location": token_location}) - error_dict = concurrent_access_to_lock_file(process_to_test, read=False) - process_name = process_to_test[0].__name__ - if process_name in error_dict.keys(): - assert isinstance(error_dict[process_name], RuntimeError) - else: - raise AssertionError( - "Expected a RuntimeError while reading locked credentials." - ) + _, error_dict = concurrent_access_to_lock_file(process_to_test, read=False) + assert not error_dict def test_write_credentials_writing_locked_file( token_setup, concurrent_access_to_lock_file ): - """Test that write_credentials writing a locked file end in error.""" - token_response, token_location = token_setup + """Test that write_credentials is waiting to write a locked file end in error.""" + token_response, token_location, _ = token_setup process_to_test = ( write_credentials, {"token_response": token_response, "location": token_location}, ) - error_dict = concurrent_access_to_lock_file(process_to_test) - process_name = process_to_test[0].__name__ - if process_name in error_dict.keys(): - assert isinstance(error_dict[process_name], BlockingIOError) - else: - raise AssertionError( - "Expected a BlockingIOError while writing locked credentials." - ) - - -def create_temp_file(content=None) -> Path: - """Helper function to create a temporary file with optional content.""" - temp_file = NamedTemporaryFile(delete=False) - temp_path = Path(temp_file.name) - temp_file.close() - if content is not None: - temp_path.write_text(content) - return temp_path + _, error_dict = concurrent_access_to_lock_file(process_to_test) + assert not error_dict def test_read_credentials_empty_file(): @@ -186,7 +181,7 @@ def test_read_credentials_empty_file(): def test_write_credentials_empty_file(token_setup): """Test that write_credentials raises an appropriate error for an empty file.""" temp_file = create_temp_file("") - token_response, _ = token_setup + token_response, _, _ = token_setup write_credentials(token_response, location=temp_file) temp_file.unlink() @@ -202,7 +197,7 @@ def test_read_credentials_missing_file(): def test_write_credentials_unavailable_path(token_setup): """Test that write_credentials raises error when it can't create path.""" wrong_path = Path("/wrong/path/file.txt") - token_response, _ = token_setup + token_response, _, _ = token_setup with pytest.raises(PermissionError): write_credentials(token_response, location=wrong_path) @@ -220,7 +215,7 @@ def test_read_credentials_invalid_content(): def test_read_credentials_valid_file(token_setup): """Test that read_credentials works correctly with a valid file.""" - token_response, _ = token_setup + token_response, _, _ = token_setup temp_file = create_temp_file(content=CREDENTIALS_CONTENT) credentials = read_credentials(location=temp_file) @@ -229,3 +224,159 @@ def test_read_credentials_valid_file(token_setup): assert credentials.expires_in < token_response.expires_in assert credentials.token_type == token_response.token_type assert credentials.refresh_token == token_response.refresh_token + + +# Testing get_token: + + +def test_get_token_accessing_lock_file(token_setup, concurrent_access_to_lock_file): + """Test get_token is waiting to read token from locked file.""" + token_response, token_location, _ = token_setup + process_to_test = ( + get_token, + { + "location": token_location, + "token": None, + "token_endpoint": "/endpoint", + "client_id": "ID", + "verify": False, + }, + ) + result, error_dict = concurrent_access_to_lock_file(process_to_test, read=False) + assert not error_dict + assert isinstance(result, AccessToken) + assert result.token == token_response.access_token + + +def test_get_token_valid_input_token(token_setup): + """Test that get_token return the valid token.""" + _, token_location, access_token = token_setup + result = get_token( + location=token_location, + token=access_token, + token_endpoint="", + client_id="ID", + verify=False, + ) + assert result == access_token + + +def test_get_token_valid_input_credential(): + """Test that get_token return the valid token given in the credential file.""" + temp_file = create_temp_file(content=CREDENTIALS_CONTENT) + result = get_token( + location=temp_file, token=None, token_endpoint="", client_id="ID", verify=False + ) + temp_file.unlink() + assert isinstance(result, AccessToken) + + +def test_get_token_input_token_not_exists(token_setup): + _, token_location, access_token = token_setup + result = get_token( + location=token_location, + token=access_token, + token_endpoint="", + client_id="ID", + verify=False, + ) + assert result is None + + +def test_get_token_invalid_input(token_setup): + """Test that get_token manage invalid input token.""" + # Test wrong key in credential + token_response, _, _ = token_setup + wrong_credential_content = "'{\"wrong_key\": False}'" + temp_file = create_temp_file(content=wrong_credential_content) + result = get_token( + location=temp_file, token=None, token_endpoint="", client_id="ID", verify=False + ) + temp_file.unlink() + assert result is None + + # Test with invalid token date + token_response = TOKEN_RESPONSE_DICT.copy() + token_response["expires_in"] = int(datetime.now(tz=timezone.utc).timestamp()) + credential_content = json.dumps(token_response) + temp_file = create_temp_file(content=credential_content) + result = get_token( + location=temp_file, token=None, token_endpoint="", client_id="ID", verify=False + ) + temp_file.unlink() + assert result is None + + +def test_get_token_refresh_valid(): + """Test that get_token refresh a valid outdated token.""" + token_response = TOKEN_RESPONSE_DICT.copy() + # the future content of the refreshed token + refresh_token = TokenResponse(**token_response) + # Create expired credential file + token_response["expires_on"] = int( + (datetime.now(tz=timezone.utc) - timedelta(seconds=10)).timestamp() + ) + token_response.pop("expires_in") + credentials = json.dumps(token_response) + temp_file = create_temp_file(content=credentials) + + with ( + patch( + "diracx.client.patches.utils.is_refresh_token_valid", return_value=True + ) as mock_is_refresh_valid, + patch( + "diracx.client.patches.utils.refresh_token", return_value=refresh_token + ) as mock_refresh_token, + ): + result = get_token( + location=temp_file, + token=None, + token_endpoint="", + client_id="ID", + verify=False, + ) + + # Verify that the credential fil has been refreshed: + with open(temp_file, "r") as f: + content = f.read() + assert content == serialize_credentials(refresh_token) + + temp_file.unlink() + + assert result is not None + assert isinstance(result, AccessToken) + assert result.token == refresh_token.access_token + assert result.expires_on > refresh_token.expires_in + mock_is_refresh_valid.assert_called_once_with(refresh_token.refresh_token) + mock_refresh_token.assert_called_once_with( + "", "ID", refresh_token.refresh_token, verify=False + ) + + +def test_get_token_refresh_invalid(): + """Test that get_token manages an invalid refresh token.""" + token_response = TOKEN_RESPONSE_DICT.copy() + refresh_token = TokenResponse(**token_response) + token_response["expires_on"] = int( + (datetime.now(tz=timezone.utc) - timedelta(seconds=10)).timestamp() + ) + token_response.pop("expires_in") + credentials = json.dumps(token_response) + temp_file = create_temp_file(content=credentials) + + with ( + patch( + "diracx.client.patches.utils.is_refresh_token_valid", return_value=False + ) as mock_is_refresh_valid, + ): + result = get_token( + location=temp_file, + token=None, + token_endpoint="", + client_id="ID", + verify=False, + ) + + temp_file.unlink() + assert result is None + mock_is_refresh_valid.assert_called_once_with(refresh_token.refresh_token)