From 4c122ce917c571f0639df61c718323742e8b9f15 Mon Sep 17 00:00:00 2001 From: Brian Park Date: Tue, 14 Nov 2023 12:35:51 -0500 Subject: [PATCH 1/3] chore: output monotonicities functionality for RTL layer for consecutive RTLs --- pyproject.toml | 3 + pytorch_lattice/layers/__init__.py | 1 + pytorch_lattice/layers/rtl.py | 274 +++++++++++++++++++++ tests/layers/test_rtl.py | 374 +++++++++++++++++++++++++++++ 4 files changed, 652 insertions(+) create mode 100644 pytorch_lattice/layers/rtl.py create mode 100644 tests/layers/test_rtl.py diff --git a/pyproject.toml b/pyproject.toml index f20f41e..d163548 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" +[tool.mypy] +exclude = ["venv"] + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/pytorch_lattice/layers/__init__.py b/pytorch_lattice/layers/__init__.py index 15c21e9..b97b156 100644 --- a/pytorch_lattice/layers/__init__.py +++ b/pytorch_lattice/layers/__init__.py @@ -3,3 +3,4 @@ from .lattice import Lattice from .linear import Linear from .numerical_calibrator import NumericalCalibrator +from .rtl import RTL diff --git a/pytorch_lattice/layers/rtl.py b/pytorch_lattice/layers/rtl.py new file mode 100644 index 0000000..e2751dd --- /dev/null +++ b/pytorch_lattice/layers/rtl.py @@ -0,0 +1,274 @@ +"""Random Tiny Lattice module for use in calibrated modeling. + +PyTorch implementation of a RTL layer. +This layer takes several inputs which would otherwise be slow to run on a single lattice +and runs random subsets on an assortment of Random Tiny Lattices as an optimization. +""" +import logging +from typing import List, Optional + +import numpy as np +import torch + +from ..enums import Interpolation, LatticeInit, Monotonicity +from .lattice import Lattice + + +class RTL(torch.nn.Module): + """An RTL Module. + + Layer takes a number of features that would otherwise be too many to assign to + a single lattice, and instead assigns small random subsets of the features to an + ensemble of smaller lattices. The features are shuffled and uniformly repeated + if there are more slots in the RTL than features. + + Attributes: + - All `__init__` arguments. + - _lattice_layers: `dict` of form `{monotonic_count: (lattice, groups)}` which + keeps track of the RTL structure. Features are indexed then randomly grouped + together to be assigned to a lattice - groups with the same number of + monotonic features can be put into the same lattice for further optimization, + and are thus stored together in the dict according to `monotonic_count`. + + Example: + `python + inputs=torch.tensor(...) # shape: (batch_size, D) + monotonicities = List[Monotonicity...] # len: D + rtl1=RTL( + monotonicities, + num_lattices = 5 + lattice_rank = 3, # num_lattices * lattice_rank must be greater than D + ) + output1 = rtl1(inputs) + + # If you want to pass through consecutive RTLs + + rtl2 = RTL( + monotonicities=rtl1.output_monotonicities() # len: rtl1.num_lattices + ... + ) + output2 = RTL2(output1) + ` + """ + + def __init__( + self, + monotonicities: List[Monotonicity], + num_lattices: int, + lattice_rank: int, + lattice_size: int = 2, + output_min: Optional[float] = None, + output_max: Optional[float] = None, + kernel_init: LatticeInit = LatticeInit.LINEAR, + clip_inputs: bool = True, + interpolation: Interpolation = Interpolation.HYPERCUBE, + average_outputs: bool = False, + random_seed: int = 42, + ) -> None: + """Initializes an instance of 'RTL'. + + Args: + monotonicities: `List` of Monotonicity.INCREASING or Monotonicity.NONE + indicating monotonicities of input features, ordered respectively. + num_lattices: number of lattices in RTL structure. + lattice_rank: number of inputs for each lattice in RTL structure. + output_min: Minimum output of each lattice in RTL. + output_max: Maximum output of each lattice in RTL. + kernel_init: Initialization scheme to use for lattices. + clip_inputs: Whether input should be clipped to the range of each lattice. + interpolation: Interpolation scheme for each lattice in RTL. + average_outputs: Whether to average the outputs of every lattice RTL. + random_seed: seed used for shuffling. + + Raises: + ValueError: if size of RTL, determined by `num_lattices * lattice_rank`, is + too small to support the number of input features. + """ + super().__init__() + + if len(monotonicities) > num_lattices * lattice_rank: + raise ValueError( + f"RTL with {num_lattices}x{lattice_rank}D structure cannot support " + + f"{len(monotonicities)} input features." + ) + self.monotonicities = monotonicities + self.num_lattices = num_lattices + self.lattice_rank = lattice_rank + self.lattice_size = lattice_size + self.output_min = output_min + self.output_max = output_max + self.kernel_init = kernel_init + self.clip_inputs = clip_inputs + self.interpolation = interpolation + self.average_outputs = average_outputs + self.random_seed = random_seed + + rtl_indices = np.array( + [i % len(self.monotonicities) for i in range(num_lattices * lattice_rank)] + ) + np.random.seed(self.random_seed) + np.random.shuffle(rtl_indices) + # split_rtl_indices = np.split(rtl_indices, num_lattices) + split_rtl_indices = [list(arr) for arr in np.split(rtl_indices, num_lattices)] + swapped_rtl_indices = self._ensure_unique_sublattices(split_rtl_indices) + monotonicity_groupings = {} + for lattice_indices in swapped_rtl_indices: + monotonic_count = sum( + 1 + for idx in lattice_indices + if self.monotonicities[idx] == Monotonicity.INCREASING + ) + if monotonic_count not in monotonicity_groupings: + monotonicity_groupings[monotonic_count] = [lattice_indices] + else: + monotonicity_groupings[monotonic_count].append(lattice_indices) + for monotonic_count, groups in monotonicity_groupings.items(): + for i, lattice_indices in enumerate(groups): + sorted_indices = sorted( + lattice_indices, + key=lambda x: (self.monotonicities[x] != "increasing"), + reverse=False, + ) + groups[i] = sorted_indices + + self._lattice_layers = {} + for monotonic_count, groups in monotonicity_groupings.items(): + self._lattice_layers[monotonic_count] = ( + Lattice( + lattice_sizes=[self.lattice_size] * self.lattice_rank, + output_min=self.output_min, + output_max=self.output_max, + kernel_init=self.kernel_init, + monotonicities=[Monotonicity.INCREASING] * monotonic_count + + [Monotonicity.NONE] * (lattice_rank - monotonic_count), + clip_inputs=self.clip_inputs, + interpolation=self.interpolation, + units=len(groups), + ), + groups, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward method computed by using forward methods of each lattice in RTL. + + Args: + x: input tensor of feature values with shape `(batch_size, num_features)`. + + Returns: + torch.Tensor containing the outputs of each lattice within RTL structure. If + `average_outputs == True`, then all outputs are averaged into a tensor of + shape `(batch_size, 1)`. If `average_outputs == False`, shape of tensor is + `(batch_size, num_lattices)`. + """ + forward_results = [] + for _, (lattice, group) in sorted(self._lattice_layers.items()): + if len(group) > 1: + lattice_input = torch.stack([x[:, idx] for idx in group], dim=-2) + else: + lattice_input = x[:, group[0]] + forward_results.append(lattice.forward(lattice_input)) + result = torch.cat(forward_results, dim=-1) + if not self.average_outputs: + return result + result = torch.mean(result, dim=-1, keepdim=True) + + return result + + @torch.no_grad() + def output_monotonicities(self) -> List[Monotonicity]: + """Gives the monotonicities of the outputs of RTL. + + Returns: + List of `Monotonicity` corresponding to each output of the RTL layer, in the + same order as outputs. + """ + monotonicities = [] + for monotonic_count, (lattice, _) in sorted(self._lattice_layers.items()): + if monotonic_count: + monotonicity = Monotonicity.INCREASING + else: + monotonicity = Monotonicity.NONE + for _ in range(lattice.units): + monotonicities.append(monotonicity) + + return monotonicities + + @torch.no_grad() + def constrain(self) -> None: + """Enforces constraints for each lattice in RTL.""" + for lattice, _ in self._lattice_layers.values(): + lattice.constrain() + + @torch.no_grad() + def assert_constraints(self, eps=1e-6) -> List[List[str]]: + """Asserts that each Lattice in RTL satisfies all constraints. + + Args: + eps: allowed constraints violations. + + Returns: + List of lists, each with constraints violations for an individual Lattice. + """ + return list( + lattice.assert_constraints(eps=eps) + for lattice, _ in self._lattice_layers.values() + ) + + @staticmethod + def _ensure_unique_sublattices( + rtl_indices: List[List[int]], + max_swaps: int = 10000, + ) -> List[List[int]]: + """Attempts to ensure every lattice in RTL structure contains unique features. + + Args: + rtl_indices: list of lists where inner lists are groupings of + indices of input features to RTL layer. + max_swaps: maximum number of swaps to perform before giving up. + + Returns: + List of lists where elements between inner lists have been swapped in + an attempt to remove any duplicates from every grouping. + """ + swaps = 0 + num_sublattices = len(rtl_indices) + + def find_swap_candidate(current_index, element): + """Helper function to find the next sublattice not containing element.""" + for offset in range(1, num_sublattices): + candidate_index = (current_index + offset) % num_sublattices + if element not in rtl_indices[candidate_index]: + return candidate_index + return None + + for i, sublattice in enumerate(rtl_indices): + unique_elements = set() + for element in sublattice: + if element in unique_elements: + swap_with = find_swap_candidate(i, element) + if swap_with is not None: + for swap_element in rtl_indices[swap_with]: + if swap_element not in sublattice: + # Perform the swap + idx_to_swap = rtl_indices[swap_with].index(swap_element) + idx_duplicate = sublattice.index(element) + ( + rtl_indices[swap_with][idx_to_swap], + sublattice[idx_duplicate], + ) = element, swap_element + swaps += 1 + break + else: + logging.info( + "Some lattices in RTL may use the same feature multiple " + "times." + ) + return rtl_indices + else: + unique_elements.add(element) + if swaps >= max_swaps: + logging.info( + "Some lattices in RTL may use the same feature multiple times." + ) + return rtl_indices + return rtl_indices diff --git a/tests/layers/test_rtl.py b/tests/layers/test_rtl.py new file mode 100644 index 0000000..1f9abb6 --- /dev/null +++ b/tests/layers/test_rtl.py @@ -0,0 +1,374 @@ +"""Tests for RTL layer.""" +from itertools import cycle +from unittest.mock import Mock, patch + +import pytest +import torch + +from pytorch_lattice import Interpolation, LatticeInit, Monotonicity +from pytorch_lattice.layers import RTL, Lattice + + +@pytest.mark.parametrize( + "monotonicities, num_lattices, lattice_rank, output_min, output_max, kernel_init," + "clip_inputs, interpolation, average_outputs", + [ + ( + [ + Monotonicity.NONE, + Monotonicity.NONE, + Monotonicity.NONE, + Monotonicity.NONE, + ], + 3, + 3, + None, + 2.0, + LatticeInit.LINEAR, + True, + Interpolation.HYPERCUBE, + True, + ), + ( + [ + Monotonicity.INCREASING, + Monotonicity.INCREASING, + Monotonicity.NONE, + Monotonicity.NONE, + ], + 3, + 3, + -1.0, + 4.0, + LatticeInit.LINEAR, + False, + Interpolation.SIMPLEX, + False, + ), + ( + [Monotonicity.INCREASING, Monotonicity.NONE] * 25, + 20, + 5, + None, + None, + LatticeInit.LINEAR, + True, + Interpolation.HYPERCUBE, + True, + ), + ], +) +def test_initialization( + monotonicities, + num_lattices, + lattice_rank, + output_min, + output_max, + kernel_init, + clip_inputs, + interpolation, + average_outputs, +): + """Tests that RTL Initialization works properly.""" + rtl = RTL( + monotonicities=monotonicities, + num_lattices=num_lattices, + lattice_rank=lattice_rank, + output_min=output_min, + output_max=output_max, + kernel_init=kernel_init, + clip_inputs=clip_inputs, + interpolation=interpolation, + average_outputs=average_outputs, + ) + assert rtl.monotonicities == monotonicities + assert rtl.num_lattices == num_lattices + assert rtl.lattice_rank == lattice_rank + assert rtl.output_min == output_min + assert rtl.output_max == output_max + assert rtl.kernel_init == kernel_init + assert rtl.interpolation == interpolation + assert rtl.average_outputs == average_outputs + + total_lattices = 0 + for monotonic_count, (lattice, group) in rtl._lattice_layers.items(): + # test monotonic features have been sorted to front of list for lattice indices + for single_lattice_indices in group: + for i in range(lattice_rank): + if i < monotonic_count: + assert ( + rtl.monotonicities[single_lattice_indices[i]] + == Monotonicity.INCREASING + ) + else: + assert ( + rtl.monotonicities[single_lattice_indices[i]] + == Monotonicity.NONE + ) + + assert len(lattice.monotonicities) == len(lattice.lattice_sizes) + assert ( + sum(1 for _ in lattice.monotonicities if _ == Monotonicity.INCREASING) + == monotonic_count + ) + assert lattice.output_min == rtl.output_min + assert lattice.output_max == rtl.output_max + assert lattice.kernel_init == rtl.kernel_init + assert lattice.clip_inputs == rtl.clip_inputs + assert lattice.interpolation == rtl.interpolation + + # test number of lattices created is correct + total_lattices += lattice.units + + assert total_lattices == num_lattices + + +@pytest.mark.parametrize( + "monotonicities, num_lattices, lattice_rank", + [ + ([Monotonicity.NONE] * 9, 2, 2), + ([Monotonicity.INCREASING] * 10, 3, 3), + ], +) +def test_initialization_invalid( + monotonicities, + num_lattices, + lattice_rank, +): + """Tests that RTL Initialization raises error when RTL is too small.""" + with pytest.raises(ValueError) as exc_info: + RTL( + monotonicities=monotonicities, + num_lattices=num_lattices, + lattice_rank=lattice_rank, + ) + assert ( + str(exc_info.value) + == f"RTL with {num_lattices}x{lattice_rank}D structure cannot support " + + f"{len(monotonicities)} input features." + ) + + +@pytest.mark.parametrize( + "num_features, num_lattices, lattice_rank, units, expected_lattice_args," + "expected_result, expected_avg", + [ + ( + 6, + 6, + 3, + [3, 2, 1], + [ + torch.tensor([[[0.0, 0.1, 0.2], [0.3, 0.4, 0.5], [0.0, 0.1, 0.2]]]), + torch.tensor([[[0.3, 0.4, 0.5], [0.0, 0.1, 0.2]]]), + torch.tensor([[0.3, 0.4, 0.5]]), + ], + torch.tensor([[0.0, 0.0, 0.0, 1.0, 1.0, 2.0]]), + torch.tensor([[2 / 3]]), + ), + ( + 3, + 3, + 2, + [1, 1, 1], + [ + torch.tensor([[0.0, 0.1]]), + torch.tensor([[0.2, 0.0]]), + torch.tensor([[0.1, 0.2]]), + ], + torch.tensor([[0.0, 1.0, 2.0]]), + torch.tensor([[1.0]]), + ), + ( + 6, + 7, + 5, + [2, 3, 2], + [ + torch.tensor([[[0.0, 0.1, 0.2, 0.3, 0.4], [0.5, 0.0, 0.1, 0.2, 0.3]]]), + torch.tensor( + [ + [ + [0.4, 0.5, 0.0, 0.1, 0.2], + [0.3, 0.4, 0.5, 0.0, 0.1], + [0.2, 0.3, 0.4, 0.5, 0.0], + ] + ] + ), + torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5], [0.0, 0.1, 0.2, 0.3, 0.4]]]), + ], + torch.tensor([[0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]]), + torch.tensor([[1.0]]), + ), + ], +) +def test_forward( + num_features, + num_lattices, + lattice_rank, + units, + expected_lattice_args, + expected_result, + expected_avg, +): + """Tests forward function of RTL Lattice.""" + rtl = RTL( + monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING] + * (num_features // 2), + num_lattices=num_lattices, + lattice_rank=lattice_rank, + ) + # generate indices for each lattice in cyclic fashion based off units + groups = [] + feature_indices = cycle(range(num_features)) + for lattice_units in units: + group = [ + [next(feature_indices) for _ in range(lattice_rank)] + for _ in range(lattice_units) + ] + groups.append(group) + lattice_indices = {i: groups[i % len(groups)] for i in range(len(units))} + rtl._lattice_layers = { + i: (Lattice(lattice_sizes=[2] * lattice_rank, units=unit), lattice_indices[i]) + for i, unit in enumerate(units) + } + + mock_forwards = [] + for monotonic_count, (lattice, _) in rtl._lattice_layers.items(): + mock_forward = Mock() + lattice.forward = mock_forward + mock_forward.return_value = torch.full( + (1, units[monotonic_count]), + float(monotonic_count), + dtype=torch.float32, + ) + mock_forwards.append(mock_forward) + + x = torch.arange(0, num_features * 0.1, 0.1).unsqueeze(0) + result = rtl.forward(x) + + # Assert the calls and results for each mock_forward based on expected_outs + for i, mock_forward in enumerate(mock_forwards): + mock_forward.assert_called_once() + assert torch.allclose( + mock_forward.call_args[0][0], expected_lattice_args[i], atol=1e-6 + ) + assert torch.allclose(result, expected_result) + rtl.average_outputs = True + result = rtl.forward(x) + assert torch.allclose(result, expected_avg) + + +@pytest.mark.parametrize( + "monotonic_counts, units, expected_out", + [ + ( + [0, 1, 2, 3], + [2, 1, 1, 1], + [Monotonicity.NONE] * 2 + [Monotonicity.INCREASING] * 3, + ), + ( + [0, 4, 5, 7], + [1, 2, 3, 4], + [Monotonicity.NONE] + [Monotonicity.INCREASING] * 9, + ), + ([0], [3], [Monotonicity.NONE] * 3), + ([1, 2, 3], [1, 1, 1], [Monotonicity.INCREASING] * 3), + ], +) +def test_output_monotonicities( + monotonic_counts, + units, + expected_out, +): + """Tests output_monotonicities function.""" + rtl = RTL( + monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING], + num_lattices=3, + lattice_rank=3, + ) + rtl._lattice_layers = { + monotonic_count: (Lattice(lattice_sizes=[2, 2], units=units[i]), []) + for i, monotonic_count in enumerate(monotonic_counts) + } + assert rtl.output_monotonicities() == expected_out + + +def test_constrain(): + """Tests RTL constrain function.""" + rtl = RTL( + monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING], + num_lattices=3, + lattice_rank=3, + ) + mock_constrains = [] + for lattice, _ in rtl._lattice_layers.values(): + mock_constrain = Mock() + lattice.constrain = mock_constrain + mock_constrains.append(mock_constrain) + + rtl.constrain() + for mock_constrain in mock_constrains: + mock_constrain.assert_called_once() + + +def test_assert_constraints(): + """Tests RTL assert_constraints function.""" + rtl = RTL( + monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING], + num_lattices=3, + lattice_rank=3, + ) + mock_asserts = [] + for lattice, _ in rtl._lattice_layers.values(): + mock_assert = Mock() + lattice.assert_constraints = mock_assert + mock_assert.return_value = "violation" + mock_asserts.append(mock_assert) + + violations = rtl.assert_constraints() + for mock_assert in mock_asserts: + mock_assert.assert_called_once() + + assert violations == ["violation"] * len(rtl._lattice_layers) + + +@pytest.mark.parametrize( + "rtl_indices", + [ + [[1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [6, 6]], + [[1, 1, 1], [2, 2, 2], [3, 3, 3]], + [[1, 1, 1], [2, 2, 2], [1, 2, 3], [3, 3, 3]], + [ + [1, 1, 2], + [2, 3, 4], + [1, 5, 5], + [4, 6, 7], + [1, 3, 4], + [2, 3, 3], + [4, 5, 6], + [6, 6, 6], + ], + ], +) +def test_ensure_unique_sublattices_possible(rtl_indices): + """Tests _ensure_unique_sublattices removes duplicates from groups when possible.""" + swapped_indices = RTL._ensure_unique_sublattices(rtl_indices) + for group in swapped_indices: + assert len(set(group)) == len(group) + + +@pytest.mark.parametrize( + "rtl_indices, max_swaps", + [ + ([[1, 1], [1, 2], [1, 3]], 100), + ([[1, 1], [2, 2], [3, 3], [4, 4]], 2), + ], +) +def test_ensure_unique_sublattices_impossible(rtl_indices, max_swaps): + """Tests _ensure_unique_sublattices logs when it can't remove duplicates.""" + with patch("logging.info") as mock_logging_info: + RTL._ensure_unique_sublattices(rtl_indices, max_swaps) + mock_logging_info.assert_called_with( + "Some lattices in RTL may use the same feature multiple times." + ) From 26f47e103a46ff991d970d1495186c77058f1035 Mon Sep 17 00:00:00 2001 From: Brian Park Date: Tue, 14 Nov 2023 14:46:59 -0500 Subject: [PATCH 2/3] chore: fixed RTL for change in Monotonicity enum --- .vscode/settings.json | 4 +++- pytorch_lattice/enums.py | 1 - pytorch_lattice/layers/rtl.py | 10 +++++----- tests/layers/test_rtl.py | 36 ++++++++++++++++------------------- 4 files changed, 24 insertions(+), 27 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index fbb8488..808e049 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,8 @@ { "files.autoSave": "onFocusChange", - "editor.rulers": [88], + "editor.rulers": [ + 88 + ], "editor.formatOnSaveMode": "file", "editor.formatOnSave": true, "files.insertFinalNewline": true, diff --git a/pytorch_lattice/enums.py b/pytorch_lattice/enums.py index cd676a2..3597064 100644 --- a/pytorch_lattice/enums.py +++ b/pytorch_lattice/enums.py @@ -67,7 +67,6 @@ class CategoricalCalibratorInit(_Enum): class Monotonicity(_Enum): """Type of monotonicity constraint. - - NONE: no monotonicity constraint. - INCREASING: increasing monotonicity i.e. increasing input increases output. - DECREASING: decreasing monotonicity i.e. increasing input decreases output. """ diff --git a/pytorch_lattice/layers/rtl.py b/pytorch_lattice/layers/rtl.py index e2751dd..21be9b1 100644 --- a/pytorch_lattice/layers/rtl.py +++ b/pytorch_lattice/layers/rtl.py @@ -68,8 +68,8 @@ def __init__( """Initializes an instance of 'RTL'. Args: - monotonicities: `List` of Monotonicity.INCREASING or Monotonicity.NONE - indicating monotonicities of input features, ordered respectively. + monotonicities: `List` of `Monotonicity.INCREASING` or `None` + indicating monotonicities of input features, ordered respectively. num_lattices: number of lattices in RTL structure. lattice_rank: number of inputs for each lattice in RTL structure. output_min: Minimum output of each lattice in RTL. @@ -126,7 +126,7 @@ def __init__( for i, lattice_indices in enumerate(groups): sorted_indices = sorted( lattice_indices, - key=lambda x: (self.monotonicities[x] != "increasing"), + key=lambda x: (self.monotonicities[x] is None), reverse=False, ) groups[i] = sorted_indices @@ -140,7 +140,7 @@ def __init__( output_max=self.output_max, kernel_init=self.kernel_init, monotonicities=[Monotonicity.INCREASING] * monotonic_count - + [Monotonicity.NONE] * (lattice_rank - monotonic_count), + + [None] * (lattice_rank - monotonic_count), clip_inputs=self.clip_inputs, interpolation=self.interpolation, units=len(groups), @@ -187,7 +187,7 @@ def output_monotonicities(self) -> List[Monotonicity]: if monotonic_count: monotonicity = Monotonicity.INCREASING else: - monotonicity = Monotonicity.NONE + monotonicity = None for _ in range(lattice.units): monotonicities.append(monotonicity) diff --git a/tests/layers/test_rtl.py b/tests/layers/test_rtl.py index 1f9abb6..53a13e4 100644 --- a/tests/layers/test_rtl.py +++ b/tests/layers/test_rtl.py @@ -15,10 +15,10 @@ [ ( [ - Monotonicity.NONE, - Monotonicity.NONE, - Monotonicity.NONE, - Monotonicity.NONE, + None, + None, + None, + None, ], 3, 3, @@ -33,8 +33,8 @@ [ Monotonicity.INCREASING, Monotonicity.INCREASING, - Monotonicity.NONE, - Monotonicity.NONE, + None, + None, ], 3, 3, @@ -46,7 +46,7 @@ False, ), ( - [Monotonicity.INCREASING, Monotonicity.NONE] * 25, + [Monotonicity.INCREASING, None] * 25, 20, 5, None, @@ -101,10 +101,7 @@ def test_initialization( == Monotonicity.INCREASING ) else: - assert ( - rtl.monotonicities[single_lattice_indices[i]] - == Monotonicity.NONE - ) + assert rtl.monotonicities[single_lattice_indices[i]] == None assert len(lattice.monotonicities) == len(lattice.lattice_sizes) assert ( @@ -126,7 +123,7 @@ def test_initialization( @pytest.mark.parametrize( "monotonicities, num_lattices, lattice_rank", [ - ([Monotonicity.NONE] * 9, 2, 2), + ([None] * 9, 2, 2), ([Monotonicity.INCREASING] * 10, 3, 3), ], ) @@ -213,8 +210,7 @@ def test_forward( ): """Tests forward function of RTL Lattice.""" rtl = RTL( - monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING] - * (num_features // 2), + monotonicities=[None, Monotonicity.INCREASING] * (num_features // 2), num_lattices=num_lattices, lattice_rank=lattice_rank, ) @@ -265,14 +261,14 @@ def test_forward( ( [0, 1, 2, 3], [2, 1, 1, 1], - [Monotonicity.NONE] * 2 + [Monotonicity.INCREASING] * 3, + [None] * 2 + [Monotonicity.INCREASING] * 3, ), ( [0, 4, 5, 7], [1, 2, 3, 4], - [Monotonicity.NONE] + [Monotonicity.INCREASING] * 9, + [None] + [Monotonicity.INCREASING] * 9, ), - ([0], [3], [Monotonicity.NONE] * 3), + ([0], [3], [None] * 3), ([1, 2, 3], [1, 1, 1], [Monotonicity.INCREASING] * 3), ], ) @@ -283,7 +279,7 @@ def test_output_monotonicities( ): """Tests output_monotonicities function.""" rtl = RTL( - monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING], + monotonicities=[None, Monotonicity.INCREASING], num_lattices=3, lattice_rank=3, ) @@ -297,7 +293,7 @@ def test_output_monotonicities( def test_constrain(): """Tests RTL constrain function.""" rtl = RTL( - monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING], + monotonicities=[None, Monotonicity.INCREASING], num_lattices=3, lattice_rank=3, ) @@ -315,7 +311,7 @@ def test_constrain(): def test_assert_constraints(): """Tests RTL assert_constraints function.""" rtl = RTL( - monotonicities=[Monotonicity.NONE, Monotonicity.INCREASING], + monotonicities=[None, Monotonicity.INCREASING], num_lattices=3, lattice_rank=3, ) From 2c0ba86c57cd5de59a5bf7c825c99c57bcd69e0c Mon Sep 17 00:00:00 2001 From: Brian Park Date: Wed, 6 Dec 2023 17:17:20 -0500 Subject: [PATCH 3/3] chore: settings.json missing line --- .vscode/settings.json | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 43c5aa1..45c2e19 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,9 @@ { "files.autoSave": "onFocusChange", - "editor.rulers": [88], + "editor.rulers": [ + 88 + ], + "editor.formatOnSave": true, "editor.formatOnSaveMode": "file", "files.insertFinalNewline": true, "python.testing.unittestEnabled": false,