Skip to content

Commit

Permalink
fix: jaxify codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Sep 13, 2024
1 parent d894b3a commit 0414a02
Show file tree
Hide file tree
Showing 13 changed files with 485 additions and 465 deletions.
6 changes: 5 additions & 1 deletion .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ $ uv sync

```shell
$ uv jflux
```
```

## References

* Original Implementation: [black-forest-labs/flux](https://github.com/black-forest-labs/flux)
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ models/
notebooks/

**/.DS_Store
.*
.*
157 changes: 129 additions & 28 deletions jflux/modules/autoencoder.py → jflux/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flax import nnx
from einops import rearrange

from jflux.layers import DiagonalGaussian
from jflux.sampling import interpolate


Expand All @@ -23,13 +24,21 @@ class AutoEncoderParams:


class AttnBlock(nnx.Module):
def __init__(self, in_channels: int, rngs: nnx.Rngs) -> None:
self.in_channels = in_channels
"""
Attention Block for the Encoder and Decoder.
Args:
in_channels (int): Number of input channels.
rngs (nnx.Rngs): RNGs for the module.
"""

def __init__(self, in_channels: int, rngs: nnx.Rngs) -> None:
# Normalization Layer
self.norm = nnx.GroupNorm(
num_groups=32, num_features=in_channels, epsilon=1e-6, rngs=rngs
)

# Query, Key and Value Layers
self.query_layer = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
Expand All @@ -48,6 +57,8 @@ def __init__(self, in_channels: int, rngs: nnx.Rngs) -> None:
kernel_size=(1, 1),
rngs=rngs,
)

# Output Projection Layer
self.projection = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
Expand All @@ -64,6 +75,7 @@ def attention(self, input_tensor: Array) -> Array:
key = self.key_layer(input_tensor)
value = self.value_layer(input_tensor)

# TODO (ariG23498): incorporate the attention fn from jflux.math
# Reshape for JAX Attention impl
b, c, h, w = query.shape
query = rearrange(query, "b c h w -> b (h w) 1 c")
Expand All @@ -79,10 +91,18 @@ def __call__(self, x: Array) -> Array:


class ResnetBlock(nnx.Module):
"""
Residual Block for the Encoder and Decoder.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
rngs (nnx.Rngs): RNGs for the module.
"""

def __init__(self, in_channels: int, out_channels: int, rngs: nnx.Rngs) -> None:
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.out_channels = in_channels if out_channels is None else out_channels

self.norm1 = nnx.GroupNorm(
num_groups=32, num_features=in_channels, epsilon=1e-6, rngs=rngs
Expand Down Expand Up @@ -133,6 +153,17 @@ def __call__(self, input_tensor: Array) -> Array:


class Downsample(nnx.Module):
"""
Downsample Block for the Encoder.
Args:
in_channels (int): Number of input channels.
rngs (nnx.Rngs): RNGs for the module.
Returns:
Downsampled input tensor.
"""

def __init__(self, in_channels: int, rngs: nnx.Rngs) -> None:
self.conv = nnx.Conv(
in_features=in_channels,
Expand All @@ -150,6 +181,17 @@ def __call__(self, x: Array) -> Array:


class Upsample(nnx.Module):
"""
Upsample Block for the Decoder.
Args:
in_channels (int): Number of input channels.
rngs (nnx.Rngs): RNGs for the module.
Returns:
Upsampled input tensor.
"""

def __init__(self, in_channels: int, rngs: nnx.Rngs) -> None:
self.conv = nnx.Conv(
in_features=in_channels,
Expand All @@ -167,6 +209,19 @@ def __call__(self, x: Array) -> Array:


class Encoder(nnx.Module):
"""
Encoder module for the AutoEncoder.
Args:
resolution (int): Resolution of the input tensor.
in_channels (int): Number of input channels.
ch (int): Number of channels.
ch_mult (list[int]): List of channel multipliers.
num_res_blocks (int): Number of residual blocks.
z_channels (int): Number of latent channels.
rngs (nnx.Rngs): RNGs for the module.
"""

def __init__(
self,
resolution: int,
Expand All @@ -176,12 +231,13 @@ def __init__(
num_res_blocks: int,
z_channels: int,
rngs: nnx.Rngs,
):
) -> None:
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.rngs = rngs
# downsampling
self.conv_in = nnx.Conv(
in_features=in_channels,
Expand All @@ -203,13 +259,15 @@ def __init__(
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block.append(
ResnetBlock(in_channels=block_in, out_channels=block_out, rngs=rngs)
)
block_in = block_out
down = nnx.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
down.downsample = Downsample(in_channels=block_in, rngs=rngs)
curr_res = curr_res // 2
self.down.append(down)

Expand Down Expand Up @@ -261,6 +319,20 @@ def __call__(self, x: Array) -> Array:


class Decoder(nnx.Module):
"""
Decoder module for the AutoEncoder.
Args:
resolution (int): Resolution of the input tensor.
in_channels (int): Number of input channels.
ch (int): Number of channels.
out_ch (int): Number of output channels.
ch_mult (list[int]): List of channel multipliers.
num_res_blocks (int): Number of residual blocks.
z_channels (int): Number of latent channels.
rngs (nnx.Rngs): RNGs for the module.
"""

def __init__(
self,
ch: int,
Expand All @@ -271,7 +343,7 @@ def __init__(
resolution: int,
z_channels: int,
rngs: nnx.Rngs,
):
) -> None:
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
Expand All @@ -296,9 +368,13 @@ def __init__(

# middle
self.mid = nnx.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, rngs=rngs
)
self.mid.attn_1 = AttnBlock(in_channels=block_in, rngs=rngs)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, rngs=rngs
)

# upsampling
self.up = nnx.ModuleList()
Expand All @@ -307,13 +383,15 @@ def __init__(
attn = nnx.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block.append(
ResnetBlock(in_channels=block_in, out_channels=block_out, rngs=rngs)
)
block_in = block_out
up = nnx.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
up.upsample = Upsample(in_channels=block_in, rngs=rngs)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

Expand Down Expand Up @@ -355,29 +433,23 @@ def __call__(self, z: Array) -> Array:
return h


class DiagonalGaussian(nnx.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
self.sample = sample
self.chunk_dim = chunk_dim

def __call__(self, z: Array) -> Array:
mean, logvar = jnp.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = jnp.exp(0.5 * logvar)
return mean + std * jnp.randn_like(mean)
else:
return mean
class AutoEncoder(nnx.Module):
"""
AutoEncoder module.
Args:
params (AutoEncoderParams): Parameters for the AutoEncoder.
"""

class AutoEncoder(nnx.Module):
def __init__(self, params: AutoEncoderParams):
def __init__(self, params: AutoEncoderParams, rngs: nnx.Rngs) -> None:
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
rngs=rngs,
)
self.decoder = Decoder(
resolution=params.resolution,
Expand All @@ -387,20 +459,49 @@ def __init__(self, params: AutoEncoderParams):
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
rngs=rngs,
)
self.reg = DiagonalGaussian()
# FIXME: Provide a single key
self.reg = DiagonalGaussian(key=rngs) # noqa: ignore

self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor

def encode(self, x: Array) -> Array:
"""
Encodes the provided tensor.
Args:
x (Array): Input tensor.
Returns:
Array: Encoded tensor.
"""
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z

def decode(self, z: Array) -> Array:
"""
Decodes the provided tensor.
Args:
z (Array): Encoded tensor.
Returns:
Array: Decoded tensor.
"""
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)

def __call__(self, x: Array) -> Array:
"""
Forward pass for the AutoEncoder Module.
Args:
x (Array): Input tensor.
Returns:
Array
"""
return self.decode(self.encode(x))
Loading

0 comments on commit 0414a02

Please sign in to comment.