Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve YAML serialization. #193

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nrel/hive/config/hive_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from typing import NamedTuple, Dict, Union, Tuple, Optional

import pkg_resources
import yaml

from nrel.hive.config.config_builder import ConfigBuilder
from nrel.hive.config.dispatcher_config import DispatcherConfig
from nrel.hive.config.global_config import GlobalConfig
from nrel.hive.config.input import Input
from nrel.hive.config.network import Network
from nrel.hive.config.sim import Sim
from nrel.hive.custom_yaml import custom_yaml as yaml
from nrel.hive.util import fs

log = logging.getLogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions nrel/hive/custom_yaml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from nrel.hive.custom_yaml.custom_yaml import custom_yaml
70 changes: 70 additions & 0 deletions nrel/hive/custom_yaml/custom_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from pathlib import PurePath
from typing import Any, Union

import yaml

from nrel.hive.dispatcher.instruction_generator.charging_search_type import ChargingSearchType
from nrel.hive.model.sim_time import SimTime
from nrel.hive.model.vehicle.schedules.schedule_type import ScheduleType
from nrel.hive.reporting.report_type import ReportType

log = logging.getLogger(__name__)

custom_yaml = yaml

# This tag is not written to the file during serialization because interpretation is implicit during YAML deserialization.
YAML_STR_TAG = "tag:yaml.org,2002:str"


# Handling stdlib objects that should be represented by list(obj).
# Prefer to handle classes within their own definition and then register below.
def convert_to_unsorted_list(dumper: custom_yaml.Dumper, obj: tuple):
"""Patches PyYAML representation for an object so that it is treated as a YAML list. Avoids an explicit YAML tag."""
return dumper.represent_list(list(obj))


custom_yaml.add_representer(data_type=tuple, representer=convert_to_unsorted_list)


# Handling stdlib objects that should be represented by sorted(list(obj)).
# Prefer to handle classes within their own definition and then register below.
def convert_to_sorted_list(dumper: custom_yaml.Dumper, obj: set):
"""Patches PyYAML representation for an object so that it is treated as a YAML list. Avoids an explicit YAML tag."""
return dumper.represent_list(sorted(list(obj)))


custom_yaml.add_representer(data_type=set, representer=convert_to_sorted_list)


# Handling stdlib objects that should be represented as str(obj).
# Prefer to handle classes within their own definition and then register below.
def convert_to_str(dumper: custom_yaml.Dumper, path: PurePath):
"""Patches PyYAML representation for an object so that it is treated as a YAML str. Avoids an explicit YAML tag."""
return dumper.represent_scalar(tag=YAML_STR_TAG, value=str(path))


custom_yaml.add_multi_representer(data_type=PurePath, multi_representer=convert_to_str)


# Registering explicit/specific representers that are kept withing their classes.
custom_yaml.add_representer(
data_type=ChargingSearchType, representer=ChargingSearchType.yaml_representer
)
custom_yaml.add_representer(data_type=ReportType, representer=ReportType.yaml_representer)
custom_yaml.add_representer(data_type=ScheduleType, representer=ScheduleType.yaml_representer)
custom_yaml.add_representer(data_type=SimTime, representer=SimTime.yaml_representer)


# Fallback to str() representation for any child of `object`.
# Raise a warning to alert the user that implicit serialization was done.
# Does not appear to work for built-in types that PyYaml has special serializers for.
def generic_representer(dumper: custom_yaml.Dumper, obj: object):
"""Serializes arbitrary objects to strs."""
log.warning(f"{obj.__class__} object was implicity serialized with `str(obj)`.")
tag = YAML_STR_TAG
val = str(obj)
return dumper.represent_scalar(tag=tag, value=val)


custom_yaml.add_multi_representer(data_type=object, multi_representer=generic_representer)
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import yaml

class ChargingSearchType(Enum):
NEAREST_SHORTEST_QUEUE = 1
Expand All @@ -26,3 +29,7 @@ def from_string(string: str) -> ChargingSearchType:
raise NameError(
f"charging search type {string} is not known, must be one of {valid_names}"
)

@staticmethod
def yaml_representer(dumper: yaml.Dumper, o: "ChargingSearchType"):
return dumper.represent_scalar(tag = 'tag:yaml.org,2002:str', value = o.name.lower())
8 changes: 7 additions & 1 deletion nrel/hive/model/sim_time.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from datetime import datetime, time
from typing import Union
from typing import TYPE_CHECKING, Union

from nrel.hive.util.exception import TimeParseError

if TYPE_CHECKING:
import yaml

class SimTime(int):
ERROR_MSG = (
Expand Down Expand Up @@ -65,3 +67,7 @@ def as_epoch_time(self) -> int:

def as_iso_time(self) -> str:
return datetime.utcfromtimestamp(int(self)).isoformat()

@staticmethod
def yaml_representer(dumper: yaml.Dumper, o: "SimTime"):
return dumper.represent_scalar(tag = 'tag:yaml.org,2002:str', value = o.as_iso_time())
7 changes: 7 additions & 0 deletions nrel/hive/model/vehicle/schedules/schedule_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import yaml

class ScheduleType(Enum):
TIME_RANGE = 0
Expand All @@ -24,3 +27,7 @@ def from_string(string: str) -> ScheduleType:
raise NameError(
f"schedule type '{string}' is not known, must be one of {{{valid_names}}}"
)

@staticmethod
def yaml_representer(dumper: yaml.Dumper, o: "ScheduleType"):
return dumper.represent_scalar(tag = 'tag:yaml.org,2002:str', value = o.name.lower())
16 changes: 16 additions & 0 deletions nrel/hive/reporting/report_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import yaml


class ReportType(Enum):
Expand Down Expand Up @@ -43,3 +47,15 @@ def from_string(cls, s: str) -> ReportType:
return values[s]
except KeyError:
raise KeyError(f"{s} not a valid report type.")

@staticmethod
def yaml_representer(dumper: yaml.Dumper, o: "ReportType"):
return dumper.represent_scalar(tag="tag:yaml.org,2002:str", value=o.name.lower())

def __lt__(self, other: "ReportType"):
"""Allows sorting an iterable of ReportType, in particular for deterministic serialization of `set[ReportType]`"""
if not isinstance(other, ReportType):
raise TypeError(
f"'<' not supported between instances of {type(self)} and {type(other)}"
)
return self.name < other.name
13 changes: 13 additions & 0 deletions tests/test_charging_search_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unittest import TestCase

from nrel.hive.custom_yaml import custom_yaml as yaml
from nrel.hive.dispatcher.instruction_generator.charging_search_type import ChargingSearchType


class TestChargingSearchType(TestCase):
def test_yaml_repr(self):
a = ChargingSearchType(1)
yaml.add_representer(
data_type=ChargingSearchType, representer=ChargingSearchType.yaml_representer
)
self.assertEqual(yaml.dump(a), "nearest_shortest_queue\n...\n")
136 changes: 136 additions & 0 deletions tests/test_custom_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
from pathlib import Path, PurePath
from unittest import TestCase

from nrel.hive.custom_yaml import custom_yaml as yaml
from nrel.hive.dispatcher.instruction_generator.charging_search_type import ChargingSearchType
from nrel.hive.model.sim_time import SimTime
from nrel.hive.model.vehicle.schedules.schedule_type import ScheduleType
from nrel.hive.reporting.report_type import ReportType

BASIC_TUPLE = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
BASIC_LIST = list(BASIC_TUPLE)

BASIC_SET = set([3, 2, 1, 0, 4, 6, 7, 5, 9, 8])


class TestClass(object):
"""An accessory class for use in tests without any YAML serialization methods."""

def __str__(self):
return "123abc123"


log = logging.getLogger(__name__)


class TestCustomYaml_Tuple(TestCase):
def test_not_tagged_YAML(self):
a = (1,)
self.assertNotEqual(yaml.dump(a)[0], "!", "YAML ! tag detected during serialization.")

def test_tuple_like_list(self):
self.assertEqual(
yaml.dump(BASIC_LIST),
yaml.dump(BASIC_TUPLE),
"The custom YAML serializer is not treating a tuple like a YAML list.",
)


class TestCustomYAML_Set(TestCase):
def test_not_tagged_YAML(self):
a = set([1])
self.assertNotEqual(yaml.dump(a)[0], "!", "YAML ! tag detected during serialization.")

def test_set_like_list(self):
self.assertEqual(
yaml.dump(BASIC_LIST),
yaml.dump(BASIC_SET),
"The custom YAML serializer is not treating a set like a YAML list.",
)


class TestCustomYAML_PathLib(TestCase):
def test_purepath(self):
ppath = PurePath("./test/")
str_path = str(ppath)
self.assertEqual(yaml.dump(str_path), yaml.dump(ppath))

def test_path(self):
path = Path("./test/")
str_path = str(path)
self.assertEqual(yaml.dump(str_path), yaml.dump(path))


class TestCustomYAML_ChargingSearchType(TestCase):
def test_not_tagged_YAML(self):
a = ChargingSearchType(1)
self.assertNotEqual(yaml.dump(a)[0], "!", "YAML ! tag detected during serialization.")

def test_like_str(self):
a = ChargingSearchType(1)
self.assertEqual(yaml.dump(a), "nearest_shortest_queue\n...\n")


class TestCustomYAML_SimTime(TestCase):
def test_not_tagged_YAML(self):
a = SimTime.build(0)
self.assertNotEqual(yaml.dump(a)[0], "!", "YAML ! tag detected during serialization.")

def test_like_iso_time_dynamic(self):
a = SimTime.build(0)
self.assertEqual(yaml.dump(a.as_iso_time()), yaml.dump(a))

def test_like_iso_time_static(self):
a = SimTime.build(0)
self.assertEqual(yaml.dump(a.as_iso_time()), "'1970-01-01T00:00:00'\n")


class TestCustomYAML_ScheduleType(TestCase):
def test_not_tagged_YAML(self):
a = ScheduleType(0)
self.assertNotEqual(yaml.dump(a)[0], "!", "YAML ! tag detected during serialization.")

def test_like_str(self):
a = ScheduleType(0)
self.assertEqual(yaml.dump(a), "time_range\n...\n")


class TestCustomYAML_ReportType(TestCase):
def test_not_tagged_YAML(self):
a = ReportType(1)
self.assertNotEqual(yaml.dump(a)[0], "!", "YAML ! tag detected during serialization.")

def test_like_str(self):
a = ReportType(1)
self.assertEqual(yaml.dump(a), "station_state\n...\n")


class TestCustomYAML_Generic(TestCase):
def test_not_tagged_YAML(self):
instance = TestClass()
self.assertNotEqual(
yaml.dump(instance)[0], "!", "YAML ! tag detected during serialization."
)

def test_warn_on_generic_serialization(self):
for c in (TestClass, lambda: range(10)):
instance = c()

with self.assertLogs(level=logging.WARNING) as log_cm:
a = yaml.dump(instance)
print(a)

warninglog = (
"WARNING:nrel.hive.custom_yaml.custom_yaml:"
+ f"{instance.__class__} object was implicity serialized with `str(obj)`."
)
self.assertIn(
warninglog,
log_cm.output,
msg="WARNING entry in log was not present when implicitly serializing an object.",
)

def test_generic_serialization_like_str(self):
instance = TestClass()
self.assertEqual(yaml.dump(str(instance)), yaml.dump(instance))
32 changes: 32 additions & 0 deletions tests/test_report_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from unittest import TestCase

from nrel.hive.custom_yaml import custom_yaml as yaml
from nrel.hive.reporting.report_type import ReportType


class TestReportType(TestCase):
def test_yaml_repr(self):
yaml.add_representer(data_type=ReportType, representer=ReportType.yaml_representer)
self.assertEqual("station_state\n...\n", yaml.dump(ReportType.from_string("station_state")))

def test_ReportType_ordering_dynamic(self):
members = [m for m in ReportType]
name_then_sort = sorted([m.name for m in members])
sort_then_name = [m.name for m in sorted(members)]
self.assertEqual(name_then_sort, sort_then_name, "ReportType sorting invalid.")

def test_ReportType_ordering_static(self):
a = ReportType.from_string("station_state")
b = ReportType.from_string("driver_state")
self.assertLess(b, a, "ReportType sorting invalid.")

def test_ReportType_lt_raise(self):
class GenericClass:
@property
def name(self):
return "abc"

a = ReportType.from_string("station_state")
b = GenericClass()
with self.assertRaises(TypeError):
x = a < b
10 changes: 10 additions & 0 deletions tests/test_schedule_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from unittest import TestCase

from nrel.hive.custom_yaml import custom_yaml as yaml
from nrel.hive.model.vehicle.schedules.schedule_type import ScheduleType


class TestSimTime(TestCase):
def test_yaml_repr(self):
yaml.add_representer(data_type=ScheduleType, representer=ScheduleType.yaml_representer)
self.assertEqual("time_range\n...\n", yaml.dump(ScheduleType(0)))
10 changes: 10 additions & 0 deletions tests/test_sim_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from unittest import TestCase

from nrel.hive.custom_yaml import custom_yaml as yaml
from nrel.hive.model.sim_time import SimTime


class TestSimTime(TestCase):
def test_yaml_repr(self):
yaml.add_representer(data_type=SimTime, representer=SimTime.yaml_representer)
self.assertEqual("'1970-01-01T00:00:00'\n", yaml.dump(SimTime.build(0)))