Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding torch.nn.ConvTranspose2d #3

Merged
merged 10 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/test_all_the_things.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,36 @@ def test_vit_b16():

# Models use different convolution backends and are too deep to compare gradients programmatically. But they line up
# to reasonable expectations.


def test_conv_transpose2d():
for in_channels in [1, 2]:
for out_channels in [1, 2]:
for kernel_size in [1, 2, (1, 2)]:
for stride in [1, 2, (1, 2)]:
for padding in [(0, 0), 1, 2, (1, 2)]:
for output_padding in [0, 1, 2, (1, 2)]:
for bias in [False, True]:
for dilation in [1, 2, (1, 2)]:
model = torch.nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=bias,
dilation=dilation,
)
params = {k: random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()}
model.load_state_dict({k: j2t(v) for k, v in params.items()})

input_batch = random.normal(random.PRNGKey(123), (3, in_channels, 16, 16))
try:
res_torch = model(j2t(input_batch))
except RuntimeError:
# RuntimeError: output padding must be smaller than either stride or dilation
continue

res_jax = t2j(model)(input_batch, state_dict=params)
aac(res_jax, res_torch.numpy(force=True), atol=1e-4)
238 changes: 237 additions & 1 deletion torch2jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import functools
import math
from typing import Optional
from typing import Optional, Sequence, Tuple, Union

import jax
import jax.dlpack
Expand Down Expand Up @@ -230,6 +230,242 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
return res


@implements(torch.nn.functional.conv_transpose2d)
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
# This implementation is taken from this PR https://github.com/google/jax/pull/5772
assert input.ndim == 4, "TODO: implement non-batched input"
assert groups == 1, "TODO: implement groups != 1"

ph, pw = (padding, padding) if isinstance(padding, int) else padding
res = gradient_based_conv_transpose(
lhs=coerce(input),
rhs=coerce(weight),
strides=stride,
padding=[(ph, ph), (pw, pw)],
output_padding=output_padding,
dilation=dilation,
dimension_numbers=("NCHW", "OIHW", "NCHW"),
)
if bias is not None:
res += coerce(bias)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis]
return res


def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1):
"""Taken from https://github.com/google/jax/pull/5772
Determines the output length of a transposed convolution given the input length.
Function modified from Keras.
Arguments:
input_length: Integer.
filter_size: Integer.
padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple.
output_padding: Integer, amount of padding along the output dimension. Can
be set to `None` in which case the output length is inferred.
stride: Integer.
dilation: Integer.
Returns:
The output length (integer).
"""
if input_length is None:
return None

# Get the dilated kernel size
filter_size = filter_size + (filter_size - 1) * (dilation - 1)

# Infer length if output padding is None, else compute the exact length
if output_padding is None:
if padding == "VALID":
length = input_length * stride + jax.lax.max(filter_size - stride, 0)
elif padding == "SAME":
length = input_length * stride
else:
length = (input_length - 1) * stride + filter_size - padding[0] - padding[1]

else:
if padding == "SAME":
pad = filter_size // 2
total_pad = pad * 2
elif padding == "VALID":
total_pad = 0
else:
total_pad = padding[0] + padding[1]

length = (input_length - 1) * stride + filter_size - total_pad + output_padding

return length


def _compute_adjusted_padding(
input_size: int,
output_size: int,
kernel_size: int,
stride: int,
padding: Union[str, Tuple[int, int]],
dilation: int = 1,
) -> Tuple[int, int]:
"""
Taken from https://github.com/google/jax/pull/5772
Computes adjusted padding for desired ConvTranspose `output_size`.
Ported from DeepMind Haiku.
"""
kernel_size = (kernel_size - 1) * dilation + 1

if padding == "VALID":
expected_input_size = (output_size - kernel_size + stride) // stride
if input_size != expected_input_size:
raise ValueError(
f"The expected input size with the current set of input "
f"parameters is {expected_input_size} which doesn't "
f"match the actual input size {input_size}."
)
padding_before = 0
elif padding == "SAME":
expected_input_size = (output_size + stride - 1) // stride
if input_size != expected_input_size:
raise ValueError(
f"The expected input size with the current set of input "
f"parameters is {expected_input_size} which doesn't "
f"match the actual input size {input_size}."
)
padding_needed = jax.lax.max(0, (input_size - 1) * stride + kernel_size - output_size)
padding_before = padding_needed // 2
else:
padding_before = padding[0] # type: ignore[assignment]

expanded_input_size = (input_size - 1) * stride + 1
padded_out_size = output_size + kernel_size - 1
pad_before = kernel_size - 1 - padding_before
pad_after = padded_out_size - expanded_input_size - pad_before
return (pad_before, pad_after)


def gradient_based_conv_transpose(
lhs,
rhs,
strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
output_padding: Optional[Sequence[int]] = None,
output_shape: Optional[Sequence[int]] = None,
dilation: Optional[Sequence[int]] = None,
dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
transpose_kernel: bool = True,
precision=None,
):
"""
Taken from https://github.com/google/jax/pull/5772
Convenience wrapper for calculating the N-d transposed convolution.
Much like `conv_transpose`, this function calculates transposed convolutions
via fractionally strided convolution rather than calculating the gradient
(transpose) of a forward convolution. However, the latter is more common
among deep learning frameworks, such as TensorFlow, PyTorch, and Keras.
This function provides the same set of APIs to help reproduce results in these frameworks.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
strides: sequence of `n` integers, amounts to strides of the corresponding forward convolution.
padding: `"SAME"`, `"VALID"`, or a sequence of `n` integer 2-tuples that controls
the before-and-after padding for each `n` spatial dimension of
the corresponding forward convolution.
output_padding: A sequence of integers specifying the amount of padding along
each spacial dimension of the output tensor, used to disambiguate the output shape of
transposed convolutions when the stride is larger than 1.
(see a detailed description at
1https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)
The amount of output padding along a given dimension must
be lower than the stride along that same dimension.
If set to `None` (default), the output shape is inferred.
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
output_shape: Output shape of the spatial dimensions of a transpose
convolution. Can be `None` or an iterable of `n` integers. If a `None` value is given (default),
the shape is automatically calculated.
Similar to `output_padding`, `output_shape` is also for disambiguating the output shape
when stride > 1 (see also
https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose)
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. Dilated convolution
is also known as atrous convolution.
dimension_numbers: tuple of dimension descriptors as in
lax.conv_general_dilated. Defaults to tensorflow convention.
transpose_kernel: if `True` flips spatial axes and swaps the input/output
channel axes of the kernel. This makes the output of this function identical
to the gradient-derived functions like keras.layers.Conv2DTranspose and
torch.nn.ConvTranspose2d applied to the same kernel.
Although for typical use in neural nets this is unnecessary
and makes input/output channel specification confusing, you need to set this to `True`
in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and PyTorch.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
Returns:
Transposed N-d convolution.
"""
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
ndims = len(lhs.shape)
one = (1,) * (ndims - 2)
# Set dimensional layout defaults if not specified.
if dimension_numbers is None:
if ndims == 2:
dimension_numbers = ("NC", "IO", "NC")
elif ndims == 3:
dimension_numbers = ("NHC", "HIO", "NHC")
elif ndims == 4:
dimension_numbers = ("NHWC", "HWIO", "NHWC")
elif ndims == 5:
dimension_numbers = ("NHWDC", "HWDIO", "NHWDC")
else:
raise ValueError("No 4+ dimensional dimension_number defaults.")
dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
k_shape = jnp.take(jnp.array(rhs.shape), jnp.array(dn.rhs_spec))
k_sdims = k_shape[2:] # type: ignore[index]
i_shape = jnp.take(jnp.array(lhs.shape), jnp.array(dn.lhs_spec))
i_sdims = i_shape[2:] # type: ignore[index]

# Calculate correct output shape given padding and strides.
if dilation is None:
dilation = (1,) * (rhs.ndim - 2)

if output_padding is None:
output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item]

if isinstance(padding, str):
if padding in {"SAME", "VALID"}:
padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item]
else:
raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.")

inferred_output_shape = tuple(
map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation)
)
if output_shape is None:
output_shape = inferred_output_shape # type: ignore[assignment]
else:
if not output_shape == inferred_output_shape:
raise ValueError(
f"`output_padding` and `output_shape` are not compatible."
f"Inferred output shape from `output_padding`: {inferred_output_shape}, "
f"but got `output_shape` {output_shape}"
)

pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation))

if transpose_kernel:
# flip spatial dims and swap input / output channel axes
rhs = _flip_axes(rhs, dn.rhs_spec[2:])
rhs = jnp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, precision=precision)


def _flip_axes(x, axes):
"""
Taken from https://github.com/google/jax/pull/5772
Flip ndarray 'x' along each axis specified in axes tuple."""
for axis in axes:
x = jnp.flip(x, axis)
return x


@implements(torch.nn.functional.dropout)
def dropout(input, p=0.5, training=True, inplace=False):
assert not training, "TODO: implement dropout=True"
Expand Down