Skip to content

Commit

Permalink
update to work with the latest version of jax
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-meanwhile committed Oct 10, 2023
1 parent e1e6ebe commit e92cd10
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 33 deletions.
13 changes: 7 additions & 6 deletions jvt/convpass.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from flax import linen as nn
from jax import numpy as jnp
from jax import Array

class ResidualPreNorm(nn.Module):
func: nn.Module

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
return self.func(nn.LayerNorm()(inputs)) + inputs

class FeedForward(nn.Module):
dim: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
out = nn.Dense(self.dim)(inputs)
out = nn.gelu(out)
out = nn.Dense(inputs.shape[-1])(out)
Expand All @@ -27,7 +28,7 @@ def setup(self) -> None:
self.hidden_conv = nn.Conv(self.hidden_dim, (3, 3), padding='same')

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, L, E = inputs.shape
H, P = self.hidden_dim, int((L - 1) ** 0.5)
out = nn.Dense(features=H)(inputs)
Expand All @@ -42,10 +43,10 @@ class MHDPAttention(nn.Module):
num_heads: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, L, E = inputs.shape
out = nn.Dense(self.num_heads * E * 3, use_bias=False)(inputs)
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1))
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1))
attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (1 / jnp.sqrt(E)))
out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True)
out = nn.Dense(E, use_bias=False)(out.reshape(B, L, -1))
Expand All @@ -62,7 +63,7 @@ class ConvPassViT(nn.Module):
convp_coef: float = 1.0

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
P = self.patch_size
out = nn.Conv(self.width, kernel_size=(P, P), strides=P, use_bias=False)(inputs)
out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, L, E]
Expand Down
17 changes: 9 additions & 8 deletions jvt/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@

from flax import linen as nn
from jax import numpy as jnp
from jax import Array

class ResidualPreNorm(nn.Module):
func: nn.Module

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
return self.func(nn.LayerNorm()(inputs)) + inputs

class FeedForward(nn.Module):
dim: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
out = nn.Dense(self.dim)(inputs)
out = nn.gelu(out)
out = nn.Dense(inputs.shape[-1])(out)
Expand All @@ -26,10 +27,10 @@ class MHDPAttention(nn.Module):
num_heads: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, L, E = inputs.shape
out = nn.Dense(self.num_heads * E * 3, use_bias=True)(inputs)
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1))
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1))
attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (1 / jnp.sqrt(E)))
out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True)
out = nn.Dense(E, use_bias=False)(out.reshape(B, L, -1))
Expand All @@ -44,7 +45,7 @@ class DeiT(nn.Module):
dim_ffn: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
P = self.patch_size
out = nn.Conv(self.width, kernel_size=(P, P), strides=P, use_bias=False)(inputs)
out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, L, E]
Expand All @@ -66,14 +67,14 @@ def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:


@functools.partial(jax.jit)
def kl_divergence(y: jnp.DeviceArray, y_hat: jnp.DeviceArray) -> jnp.DeviceArray:
def kl_divergence(y: Array, y_hat: Array) -> Array:
return jnp.sum(jnp.exp(y) * (y - y_hat), -1)

@functools.partial(jax.jit, static_argnames=('temp', 'alpha'))
def soft_distillation_loss(y: jnp.DeviceArray, y_s: jnp.DeviceArray, y_t: jnp.DeviceArray, temp: float, alpha: float) -> jnp.DeviceArray:
def soft_distillation_loss(y: Array, y_s: Array, y_t: Array, temp: float, alpha: float) -> Array:
div = kl_divergence(jax.nn.softmax(y_t / temp), jax.nn.softmax(y_s / temp)) * (temp ** 2)
return ((1 - alpha) * optax.softmax_cross_entropy(y_s, y)) + (div * alpha)

@functools.partial(jax.jit)
def hard_distillation_loss(y: jnp.DeviceArray, y_s: jnp.DeviceArray, y_t: jnp.DeviceArray) -> jnp.DeviceArray:
def hard_distillation_loss(y: Array, y_s: Array, y_t: Array) -> Array:
return (optax.softmax_cross_entropy(y_s, y) + optax.softmax_cross_entropy(y_s, y_t)) / 2
15 changes: 8 additions & 7 deletions jvt/levit.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from flax import linen as nn
from jax import numpy as jnp
from jax import Array
from typing import Sequence

class Residual(nn.Module):
func: nn.Module

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
return self.func(inputs) + inputs

class FeedForward(nn.Module):
scale_factor: int = 2
training: bool = False

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
out = nn.Dense(inputs.shape[-1] * self.scale_factor, use_bias=False)(inputs)
out = nn.BatchNorm(self.training, bias_init=nn.initializers.zeros)(out)
out = nn.hard_swish(out)
Expand All @@ -28,13 +29,13 @@ class LeViT_MHDPAttention(nn.Module):
training: bool = False

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, H, W, C = inputs.shape
D = self.dim * self.num_heads
attn_b = self.param('bias', nn.initializers.zeros, (1, 1, H * W, H * W)) # !
out = nn.Dense(D * (3 + 1), use_bias=False)(inputs)
out = nn.BatchNorm(self.training, scale_init=nn.initializers.ones, use_bias=False)(out)
q, k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out.split((D, D * 2), -1))
q, k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out._split((D, D * 2), -1))
out = nn.softmax((jnp.matmul(q, k.swapaxes(2, 3)) / jnp.sqrt(self.dim)) + attn_b)
out = jnp.matmul(out, v).swapaxes(1, 2).reshape(B, H, W, -1)
out = nn.Dense(C, use_bias=False)(nn.hard_swish(out))
Expand All @@ -47,14 +48,14 @@ class LeViT_SubsampleAttention(nn.Module):
training: bool = False

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, H, W, _ = inputs.shape
D = self.dim * self.num_heads
out_hw = (((H - 1) // 2) + 1)
attn_b = self.param('bias', nn.initializers.zeros, (1, 1, out_hw ** 2, H * W)) # !
out = nn.Dense(D * (2 + 3), use_bias=False)(inputs)
out = nn.BatchNorm(self.training, scale_init=nn.initializers.ones, use_bias=False)(out)
k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out.split((D,), -1))
k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out._split((D,), -1))
q = nn.avg_pool(inputs, window_shape=(1, 1), strides=(2, 2))
q = nn.Dense(D, use_bias=False)(q)
q = nn.BatchNorm(self.training, scale_init=nn.initializers.ones, use_bias=False)(q)
Expand All @@ -76,7 +77,7 @@ class LeViT(nn.Module):
training: bool = False

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, H, W, C = inputs.shape # [B, 224, 224, 3]
out = nn.Sequential([*(nn.Conv(C, (3, 3), strides=2) for C in self.filters)])(inputs)
out = nn.Sequential([*(nn.Sequential([ # stage 1
Expand Down
15 changes: 8 additions & 7 deletions jvt/mae.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from jax import random as jrnd
from jax import numpy as jnp
from jax import Array
from flax import linen as nn
from typing import Sequence

class ResidualPreNorm(nn.Module):
func: nn.Module

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
return self.func(nn.LayerNorm()(inputs)) + inputs

class FeedForward(nn.Module):
dim: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
out = nn.Dense(self.dim)(inputs)
out = nn.gelu(out)
out = nn.Dense(inputs.shape[-1])(out)
Expand All @@ -24,10 +25,10 @@ class MHDPAttention(nn.Module):
num_heads: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, L, E = inputs.shape
out = nn.Dense(self.num_heads * E * 3, use_bias=False)(inputs)
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1))
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1))
attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (1 / jnp.sqrt(E)))
out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True)
out = nn.Dense(E, use_bias=False)(out.reshape(B, L, -1))
Expand All @@ -43,7 +44,7 @@ class MaskedEncoder(nn.Module):
dim_ffn: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
P = self.patch_size
out = nn.Conv(self.width, (P, P), strides=P, use_bias=False)(inputs)
out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, P, E]
Expand All @@ -70,7 +71,7 @@ class MaskedDecoder(nn.Module):
dim_ffn: int

@nn.compact
def __call__(self, patches: jnp.DeviceArray, mask: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, patches: Array, mask: Array) -> Array:
B, N, E = self.embeddings_shape
mt = self.param('mt', nn.initializers.normal(0.02), (1, 1, E)) # mask token
pe = self.param('pe', nn.initializers.normal(0.02), (1, N, E))
Expand Down Expand Up @@ -101,7 +102,7 @@ class MAE(nn.Module):
dec_ffn_dim: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
tokens, mask, embeddings_shape = MaskedEncoder(
self.key,
self.patch_size,
Expand Down
11 changes: 6 additions & 5 deletions jvt/vit.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from flax import linen as nn
from jax import numpy as jnp
from jax import Array

class ResidualPreNorm(nn.Module):
func: nn.Module

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
return self.func(nn.LayerNorm(1e-9)(inputs)) + inputs

class FeedForward(nn.Module):
dim: int

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
out = nn.Dense(self.dim)(inputs)
out = nn.gelu(out)
out = nn.Dense(inputs.shape[-1])(out)
Expand All @@ -24,10 +25,10 @@ class MHDPAttention(nn.Module):
dropout_rate: float

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
B, L, E = inputs.shape
out = nn.Dense(self.num_heads * E * 3, use_bias=False)(inputs)
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1))
q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1))
attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (E ** -0.5))
attn = nn.Dropout(self.dropout_rate, (), (not self.enable_dropout))(attn)
out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True)
Expand All @@ -45,7 +46,7 @@ class ViT(nn.Module):
dropout_rate: float = 0.1

@nn.compact
def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray:
def __call__(self, inputs: Array) -> Array:
P = self.patch_size
out = nn.Conv(self.width, kernel_size=(P, P), strides=P, use_bias=False)(inputs)
out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, L, E]
Expand Down

0 comments on commit e92cd10

Please sign in to comment.