From c520ba3a6a3c734ba0f9543492bb9922c82106c1 Mon Sep 17 00:00:00 2001 From: Ihor Liubymov Date: Tue, 24 Sep 2024 14:07:16 +0300 Subject: [PATCH] fix: add lock to CachingCryptoMaterialsManager --- .../materials_managers/caching.py | 82 ++++++++++--------- test/unit/test_material_managers_caching.py | 34 ++++++++ 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/src/aws_encryption_sdk/materials_managers/caching.py b/src/aws_encryption_sdk/materials_managers/caching.py index b1a0ecab5..81822c6f5 100644 --- a/src/aws_encryption_sdk/materials_managers/caching.py +++ b/src/aws_encryption_sdk/materials_managers/caching.py @@ -3,6 +3,7 @@ """Caching crypto material manager.""" import logging import uuid +from threading import RLock import attr import six @@ -109,6 +110,8 @@ def __attrs_post_init__(self): if self.partition_name is None: self.partition_name = to_bytes(str(uuid.uuid4())) + self._cache_lock = RLock() + def _cache_entry_has_encrypted_too_many_bytes(self, entry): """Determines if a cache entry has exceeded the max allowed bytes encrypted. @@ -188,32 +191,33 @@ def get_encryption_materials(self, request): ) cache_key = build_encryption_materials_cache_key(partition=self.partition_name, request=inner_request) - # Attempt to retrieve from cache - try: - cache_entry = self.cache.get_encryption_materials( - cache_key=cache_key, plaintext_length=request.plaintext_length - ) - except CacheKeyError: - pass - else: - if self._cache_entry_has_exceeded_limits(cache_entry): - self.cache.remove(cache_entry) + with self._cache_lock: + # Attempt to retrieve from cache + try: + cache_entry = self.cache.get_encryption_materials( + cache_key=cache_key, plaintext_length=request.plaintext_length + ) + except CacheKeyError: + pass else: - return cache_entry.value - - # Nothing found in cache: try the material manager - new_result = self.backing_materials_manager.get_encryption_materials(inner_request) - - if not new_result.algorithm.safe_to_cache() or request.plaintext_length >= self.max_bytes_encrypted: - return new_result - - # Add results into cache - self.cache.put_encryption_materials( - cache_key=cache_key, - encryption_materials=new_result, - plaintext_length=request.plaintext_length, - entry_hints=CryptoMaterialsCacheEntryHints(lifetime=self.max_age), - ) + if self._cache_entry_has_exceeded_limits(cache_entry): + self.cache.remove(cache_entry) + else: + return cache_entry.value + + # Nothing found in cache: try the material manager + new_result = self.backing_materials_manager.get_encryption_materials(inner_request) + + if not new_result.algorithm.safe_to_cache() or request.plaintext_length >= self.max_bytes_encrypted: + return new_result + + # Add results into cache + self.cache.put_encryption_materials( + cache_key=cache_key, + encryption_materials=new_result, + plaintext_length=request.plaintext_length, + entry_hints=CryptoMaterialsCacheEntryHints(lifetime=self.max_age), + ) return new_result def decrypt_materials(self, request): @@ -225,21 +229,21 @@ def decrypt_materials(self, request): :rtype: aws_encryption_sdk.materials_managers.DecryptionMaterials """ cache_key = build_decryption_materials_cache_key(partition=self.partition_name, request=request) - - # Attempt to retrieve from cache - try: - cache_entry = self.cache.get_decryption_materials(cache_key) - except CacheKeyError: - pass - else: - if self._cache_entry_is_too_old(cache_entry): - self.cache.remove(cache_entry) + with self._cache_lock: + # Attempt to retrieve from cache + try: + cache_entry = self.cache.get_decryption_materials(cache_key) + except CacheKeyError: + pass else: - return cache_entry.value + if self._cache_entry_is_too_old(cache_entry): + self.cache.remove(cache_entry) + else: + return cache_entry.value - # Nothing found in cache: try the material manager - new_result = self.backing_materials_manager.decrypt_materials(request) + # Nothing found in cache: try the material manager + new_result = self.backing_materials_manager.decrypt_materials(request) - # Add results into cache - self.cache.put_decryption_materials(cache_key=cache_key, decryption_materials=new_result) + # Add results into cache + self.cache.put_decryption_materials(cache_key=cache_key, decryption_materials=new_result) return new_result diff --git a/test/unit/test_material_managers_caching.py b/test/unit/test_material_managers_caching.py index 7f186becc..62c768cfc 100644 --- a/test/unit/test_material_managers_caching.py +++ b/test/unit/test_material_managers_caching.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """Unit test suite for CachingCryptoMaterialsManager""" +import concurrent.futures + import pytest from mock import MagicMock, sentinel from pytest_mock import mocker # noqa pylint: disable=unused-import @@ -371,6 +373,26 @@ def test_get_encryption_materials_cache_miss_algorithm_not_safe_to_cache( assert test is ccmm.backing_materials_manager.get_encryption_materials.return_value +def test_get_encryption_materials_cache_thread_safe( + patch_encryption_materials_request, + patch_should_cache_encryption_request, + patch_cache_entry_has_exceeded_limits, + patch_build_encryption_materials_cache_key, +): + patch_cache_entry_has_exceeded_limits.return_value = False + mock_request = fake_encryption_request() + mock_request.plaintext_length = 10 + ccmm = build_ccmm() + ccmm.cache.get_encryption_materials.side_effect = [CacheKeyError, MagicMock(), MagicMock()] + ccmm.backing_materials_manager.get_encryption_materials.return_value.algorithm.safe_to_cache.return_value = True + + arguments = [mock_request, mock_request, mock_request] + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + results = [item for item in executor.map(ccmm.get_encryption_materials, arguments)] + + assert ccmm.backing_materials_manager.get_encryption_materials.call_count == 1 + + @pytest.fixture def patch_build_decryption_materials_cache_key(mocker): mocker.patch.object(aws_encryption_sdk.materials_managers.caching, "build_decryption_materials_cache_key") @@ -428,3 +450,15 @@ def test_decrypt_materials_cache_miss(patch_build_decryption_materials_cache_key assert not patch_cache_entry_is_too_old.called assert not ccmm.cache.remove.called assert test is ccmm.backing_materials_manager.decrypt_materials.return_value + + +def test_decrypt_materials_cache_thread_safe(patch_build_decryption_materials_cache_key, patch_cache_entry_is_too_old): + patch_cache_entry_is_too_old.return_value = False + ccmm = build_ccmm() + ccmm.cache.get_decryption_materials.side_effect = [CacheKeyError, MagicMock(), MagicMock()] + + arguments = [sentinel.request, sentinel.request, sentinel.request] + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + results = [item for item in executor.map(ccmm.decrypt_materials, arguments)] + + assert ccmm.backing_materials_manager.decrypt_materials.call_count == 1