Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable threads if messages.ASYNC_VALIDATION = False #678

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions ocpp/charge_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ async def _handle_call(self, msg):
return

if not handlers.get("_skip_schema_validation", False):
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, msg, self._ocpp_version
)
await validate_payload(msg, self._ocpp_version)

# OCPP uses camelCase for the keys in the payload. It's more pythonic
# to use snake_case for keyword arguments. Therefore the keys must be
# 'translated'. Some examples:
Expand Down Expand Up @@ -344,9 +343,7 @@ async def _handle_call(self, msg):
response = msg.create_call_result(camel_case_payload)

if not handlers.get("_skip_schema_validation", False):
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, response, self._ocpp_version
)
await validate_payload(response, self._ocpp_version)

await self._send(response.to_json())

Expand Down Expand Up @@ -415,9 +412,7 @@ async def call(
)

if not skip_schema_validation:
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, call, self._ocpp_version
)
await validate_payload(call, self._ocpp_version)

# Use a lock to prevent make sure that only 1 message can be send at a
# a time.
Expand All @@ -440,9 +435,7 @@ async def call(
raise response.to_exception()
elif not skip_schema_validation:
response.action = call.action
await asyncio.get_event_loop().run_in_executor(
None, validate_payload, response, self._ocpp_version
)
await validate_payload(response, self._ocpp_version)

snake_case_payload = camel_to_snake_case(response.payload)
# Create the correct Payload instance based on the received payload. If
Expand Down
15 changes: 14 additions & 1 deletion ocpp/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
""" Module containing classes that model the several OCPP messages types. It
also contain some helper functions for packing and unpacking messages. """

from __future__ import annotations

import asyncio
import decimal
import json
import os
Expand All @@ -24,6 +26,8 @@

_validators: Dict[str, Draft4Validator] = {}

ASYNC_VALIDATION = True


class _DecimalEncoder(json.JSONEncoder):
"""Encode values of type `decimal.Decimal` using 1 decimal point.
Expand Down Expand Up @@ -169,8 +173,17 @@ def get_validator(
return _validators[cache_key]


def validate_payload(message: Union[Call, CallResult], ocpp_version: str) -> None:
async def validate_payload(message: Union[Call, CallResult], ocpp_version: str) -> None:
"""Validate the payload of the message using JSON schemas."""
if ASYNC_VALIDATION:
await asyncio.get_event_loop().run_in_executor(
None, _validate_payload, message, ocpp_version
)
else:
_validate_payload(message, ocpp_version)


def _validate_payload(message: Union[Call, CallResult], ocpp_version: str) -> None:
if type(message) not in [Call, CallResult]:
raise ValidationError(
"Payload can't be validated because message "
Expand Down
6 changes: 3 additions & 3 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ProtocolError,
TypeConstraintViolationError,
)
from ocpp.messages import Call, validate_payload
from ocpp.messages import Call, _validate_payload


def test_exception_with_error_details():
Expand Down Expand Up @@ -36,7 +36,7 @@ def test_exception_show_triggered_message_type_constraint():
)

with pytest.raises(TypeConstraintViolationError) as exception_info:
validate_payload(call, "1.6")
_validate_payload(call, "1.6")
assert ocpp_message in str(exception_info.value)


Expand All @@ -54,5 +54,5 @@ def test_exception_show_triggered_message_format():
)

with pytest.raises(FormatViolationError) as exception_info:
validate_payload(call, "1.6")
_validate_payload(call, "1.6")
assert ocpp_message in str(exception_info.value)
46 changes: 35 additions & 11 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import decimal
import json
import threading
from datetime import datetime

import pytest
from hypothesis import given
from hypothesis.strategies import binary

import ocpp
from ocpp.exceptions import (
FormatViolationError,
NotImplementedError,
Expand All @@ -21,6 +23,7 @@
CallResult,
MessageType,
_DecimalEncoder,
_validate_payload,
_validators,
get_validator,
pack,
Expand Down Expand Up @@ -137,7 +140,7 @@ def test_validate_set_charging_profile_payload():
},
)

validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_get_composite_profile_payload():
Expand All @@ -162,7 +165,7 @@ def test_validate_get_composite_profile_payload():
},
)

validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


@pytest.mark.parametrize("ocpp_version", ["1.6", "2.0.1"])
Expand All @@ -177,7 +180,7 @@ def test_validate_payload_with_valid_payload(ocpp_version):
payload={"currentTime": datetime.now().isoformat()},
)

validate_payload(message, ocpp_version=ocpp_version)
_validate_payload(message, ocpp_version=ocpp_version)


def test_validate_payload_with_invalid_additional_properties_payload():
Expand All @@ -192,7 +195,7 @@ def test_validate_payload_with_invalid_additional_properties_payload():
)

with pytest.raises(FormatViolationError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_payload_with_invalid_type_payload():
Expand All @@ -212,7 +215,7 @@ def test_validate_payload_with_invalid_type_payload():
)

with pytest.raises(TypeConstraintViolationError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_payload_with_invalid_missing_property_payload():
Expand All @@ -232,7 +235,7 @@ def test_validate_payload_with_invalid_missing_property_payload():
)

with pytest.raises(ProtocolError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_payload_with_invalid_message_type_id():
Expand All @@ -241,7 +244,7 @@ def test_validate_payload_with_invalid_message_type_id():
a message type id other than 2, Call, or 3, CallError.
"""
with pytest.raises(ValidationError):
validate_payload(dict(), ocpp_version="1.6")
_validate_payload(dict(), ocpp_version="1.6")


def test_validate_payload_with_non_existing_schema():
Expand All @@ -256,7 +259,7 @@ def test_validate_payload_with_non_existing_schema():
)

with pytest.raises(NotImplementedError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_call_error_representation():
Expand Down Expand Up @@ -342,7 +345,7 @@ def test_serializing_custom_types():
)

try:
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")
except TypeConstraintViolationError as error:
# Before the fix, this call would fail with a TypError. Lack of any error
# makes this test pass.
Expand Down Expand Up @@ -376,7 +379,7 @@ def test_validate_meter_values_hertz():
},
)

validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


def test_validate_set_maxlength_violation_payload():
Expand All @@ -394,4 +397,25 @@ def test_validate_set_maxlength_violation_payload():
)

with pytest.raises(TypeConstraintViolationError):
validate_payload(message, ocpp_version="1.6")
_validate_payload(message, ocpp_version="1.6")


@pytest.mark.parametrize("use_threads", [False, True])
@pytest.mark.asyncio
async def test_validate_payload_threads(use_threads):
"""
Test that threads usage can be configured
"""
message = CallResult(
unique_id="1234",
action="Heartbeat",
payload={"currentTime": datetime.now().isoformat()},
)

assert threading.active_count() == 1
ocpp.messages.ASYNC_VALIDATION = use_threads
await validate_payload(message, ocpp_version="1.6")
if use_threads:
assert threading.active_count() > 1
else:
assert threading.active_count() == 1
Loading