Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of RTL layer #4

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
{
"files.autoSave": "onFocusChange",
"editor.rulers": [
88
],
"editor.formatOnSaveMode": "file",
"editor.formatOnSave": true,
"editor.rulers": [88],
"editor.formatOnSaveMode": "file",
willbakst marked this conversation as resolved.
Show resolved Hide resolved
"files.insertFinalNewline": true,
"python.testing.unittestEnabled": false,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ skip-magic-trailing-comma = false
line-ending = "auto"

[tool.mypy]
exclude = ["venv"]
exclude = ["examples", "venv"]

[build-system]
requires = ["poetry-core"]
Expand Down
79 changes: 36 additions & 43 deletions pytorch_lattice/layers/rtl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Random Tiny Lattice module for use in calibrated modeling.
"""A PyTorch module implementing a calibrated modeling layer for Random Tiny Lattices.

PyTorch implementation of a RTL layer.
This layer takes several inputs which would otherwise be slow to run on a single lattice
and runs random subsets on an assortment of Random Tiny Lattices as an optimization.
This module implements an ensemble of tiny lattices that each operate on a subset of the
inputs. It utilizes the multi-unit functionality of the Lattice module to better
optimize speed performance by putting feature subsets that have the same constraint
structure into the same Lattice module as multiple units.
"""
import logging
from typing import List, Optional
from typing import Optional, Union

import numpy as np
import torch
Expand All @@ -15,45 +16,38 @@


class RTL(torch.nn.Module):
"""An RTL Module.
"""A module that efficiently implements Random Tiny Lattices.

Layer takes a number of features that would otherwise be too many to assign to
a single lattice, and instead assigns small random subsets of the features to an
ensemble of smaller lattices. The features are shuffled and uniformly repeated
if there are more slots in the RTL than features.
This module creates an ensemble of lattices where each lattice in the ensemble takes
as input a subset of the input features. For further efficiency, input subsets with
the same constraint structure all go through the same lattice as multiple units in
parallel. When creating the ensemble structure, features are shuffled and uniformly
repeated if there are more available slots in the ensemble structure than there are
features.

Attributes:
- All `__init__` arguments.
- _lattice_layers: `dict` of form `{monotonic_count: (lattice, groups)}` which
keeps track of the RTL structure. Features are indexed then randomly grouped
together to be assigned to a lattice - groups with the same number of
monotonic features can be put into the same lattice for further optimization,
and are thus stored together in the dict according to `monotonic_count`.

Example:
`python
inputs=torch.tensor(...) # shape: (batch_size, D)
monotonicities = List[Monotonicity...] # len: D
rtl1=RTL(
```python
inputs=torch.tensor(...) # shape: (batch_size, D)
monotonicities = List[Monotonicity...] # len: D
random_tiny_lattices = RTL(
monotonicities,
num_lattices = 5
lattice_rank = 3, # num_lattices * lattice_rank must be greater than D
num_lattices=5
lattice_rank=3, # num_lattices * lattice_rank must be greater than D
)
output1 = rtl1(inputs)
output1 = random_tiny_lattices(inputs)

# If you want to pass through consecutive RTLs

rtl2 = RTL(
monotonicities=rtl1.output_monotonicities() # len: rtl1.num_lattices
...
)
output2 = RTL2(output1)
`
# You can stack RTL modules based on the previous RTL's output monotonicities.
rtl2 = RTL(random_tiny_lattices.output_monotonicities(), ...)
outputs2 = rtl2(outputs)
```
"""

def __init__(
self,
monotonicities: List[Monotonicity],
monotonicities: list[Monotonicity],
num_lattices: int,
lattice_rank: int,
lattice_size: int = 2,
Expand All @@ -68,7 +62,7 @@ def __init__(
"""Initializes an instance of 'RTL'.

Args:
monotonicities: `List` of `Monotonicity.INCREASING` or `None`
monotonicities: List of `Monotonicity.INCREASING` or `None`
indicating monotonicities of input features, ordered respectively.
num_lattices: number of lattices in RTL structure.
lattice_rank: number of inputs for each lattice in RTL structure.
Expand All @@ -81,7 +75,7 @@ def __init__(
random_seed: seed used for shuffling.

Raises:
ValueError: if size of RTL, determined by `num_lattices * lattice_rank`, is
ValueError: If size of RTL, determined by `num_lattices * lattice_rank`, is
too small to support the number of input features.
"""
super().__init__()
Expand All @@ -108,7 +102,6 @@ def __init__(
)
np.random.seed(self.random_seed)
np.random.shuffle(rtl_indices)
# split_rtl_indices = np.split(rtl_indices, num_lattices)
split_rtl_indices = [list(arr) for arr in np.split(rtl_indices, num_lattices)]
swapped_rtl_indices = self._ensure_unique_sublattices(split_rtl_indices)
monotonicity_groupings = {}
Expand Down Expand Up @@ -149,14 +142,14 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method computed by using forward methods of each lattice in RTL.
"""Forward method computed by using forward methods of each lattice in ensemble.

Args:
x: input tensor of feature values with shape `(batch_size, num_features)`.

Returns:
torch.Tensor containing the outputs of each lattice within RTL structure. If
`average_outputs == True`, then all outputs are averaged into a tensor of
`torch.Tensor` containing the outputs of each lattice within RTL structure.
If `average_outputs == True`, then all outputs are averaged into a tensor of
shape `(batch_size, 1)`. If `average_outputs == False`, shape of tensor is
`(batch_size, num_lattices)`.
"""
Expand All @@ -175,7 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return result

@torch.no_grad()
def output_monotonicities(self) -> List[Monotonicity]:
def output_monotonicities(self) -> list[Union[Monotonicity, None]]:
"""Gives the monotonicities of the outputs of RTL.

Returns:
Expand All @@ -194,13 +187,13 @@ def output_monotonicities(self) -> List[Monotonicity]:
return monotonicities

@torch.no_grad()
def constrain(self) -> None:
def apply_constraints(self) -> None:
"""Enforces constraints for each lattice in RTL."""
for lattice, _ in self._lattice_layers.values():
lattice.constrain()
lattice.apply_constraints()

@torch.no_grad()
def assert_constraints(self, eps=1e-6) -> List[List[str]]:
def assert_constraints(self, eps: float = 1e-6) -> list[list[str]]:
"""Asserts that each Lattice in RTL satisfies all constraints.

Args:
Expand All @@ -216,9 +209,9 @@ def assert_constraints(self, eps=1e-6) -> List[List[str]]:

@staticmethod
def _ensure_unique_sublattices(
rtl_indices: List[List[int]],
rtl_indices: list[list[int]],
max_swaps: int = 10000,
) -> List[List[int]]:
) -> list[list[int]]:
"""Attempts to ensure every lattice in RTL structure contains unique features.

Args:
Expand Down
10 changes: 5 additions & 5 deletions tests/layers/test_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_initialization(
== Monotonicity.INCREASING
)
else:
assert rtl.monotonicities[single_lattice_indices[i]] == None
assert rtl.monotonicities[single_lattice_indices[i]] is None

assert len(lattice.monotonicities) == len(lattice.lattice_sizes)
assert (
Expand Down Expand Up @@ -290,8 +290,8 @@ def test_output_monotonicities(
assert rtl.output_monotonicities() == expected_out


def test_constrain():
"""Tests RTL constrain function."""
def test_apply_constraints():
"""Tests RTL apply_constraints function."""
rtl = RTL(
monotonicities=[None, Monotonicity.INCREASING],
num_lattices=3,
Expand All @@ -300,10 +300,10 @@ def test_constrain():
mock_constrains = []
for lattice, _ in rtl._lattice_layers.values():
mock_constrain = Mock()
lattice.constrain = mock_constrain
lattice.apply_constraints = mock_constrain
mock_constrains.append(mock_constrain)

rtl.constrain()
rtl.apply_constraints()
for mock_constrain in mock_constrains:
mock_constrain.assert_called_once()

Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.