diff --git a/pydiator_core/pipelines/cache_pipeline.py b/pydiator_core/pipelines/cache_pipeline.py index 9801705..fb4cf00 100644 --- a/pydiator_core/pipelines/cache_pipeline.py +++ b/pydiator_core/pipelines/cache_pipeline.py @@ -7,33 +7,35 @@ def __init__(self, cache_provider: BaseCacheProvider) -> None: self.cache_provider = cache_provider async def handle(self, req: BaseRequest) -> object: - if self.cache_provider is None: - return await self.next().handle(req) - if self.next() is None: raise Exception("pydiator_cache_pipeline_has_no_next_pipeline") - response = None + next_handle = getattr(self.next(), "handle", None) + if next_handle is None or not callable(next_handle): + raise Exception("handle_function_of_next_pipeline_is_not_valid_for_cache_pipeline") + + if self.cache_provider is None: + return await next_handle(req) if isinstance(req, BaseCacheable): - if req.is_no_cache() is False: + cache_type = req.get_cache_type() + if cache_type != CacheType.NONE and req.is_no_cache() is False: cache_key = req.get_cache_key() if cache_key is not None and cache_key != "": - cache_type = req.get_cache_type() if cache_type == CacheType.DISTRIBUTED: cached_obj = self.__get_from_cache(cache_key) if cached_obj is not None: return cached_obj - cache_duration = req.get_cache_duration() - response = await self.next().handle(req) - if response is not None and response != "" and cache_duration > 0: - self.__add_to_cache(response, cache_key, cache_duration) + response = await next_handle(req) + + cache_duration = req.get_cache_duration() + if response is not None and response != "" and cache_duration > 0: + self.__add_to_cache(response, cache_key, cache_duration) - if response is None: - response = await self.next().handle(req) + return response - return response + return await next_handle(req) def __get_from_cache(self, cache_key) -> object: cached_obj_str = self.cache_provider.get(cache_key) diff --git a/setup.py b/setup.py index 51fb957..23cb83b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="pydiator-core", - version="1.0.6", + version="1.0.7", author="Özgür Kara", author_email="ozgurkara85@gmail.com", description="Pydiator", diff --git a/tests/pipelines/test_cache_pipeline.py b/tests/pipelines/test_cache_pipeline.py index 7c948ed..57f777f 100644 --- a/tests/pipelines/test_cache_pipeline.py +++ b/tests/pipelines/test_cache_pipeline.py @@ -14,6 +14,17 @@ def setUp(self): def tearDown(self): pass + def test_handle_return_exception_when_next_is_none(self): + # Given + cache_pipeline = CachePipeline(FakeCacheProvider()) + + # When + with self.assertRaises(Exception) as context: + self.async_loop(cache_pipeline.handle(TestRequest())) + + # Then + assert context.exception.args[0] == 'pydiator_cache_pipeline_has_no_next_pipeline' + def test_handle_when_cache_provider_is_none(self): # Given next_response = TestResponse(success=True) @@ -34,16 +45,35 @@ async def next_handle(req): assert response is not None assert response == next_response - def test_handle_return_exception_when_next_is_none(self): + def test_handle_when_next_handle_is_none(self): # Given - cache_pipeline = CachePipeline(FakeCacheProvider()) + mock_test_pipeline = MagicMock() + mock_test_pipeline.handle = None + + cache_pipeline = CachePipeline(None) + cache_pipeline.set_next(mock_test_pipeline) # When with self.assertRaises(Exception) as context: self.async_loop(cache_pipeline.handle(TestRequest())) # Then - assert context.exception.args[0] == 'pydiator_cache_pipeline_has_no_next_pipeline' + assert context.exception.args[0] == 'handle_function_of_next_pipeline_is_not_valid_for_cache_pipeline' + + def test_handle_when_next_handle_is_not_callable(self): + # Given + mock_test_pipeline = MagicMock() + mock_test_pipeline.handle = 1 + + cache_pipeline = CachePipeline(None) + cache_pipeline.set_next(mock_test_pipeline) + + # When + with self.assertRaises(Exception) as context: + self.async_loop(cache_pipeline.handle(TestRequest())) + + # Then + assert context.exception.args[0] == 'handle_function_of_next_pipeline_is_not_valid_for_cache_pipeline' def test_handle_when_req_is_no_cache(self): # Given @@ -67,6 +97,27 @@ async def next_handle(req): assert response is not None assert response == next_response + def test_handle_when_req_cache_type_is_none(self): + # Given + next_response = TestResponse(success=True) + + async def next_handle(req): + return next_response + + cache_pipeline = CachePipeline(FakeCacheProvider()) + mock_test_pipeline = MagicMock() + mock_test_pipeline.handle = next_handle + cache_pipeline.set_next(mock_test_pipeline) + + test_request = TestRequestWithCacheable("cache_key", 1, CacheType.NONE) + + # When + response = self.async_loop(cache_pipeline.handle(test_request)) + + # Then + assert response is not None + assert response == next_response + def test_handle_when_req_cache_key_is_none(self): # Given next_response = TestResponse(success=True)