Skip to content

Commit

Permalink
update SAM model readme
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Nov 28, 2023
1 parent 455a297 commit a5b3c23
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 95 deletions.
23 changes: 23 additions & 0 deletions keras_cv_attention_models/segment_anything/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ___Keras Segment Anything___
***

## Summary
- [Github facebookresearch/segment-anything(https://github.com/facebookresearch/segment-anything)
- [Github ChaoningZhang/MobileSAM](https://github.com/ChaoningZhang/MobileSAM)
- Model weights ported from [Github ChaoningZhang/MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
## Models
| Model | Params | FLOPs | Input | Download | |
| ----- | ------ | ----- | ----- | -------- | --- |
| | | | | | |
## Usage
- **Basic [Still not good, just working]**
```py
from keras_cv_attention_models import test_images
from keras_cv_attention_models.segment_anything import sam
mm = sam.SAM()
image = test_images.dog_cat()
points, labels = np.array([[400, 400]]), np.array([1])
masks, iou_predictions, low_res_masks = mm(image, points, labels)
fig = mm.show(image, masks, iou_predictions, points=points, labels=labels, save_path='aa.jpg')
```
![segment_anything](https://github.com/leondgarse/keras_cv_attention_models/assets/5744524/e3013d4e-1c28-426a-bb88-66144c8413ac)
64 changes: 30 additions & 34 deletions keras_cv_attention_models/segment_anything/mask_decoder.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import math
import numpy as np
from keras_cv_attention_models import backend
from keras_cv_attention_models.backend import layers, models, functional, image_data_format, initializers
from keras_cv_attention_models.backend import layers, models, functional, image_data_format
from keras_cv_attention_models.models import register_model
from keras_cv_attention_models.attention_layers import (
activation_by_name,
add_with_layer_scale_and_drop_block,
batchnorm_with_activation,
conv2d_no_bias,
depthwise_conv2d_no_bias,
inverted_residual_block,
layer_norm,
mlp_block,
multi_head_self_attention,
mhsa_with_multi_head_position,
qkv_to_multi_head_channels_last_format,
scaled_dot_product_attention,
window_attention,
ClassToken,
MultiHeadPositionalEmbedding,
add_pre_post_process,
)
from keras_cv_attention_models.download_and_load import reload_model_weights

LAYER_NORM_EPSILON = 1e-6 # Only used in up_1_ln
PRETRAINED_DICT = {
"mask_decoder": {"mobile_sam_5m": "212d83fc04a1250d68db83ba9a33e2e2"},
}


def mlp_block_multi(inputs, hidden_dim, output_channel=-1, num_blocks=2, use_bias=True, activation="gelu", name=None):
output_channel = output_channel if output_channel > 0 else inputs.shape[-1]
nn = inputs
for id in range(num_blocks - 1):
nn = layers.Dense(hidden_dim, use_bias=use_bias, name=name and name + "dense_{}".format(id + 1))(nn)
nn = activation_by_name(nn, activation, name=name + "{}_".format(id + 1))
nn = layers.Dense(output_channel, use_bias=use_bias, name=name and name + "dense_{}".format(num_blocks))(nn)
return nn


def attention(query, key=None, value=None, num_heads=8, head_dim=0, name=""):
key = query if key is None else key
Expand All @@ -42,7 +44,7 @@ def attention(query, key=None, value=None, num_heads=8, head_dim=0, name=""):


def two_way_attention_block(
query, key, query_position, key_position, num_heads=8, downsample_rate=2, skip_first_layer_pe=False, mlp_ratio=8, activation="relu", name=""
query, key, query_position, key_position, num_heads=8, downsample_rate=2, skip_first_layer_pe=False, mlp_ratio=8, activation="gelu", name=""
):
if skip_first_layer_pe:
query = attention(query, name=name + "query_")
Expand All @@ -60,7 +62,7 @@ def two_way_attention_block(
query = layer_norm(query, axis=-1, name=name + "cross_embedding_")

# MLP block
mlp_out = mlp_block(query, hidden_dim=int(query.shape[-1] * mlp_ratio), activation=activation, name=name + "mlp_")
mlp_out = mlp_block_multi(query, hidden_dim=int(query.shape[-1] * mlp_ratio), num_blocks=2, activation=activation, name=name + "mlp_")
query = query + mlp_out
query = layer_norm(query, axis=-1, name=name + "mlp_")

Expand All @@ -73,7 +75,7 @@ def two_way_attention_block(


def two_way_transformer(
image_embedding, image_position, point_embedding, depth=2, num_heads=8, mlp_dim=2048, downsample_rate=2, mlp_ratio=8, activation="relu", name=""
image_embedding, image_position, point_embedding, depth=2, num_heads=8, mlp_dim=2048, downsample_rate=2, mlp_ratio=8, activation="gelu", name=""
):
query, query_position, key, key_position = point_embedding, point_embedding, image_embedding, image_position

Expand All @@ -99,29 +101,20 @@ def two_way_transformer(
return query, key


def mlp_block_3(inputs, hidden_dim, output_channel=-1, use_bias=True, activation="relu", name=None):
output_channel = output_channel if output_channel > 0 else inputs.shape[-1]
nn = layers.Dense(hidden_dim, use_bias=use_bias, name=name and name + "dense_1")(inputs)
nn = activation_by_name(nn, activation, name=name + "1_")
nn = layers.Dense(hidden_dim, use_bias=use_bias, name=name and name + "dense_2")(nn)
nn = activation_by_name(nn, activation, name=name + "2_")
nn = layers.Dense(output_channel, use_bias=use_bias, name=name and name + "dense_3")(nn)
return nn


def MaskDecoder(embed_dims=256, num_mask_tokens=4, activation="relu", name="mask_decoder"):
@register_model
def MaskDecoder(embed_dims=256, num_mask_tokens=4, activation="gelu", pretrained="mobile_sam_5m", name="mask_decoder"):
image_embedding = layers.Input([None, None, embed_dims], batch_size=1, name="image_embedding") # Inputs is channels_last also for PyTorch backend
point_embedding = layers.Input([None, embed_dims], batch_size=1, name="point_embedding")
image_position = layers.Input([None, None, embed_dims], batch_size=1, name="image_position")

point_embedding_with_tokens = ClassToken(num_tokens=num_mask_tokens + 1, name="cls_token")(point_embedding)
iou_masks, encoded_image_embedding = two_way_transformer(image_embedding, image_position, point_embedding_with_tokens, name="attn_")
iou_masks, encoded_image_embedding = two_way_transformer(image_embedding, image_position, point_embedding_with_tokens, activation="relu", name="attn_")
# print(f"{iou_masks.shape = }, {encoded_image_embedding.shape = }")

# output_upscaling
nn = encoded_image_embedding if backend.image_data_format() == "channels_last" else layers.Permute([3, 1, 2])(encoded_image_embedding)
nn = layers.Conv2DTranspose(embed_dims // 4, kernel_size=2, strides=2, name="up_1_conv_transpose")(nn)
nn = layer_norm(nn, name="up_1_")
nn = layer_norm(nn, epsilon=LAYER_NORM_EPSILON, name="up_1_")
nn = activation_by_name(nn, activation=activation, name="up_1_")
nn = layers.Conv2DTranspose(embed_dims // 8, kernel_size=2, strides=2, name="up_2_conv_transpose")(nn)
nn = activation_by_name(nn, activation=activation, name="up_2_")
Expand All @@ -131,11 +124,11 @@ def MaskDecoder(embed_dims=256, num_mask_tokens=4, activation="relu", name="mask

iou_masks = functional.split(iou_masks, [5, -1], axis=1)[0]
iou_token_out, masks_top, masks_left, masks_bottom, masks_right = functional.unstack(iou_masks, axis=1)
iou_pred = mlp_block_3(iou_token_out, hidden_dim=embed_dims, output_channel=num_mask_tokens, activation=activation, name="iou_pred_")
iou_pred = mlp_block_multi(iou_token_out, embed_dims, output_channel=num_mask_tokens, num_blocks=3, activation="relu", name="iou_pred_")

hyper_in = []
for id, (ii, name) in enumerate(zip([masks_top, masks_left, masks_bottom, masks_right], ["top", "left", "bottom", "right"])):
hyper_in.append(mlp_block_3(ii, hidden_dim=embed_dims, output_channel=embed_dims // 8, activation=activation, name="masks_" + name + "_"))
for id, (ii, sub_name) in enumerate(zip([masks_top, masks_left, masks_bottom, masks_right], ["top", "left", "bottom", "right"])):
hyper_in.append(mlp_block_multi(ii, embed_dims, output_channel=embed_dims // 8, num_blocks=3, activation="relu", name="masks_" + sub_name + "_"))
# print(f"{[ii.shape for ii in hyper_in] = }")
hyper_in = functional.stack(hyper_in, axis=1)

Expand All @@ -144,4 +137,7 @@ def MaskDecoder(embed_dims=256, num_mask_tokens=4, activation="relu", name="mask
# print(f"{masks.shape = }")
# [batch, 4, height * width] -> [batch, 4, height, width] -> [batch, height, width, 4], outputs channels_last also for PyTorch backend
masks = layers.Permute([2, 3, 1])(functional.reshape(masks, [-1, masks.shape[1], *pre_shape]))
return models.Model([image_embedding, point_embedding, image_position], [masks, iou_pred], name=name)

model = models.Model([image_embedding, point_embedding, image_position], [masks, iou_pred], name=name)
reload_model_weights(model, PRETRAINED_DICT, "segment_anything", pretrained)
return model
Loading

0 comments on commit a5b3c23

Please sign in to comment.