diff --git a/ocpp/charge_point.py b/ocpp/charge_point.py index 7314804b1..08a33e6e3 100644 --- a/ocpp/charge_point.py +++ b/ocpp/charge_point.py @@ -356,10 +356,25 @@ async def _handle_call(self, msg): call_unique_id_required = "call_unique_id" in handler_signature.parameters # call_unique_id should be passed as kwarg only if is defined explicitly # in the handler signature + inject_response = getattr(handler, "_inject_response", False) if call_unique_id_required: - response = handler(**snake_case_payload, call_unique_id=msg.unique_id) + if inject_response: + response = handler( + **snake_case_payload, + call_unique_id=msg.unique_id, + on_response=response_payload, + ) + else: + response = handler( + **snake_case_payload, call_unique_id=msg.unique_id + ) else: - response = handler(**snake_case_payload) + if inject_response: + response = handler( + **snake_case_payload, on_response=response_payload + ) + else: + response = handler(**snake_case_payload) # Create task to avoid blocking when making a call inside the # after handler if inspect.isawaitable(response): diff --git a/ocpp/routing.py b/ocpp/routing.py index dbdc641cf..8c9406fb1 100644 --- a/ocpp/routing.py +++ b/ocpp/routing.py @@ -56,7 +56,7 @@ def inner(*args, **kwargs): return decorator -def after(action): +def after(action, inject_response=False): """Function decorator to mark function as hook to post-request hook. This hook's arguments are the data that is in the payload for the specific @@ -76,6 +76,7 @@ def inner(*args, **kwargs): return func(*args, **kwargs) inner._after_action = action + inner._inject_response = inject_response if func.__name__ not in routables: routables.append(func.__name__) return inner diff --git a/tests/test_charge_point.py b/tests/test_charge_point.py index 00571f9fd..cedc3ddb6 100644 --- a/tests/test_charge_point.py +++ b/tests/test_charge_point.py @@ -472,3 +472,32 @@ def after_boot_notification(self, *args, **kwargs): assert ChargerA.after_boot_notification_call_count == 1 assert ChargerB.on_boot_notification_call_count == 1 assert ChargerB.after_boot_notification_call_count == 1 + + +@pytest.mark.asyncio +async def test_response_injected_to_after_handler(connection): + """ + This test ensures that the response is injected to the `after` handler + when `inject_response` is set to True. + """ + + class TestChargePoint(cp_16): + @on(Action.BootNotification) + def on_boot_notification(self, **kwargs): + return BootNotificationResult( + current_time="2024-11-01T00:00:00Z", + interval=300, + status=RegistrationStatus.accepted, + ) + + @after(Action.BootNotification, inject_response=True) + def after_boot_notification(self, on_response, **kwargs): + + assert on_response["current_time"] == "2024-11-01T00:00:00Z" + assert on_response["interval"] == 300 + assert on_response["status"] == RegistrationStatus.accepted + + charge_point = TestChargePoint("test_cp", connection) + payload = {"chargePointVendor": "vendor", "chargePointModel": "model"} + msg = Call(unique_id="1234", action=Action.BootNotification.value, payload=payload) + await charge_point._handle_call(msg)