Skip to content

Commit

Permalink
Disable threads if messages.ASYNC_VALIDATION = False
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Astrand committed Nov 18, 2024
1 parent 71c5adb commit 05c6594
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 28 deletions.
17 changes: 5 additions & 12 deletions ocpp/charge_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ async def _handle_call(self, msg):
return

if not handlers.get("_skip_schema_validation", False):
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, msg, self._ocpp_version
)
await validate_payload(msg, self._ocpp_version)

# OCPP uses camelCase for the keys in the payload. It's more pythonic
# to use snake_case for keyword arguments. Therefore the keys must be
# 'translated'. Some examples:
Expand Down Expand Up @@ -344,9 +343,7 @@ async def _handle_call(self, msg):
response = msg.create_call_result(camel_case_payload)

if not handlers.get("_skip_schema_validation", False):
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, response, self._ocpp_version
)
await validate_payload(response, self._ocpp_version)

await self._send(response.to_json())

Expand Down Expand Up @@ -415,9 +412,7 @@ async def call(
)

if not skip_schema_validation:
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, call, self._ocpp_version
)
await validate_payload(call, self._ocpp_version)

# Use a lock to prevent make sure that only 1 message can be send at a
# a time.
Expand All @@ -440,9 +435,7 @@ async def call(
raise response.to_exception()
elif not skip_schema_validation:
response.action = call.action
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, response, self._ocpp_version
)
await validate_payload(response, self._ocpp_version)

snake_case_payload = camel_to_snake_case(response.payload)
# Create the correct Payload instance based on the received payload. If
Expand Down
15 changes: 14 additions & 1 deletion ocpp/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
""" Module containing classes that model the several OCPP messages types. It
also contain some helper functions for packing and unpacking messages. """

from __future__ import annotations

import asyncio
import decimal
import json
import os
Expand All @@ -24,6 +26,8 @@

_validators: Dict[str, Draft4Validator] = {}

ASYNC_VALIDATION = True


class _DecimalEncoder(json.JSONEncoder):
"""Encode values of type `decimal.Decimal` using 1 decimal point.
Expand Down Expand Up @@ -169,8 +173,17 @@ def get_validator(
return _validators[cache_key]


def validate_payload(message: Union[Call, CallResult], ocpp_version: str) -> None:
async def validate_payload(message: Union[Call, CallResult], ocpp_version: str) -> None:
"""Validate the payload of the message using JSON schemas."""
if ASYNC_VALIDATION:
await asyncio.get_event_loop().run_in_executor(
None, _validate_payload, message, ocpp_version
)
else:
_validate_payload(message, ocpp_version)


def _validate_payload(message: Union[Call, CallResult], ocpp_version: str) -> None:
if type(message) not in [Call, CallResult]:
raise ValidationError(
"Payload can't be validated because message "
Expand Down
6 changes: 3 additions & 3 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ProtocolError,
TypeConstraintViolationError,
)
from ocpp.messages import Call, validate_payload
from ocpp.messages import Call, _validate_payload


def test_exception_with_error_details():
Expand Down Expand Up @@ -36,7 +36,7 @@ def test_exception_show_triggered_message_type_constraint():
)

with pytest.raises(TypeConstraintViolationError) as exception_info:
validate_payload(call, "1.6")
_validate_payload(call, "1.6")
assert ocpp_message in str(exception_info.value)


Expand All @@ -54,5 +54,5 @@ def test_exception_show_triggered_message_format():
)

with pytest.raises(FormatViolationError) as exception_info:
validate_payload(call, "1.6")
_validate_payload(call, "1.6")
assert ocpp_message in str(exception_info.value)
24 changes: 12 additions & 12 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
CallResult,
MessageType,
_DecimalEncoder,
_validate_payload,
_validators,
get_validator,
pack,
unpack,
validate_payload,
)
from ocpp.v16.enums import Action

Expand Down Expand Up @@ -137,7 +137,7 @@ def test_validate_set_charging_profile_payload():
},
)

validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_get_composite_profile_payload():
Expand All @@ -162,7 +162,7 @@ def test_validate_get_composite_profile_payload():
},
)

validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


@pytest.mark.parametrize("ocpp_version", ["1.6", "2.0.1"])
Expand All @@ -177,7 +177,7 @@ def test_validate_payload_with_valid_payload(ocpp_version):
payload={"currentTime": datetime.now().isoformat()},
)

validate_payload(message, ocpp_version=ocpp_version)
_validate_payload(message, ocpp_version=ocpp_version)


def test_validate_payload_with_invalid_additional_properties_payload():
Expand All @@ -192,7 +192,7 @@ def test_validate_payload_with_invalid_additional_properties_payload():
)

with pytest.raises(FormatViolationError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_payload_with_invalid_type_payload():
Expand All @@ -212,7 +212,7 @@ def test_validate_payload_with_invalid_type_payload():
)

with pytest.raises(TypeConstraintViolationError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_payload_with_invalid_missing_property_payload():
Expand All @@ -232,7 +232,7 @@ def test_validate_payload_with_invalid_missing_property_payload():
)

with pytest.raises(ProtocolError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_payload_with_invalid_message_type_id():
Expand All @@ -241,7 +241,7 @@ def test_validate_payload_with_invalid_message_type_id():
a message type id other than 2, Call, or 3, CallError.
"""
with pytest.raises(ValidationError):
validate_payload(dict(), ocpp_version="1.6")
_validate_payload(dict(), ocpp_version="1.6")


def test_validate_payload_with_non_existing_schema():
Expand All @@ -256,7 +256,7 @@ def test_validate_payload_with_non_existing_schema():
)

with pytest.raises(NotImplementedError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_call_error_representation():
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_serializing_custom_types():
)

try:
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")
except TypeConstraintViolationError as error:
# Before the fix, this call would fail with a TypError. Lack of any error
# makes this test pass.
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_validate_meter_values_hertz():
},
)

validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_set_maxlength_violation_payload():
Expand All @@ -394,4 +394,4 @@ def test_validate_set_maxlength_violation_payload():
)

with pytest.raises(TypeConstraintViolationError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")

0 comments on commit 05c6594

Please sign in to comment.