Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Jun 17, 2024
1 parent e7eaa07 commit 2137534
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 27 deletions.
1 change: 1 addition & 0 deletions ebtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
del Concatenate
del DuplexLinearNeck
del SharedDuplexLinearNeck
del GenerAct
del WideResNet
del beta_reco_bce
del beta_reco_bce_splitout
Expand Down
2 changes: 2 additions & 0 deletions ebtorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .architectures import FCBlock
from .architectures import GaussianReparameterizerSampler
from .architectures import GaussianReparameterizerSamplerLegacy
from .architectures import GenerAct
from .architectures import InnerProduct
from .architectures import lexsemble
from .architectures import pixelwise_bce_mean
Expand Down Expand Up @@ -112,6 +113,7 @@
del eval_model_on_test
del extract_conv_filters
del argser_f
del fxfx2module
del argsink
del no_op
del download_gdrive
Expand Down
24 changes: 24 additions & 0 deletions ebtorch/nn/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from .functional import silhouette_score
from .penalties import beta_gaussian_kldiv
from .utils import fxfx2module

__all__ = [
"pixelwise_bce_sum",
Expand Down Expand Up @@ -65,6 +66,7 @@
"SharedDuplexLinearNeck",
"GaussianReparameterizerSamplerLegacy",
"lexsemble",
"GenerAct",
]

# CUSTOM TYPES
Expand Down Expand Up @@ -954,3 +956,25 @@ def forward(
cxc: torch.Tensor = torch.cat(xc, dim=1)
# noinspection PyTypeChecker
return torch.chunk(self.shared_layer(cxc), 2, dim=1)


class GenerAct(nn.Module):
def __init__(
self,
act: Union[Callable[[Tensor], Tensor], nn.Module],
subv: Optional[float] = None,
maxv: Optional[float] = None,
minv: Optional[float] = None,
):
super().__init__()
self.act: nn.Module = fxfx2module(act)
self.subv: Optional[float] = subv
self.maxv: Optional[float] = maxv
self.minv: Optional[float] = minv

def forward(self, x: Tensor) -> Tensor:
x: Tensor = self.act(x)
x: Tensor = x - self.subv if self.subv is not None else x
x: Tensor = x.clamp_max(self.maxv) if self.maxv is not None else x
x: Tensor = x.clamp_min(self.minv) if self.minv is not None else x
return x
28 changes: 2 additions & 26 deletions ebtorch/nn/convstems.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from torch import Tensor
from torch.nn import functional as F

from .utils import fxfx2module

# ──────────────────────────────────────────────────────────────────────────────
__all__ = [
"ConvStem",
Expand Down Expand Up @@ -46,32 +48,6 @@ def parse(x: Union[Any, Iterable[Any]]) -> Tuple[Any, ...]:
# ──────────────────────────────────────────────────────────────────────────────


class _FxToFxobj: # NOSONAR
__slots__ = ("fx",)

def __init__(self, fx: Callable[[Tensor], Tensor]):
self.fx: Callable[[Tensor], Tensor] = fx

def __call__(self, x: Tensor) -> Tensor:
return self.fx(x)


class _FxToModule(nn.Module):
def __init__(self, fx: Callable[[Tensor], Tensor]):
super().__init__()
self.fx: _FxToFxobj = _FxToFxobj(fx)

def forward(self, x: Tensor) -> Tensor:
return self.fx(x)


def fxfx2module(fx: Union[Callable[[Tensor], Tensor], nn.Module]) -> nn.Module:
return fx if isinstance(fx, nn.Module) else _FxToModule(fx)


# ──────────────────────────────────────────────────────────────────────────────


class ConvStem(nn.Module):
"""
ConvStem for Vision Transformer (ViT) models.
Expand Down
1 change: 1 addition & 0 deletions ebtorch/nn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .onlyutils import argsink
from .onlyutils import download_gdrive
from .onlyutils import emplace_kv
from .onlyutils import fxfx2module
from .onlyutils import no_op
from .onlyutils import subset_state_dict
from .palettes import petroff_2021_cmap
Expand Down
27 changes: 27 additions & 0 deletions ebtorch/nn/utils/onlyutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
# SPDX-License-Identifier: Apache-2.0
#
# Imports
from collections.abc import Callable
from functools import partial as fpartial
from typing import Any
from typing import Tuple
from typing import Union

import requests
from torch import nn
from torch import Tensor

__all__ = [
"argser_f",
Expand All @@ -36,6 +39,7 @@
"argsink",
"no_op",
"subset_state_dict",
"fxfx2module",
]


Expand Down Expand Up @@ -124,3 +128,26 @@ def no_op() -> None:
A function that does nothing, by design.
"""
pass


def fxfx2module(fx: Union[Callable[[Tensor], Tensor], nn.Module]) -> nn.Module:
return fx if isinstance(fx, nn.Module) else _FxToModule(fx)


class _FxToFxobj: # NOSONAR
__slots__ = ("fx",)

def __init__(self, fx: Callable[[Tensor], Tensor]):
self.fx: Callable[[Tensor], Tensor] = fx

def __call__(self, x: Tensor) -> Tensor:
return self.fx(x)


class _FxToModule(nn.Module):
def __init__(self, fx: Callable[[Tensor], Tensor]):
super().__init__()
self.fx: _FxToFxobj = _FxToFxobj(fx)

def forward(self, x: Tensor) -> Tensor:
return self.fx(x)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.25.4",
version="0.25.5",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 2137534

Please sign in to comment.