From e1e6ebe67a16101c31dfc0a49d14d2f8859bf47e Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 10 Oct 2023 17:35:00 +0200 Subject: [PATCH 1/2] use the lion optimizer in MNIST example --- examples/mnist.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/mnist.py b/examples/mnist.py index 1f8fa2a..5e50b88 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -1,3 +1,5 @@ +# You need to have tensorflow and tensorflow-datasets installed to run this example + import tensorflow_datasets as tfds import optax import time @@ -10,8 +12,7 @@ from jvt import ViT -LEARNING_RATE = 2e-3 -MOMENTUM = 0.9 +LEARNING_RATE = 2e-4 MAX_ITER = 8 BATCH_SIZE = 256 CKPT_DIR = 'checkpoints' @@ -36,9 +37,9 @@ def accuracy(parameters, infer_fn) -> float: return jnp.mean(jnp.argmax(infer_fn(parameters, images), -1) == labels) -def create_train_state(rng: jax.random.KeyArray, f: nn.Module): +def create_train_state(rng: jax.Array, f: nn.Module): parameters = jax.jit(f.init)(rng, jnp.ones((1, 28, 28, 1))) - optimizer = optax.sgd(LEARNING_RATE, momentum=MOMENTUM) + optimizer = optax.lion(LEARNING_RATE) return train_state.TrainState.create( apply_fn=jax.jit(f.apply), params=parameters, From e92cd107a4843ce33b92be331f8b619a3a051a66 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 10 Oct 2023 17:39:45 +0200 Subject: [PATCH 2/2] update to work with the latest version of jax --- jvt/convpass.py | 13 +++++++------ jvt/deit.py | 17 +++++++++-------- jvt/levit.py | 15 ++++++++------- jvt/mae.py | 15 ++++++++------- jvt/vit.py | 11 ++++++----- 5 files changed, 38 insertions(+), 33 deletions(-) diff --git a/jvt/convpass.py b/jvt/convpass.py index 24f69d5..0d921fd 100644 --- a/jvt/convpass.py +++ b/jvt/convpass.py @@ -1,18 +1,19 @@ from flax import linen as nn from jax import numpy as jnp +from jax import Array class ResidualPreNorm(nn.Module): func: nn.Module @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: return self.func(nn.LayerNorm()(inputs)) + inputs class FeedForward(nn.Module): dim: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: out = nn.Dense(self.dim)(inputs) out = nn.gelu(out) out = nn.Dense(inputs.shape[-1])(out) @@ -27,7 +28,7 @@ def setup(self) -> None: self.hidden_conv = nn.Conv(self.hidden_dim, (3, 3), padding='same') @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, L, E = inputs.shape H, P = self.hidden_dim, int((L - 1) ** 0.5) out = nn.Dense(features=H)(inputs) @@ -42,10 +43,10 @@ class MHDPAttention(nn.Module): num_heads: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, L, E = inputs.shape out = nn.Dense(self.num_heads * E * 3, use_bias=False)(inputs) - q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1)) + q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1)) attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (1 / jnp.sqrt(E))) out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True) out = nn.Dense(E, use_bias=False)(out.reshape(B, L, -1)) @@ -62,7 +63,7 @@ class ConvPassViT(nn.Module): convp_coef: float = 1.0 @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: P = self.patch_size out = nn.Conv(self.width, kernel_size=(P, P), strides=P, use_bias=False)(inputs) out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, L, E] diff --git a/jvt/deit.py b/jvt/deit.py index b182a7e..ffdd6a4 100644 --- a/jvt/deit.py +++ b/jvt/deit.py @@ -4,19 +4,20 @@ from flax import linen as nn from jax import numpy as jnp +from jax import Array class ResidualPreNorm(nn.Module): func: nn.Module @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: return self.func(nn.LayerNorm()(inputs)) + inputs class FeedForward(nn.Module): dim: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: out = nn.Dense(self.dim)(inputs) out = nn.gelu(out) out = nn.Dense(inputs.shape[-1])(out) @@ -26,10 +27,10 @@ class MHDPAttention(nn.Module): num_heads: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, L, E = inputs.shape out = nn.Dense(self.num_heads * E * 3, use_bias=True)(inputs) - q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1)) + q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1)) attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (1 / jnp.sqrt(E))) out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True) out = nn.Dense(E, use_bias=False)(out.reshape(B, L, -1)) @@ -44,7 +45,7 @@ class DeiT(nn.Module): dim_ffn: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: P = self.patch_size out = nn.Conv(self.width, kernel_size=(P, P), strides=P, use_bias=False)(inputs) out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, L, E] @@ -66,14 +67,14 @@ def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: @functools.partial(jax.jit) -def kl_divergence(y: jnp.DeviceArray, y_hat: jnp.DeviceArray) -> jnp.DeviceArray: +def kl_divergence(y: Array, y_hat: Array) -> Array: return jnp.sum(jnp.exp(y) * (y - y_hat), -1) @functools.partial(jax.jit, static_argnames=('temp', 'alpha')) -def soft_distillation_loss(y: jnp.DeviceArray, y_s: jnp.DeviceArray, y_t: jnp.DeviceArray, temp: float, alpha: float) -> jnp.DeviceArray: +def soft_distillation_loss(y: Array, y_s: Array, y_t: Array, temp: float, alpha: float) -> Array: div = kl_divergence(jax.nn.softmax(y_t / temp), jax.nn.softmax(y_s / temp)) * (temp ** 2) return ((1 - alpha) * optax.softmax_cross_entropy(y_s, y)) + (div * alpha) @functools.partial(jax.jit) -def hard_distillation_loss(y: jnp.DeviceArray, y_s: jnp.DeviceArray, y_t: jnp.DeviceArray) -> jnp.DeviceArray: +def hard_distillation_loss(y: Array, y_s: Array, y_t: Array) -> Array: return (optax.softmax_cross_entropy(y_s, y) + optax.softmax_cross_entropy(y_s, y_t)) / 2 diff --git a/jvt/levit.py b/jvt/levit.py index 0d2d63a..46161d6 100644 --- a/jvt/levit.py +++ b/jvt/levit.py @@ -1,12 +1,13 @@ from flax import linen as nn from jax import numpy as jnp +from jax import Array from typing import Sequence class Residual(nn.Module): func: nn.Module @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: return self.func(inputs) + inputs class FeedForward(nn.Module): @@ -14,7 +15,7 @@ class FeedForward(nn.Module): training: bool = False @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: out = nn.Dense(inputs.shape[-1] * self.scale_factor, use_bias=False)(inputs) out = nn.BatchNorm(self.training, bias_init=nn.initializers.zeros)(out) out = nn.hard_swish(out) @@ -28,13 +29,13 @@ class LeViT_MHDPAttention(nn.Module): training: bool = False @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, H, W, C = inputs.shape D = self.dim * self.num_heads attn_b = self.param('bias', nn.initializers.zeros, (1, 1, H * W, H * W)) # ! out = nn.Dense(D * (3 + 1), use_bias=False)(inputs) out = nn.BatchNorm(self.training, scale_init=nn.initializers.ones, use_bias=False)(out) - q, k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out.split((D, D * 2), -1)) + q, k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out._split((D, D * 2), -1)) out = nn.softmax((jnp.matmul(q, k.swapaxes(2, 3)) / jnp.sqrt(self.dim)) + attn_b) out = jnp.matmul(out, v).swapaxes(1, 2).reshape(B, H, W, -1) out = nn.Dense(C, use_bias=False)(nn.hard_swish(out)) @@ -47,14 +48,14 @@ class LeViT_SubsampleAttention(nn.Module): training: bool = False @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, H, W, _ = inputs.shape D = self.dim * self.num_heads out_hw = (((H - 1) // 2) + 1) attn_b = self.param('bias', nn.initializers.zeros, (1, 1, out_hw ** 2, H * W)) # ! out = nn.Dense(D * (2 + 3), use_bias=False)(inputs) out = nn.BatchNorm(self.training, scale_init=nn.initializers.ones, use_bias=False)(out) - k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out.split((D,), -1)) + k, v = (x.reshape(B, H * W, self.num_heads, -1).swapaxes(1, 2) for x in out._split((D,), -1)) q = nn.avg_pool(inputs, window_shape=(1, 1), strides=(2, 2)) q = nn.Dense(D, use_bias=False)(q) q = nn.BatchNorm(self.training, scale_init=nn.initializers.ones, use_bias=False)(q) @@ -76,7 +77,7 @@ class LeViT(nn.Module): training: bool = False @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, H, W, C = inputs.shape # [B, 224, 224, 3] out = nn.Sequential([*(nn.Conv(C, (3, 3), strides=2) for C in self.filters)])(inputs) out = nn.Sequential([*(nn.Sequential([ # stage 1 diff --git a/jvt/mae.py b/jvt/mae.py index 7b00e8d..2f4385b 100644 --- a/jvt/mae.py +++ b/jvt/mae.py @@ -1,5 +1,6 @@ from jax import random as jrnd from jax import numpy as jnp +from jax import Array from flax import linen as nn from typing import Sequence @@ -7,14 +8,14 @@ class ResidualPreNorm(nn.Module): func: nn.Module @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: return self.func(nn.LayerNorm()(inputs)) + inputs class FeedForward(nn.Module): dim: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: out = nn.Dense(self.dim)(inputs) out = nn.gelu(out) out = nn.Dense(inputs.shape[-1])(out) @@ -24,10 +25,10 @@ class MHDPAttention(nn.Module): num_heads: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, L, E = inputs.shape out = nn.Dense(self.num_heads * E * 3, use_bias=False)(inputs) - q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1)) + q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1)) attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (1 / jnp.sqrt(E))) out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True) out = nn.Dense(E, use_bias=False)(out.reshape(B, L, -1)) @@ -43,7 +44,7 @@ class MaskedEncoder(nn.Module): dim_ffn: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: P = self.patch_size out = nn.Conv(self.width, (P, P), strides=P, use_bias=False)(inputs) out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, P, E] @@ -70,7 +71,7 @@ class MaskedDecoder(nn.Module): dim_ffn: int @nn.compact - def __call__(self, patches: jnp.DeviceArray, mask: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, patches: Array, mask: Array) -> Array: B, N, E = self.embeddings_shape mt = self.param('mt', nn.initializers.normal(0.02), (1, 1, E)) # mask token pe = self.param('pe', nn.initializers.normal(0.02), (1, N, E)) @@ -101,7 +102,7 @@ class MAE(nn.Module): dec_ffn_dim: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: tokens, mask, embeddings_shape = MaskedEncoder( self.key, self.patch_size, diff --git a/jvt/vit.py b/jvt/vit.py index 4f124ea..72b9c0d 100644 --- a/jvt/vit.py +++ b/jvt/vit.py @@ -1,18 +1,19 @@ from flax import linen as nn from jax import numpy as jnp +from jax import Array class ResidualPreNorm(nn.Module): func: nn.Module @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: return self.func(nn.LayerNorm(1e-9)(inputs)) + inputs class FeedForward(nn.Module): dim: int @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: out = nn.Dense(self.dim)(inputs) out = nn.gelu(out) out = nn.Dense(inputs.shape[-1])(out) @@ -24,10 +25,10 @@ class MHDPAttention(nn.Module): dropout_rate: float @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: B, L, E = inputs.shape out = nn.Dense(self.num_heads * E * 3, use_bias=False)(inputs) - q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out.split(3, -1)) + q, k, v = (x.reshape(B, L, self.num_heads, E) for x in out._split(3, -1)) attn = nn.softmax(jnp.einsum('bqhd,bkhd->bhqk', q, k, optimize=True) * (E ** -0.5)) attn = nn.Dropout(self.dropout_rate, (), (not self.enable_dropout))(attn) out = jnp.einsum('bhwd,bdhv->bwhv', attn, v, optimize=True) @@ -45,7 +46,7 @@ class ViT(nn.Module): dropout_rate: float = 0.1 @nn.compact - def __call__(self, inputs: jnp.DeviceArray) -> jnp.DeviceArray: + def __call__(self, inputs: Array) -> Array: P = self.patch_size out = nn.Conv(self.width, kernel_size=(P, P), strides=P, use_bias=False)(inputs) out = out.reshape(out.shape[0], -1, out.shape[3]) # [B, H, W, E] -> [B, L, E]