-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MaskedVisionTransformerDecoder #1615
Comments
@guarin I would be open to try working on this feature if I understand the requirements more specifically. I really like the open-source initiative at Lightly and would like to contribute to this repo more. So here are my questions:
Based on the repository, I feel like more can be done with restructuring.
I can then start working on a PR based on your response. Since this is still in an ideation phase, I would start with a draft PR before moving onto a full blown PR with reviews. Please do tell me what you think. |
Hi!
Yes, the idea was if we can write a single
We would then also move the prediction part out of the
Yes exactly, that is the concern. But TBH this was just an idea and I didn't really have time to investigate this further. Maybe there is a nice way to do it or maybe it is better if we just have two separate implementations.
Could you explain this further? I don't fully understand what you mean.
We have plans to refactor the package quite a bit, will take this into account as well.
If you could just have a look whether it is possible to have a shared |
Hey @guarin thanks for the swift response. I will take a look into the feasibility of the idea and detail out my findings here. Is it possible for you to assign this case to me?
I am sorry for not explaining this further, but I was under the idea that the team was closely following the factory based pattern. And for this, I think it is best to separate abstract methods into its own directory, rather than have it all in the modules subdirectory. This is simply for better readability. Simple articles that I used when reading about this: To elaborate, the Models subdirectory Additionally, one can also create decorators which ensures the registry of sub-classes to the base modules to an additional layer of security, but I think, this can be a future feature. |
Thanks for the extra info! We don't strictly follow the factory based pattern. We have some plans to update the package structure but they are still under discussion. Our goal is to make the package as easy as possible to use for research. This means that it should be straight forward to understand and adapt. We'll most likely try to reduce the usage of advanced design patterns to keep the code as simple as possible. |
The findings from my initial analysis: It is worth considering using an abstract class because there are several conceptual and functional similarities.
But of course, it is not all the same. I-JEPA is strictly using and calling this a predictor and not a decoder, which is what the MAE calls this block. Quoting from the I-JEPA paper (https://arxiv.org/pdf/2301.08243):
This clearly shows that the authors agree on the similarity of the predictor/encoder with the MAE architecture. Another point to note is that both the I-JEPA predictor and MAE decoder are shallow or lightweight. @guarin you have already mentioned a rough idea of how the abstract class would look like. But again, based on the discussions we have had in this thread, it is clear that although we will have an abstract class for the decoder/predictor, we would still need to have individual classes for the I-JEPA predictor and MAE decoder. Although not concrete, the base class for both the predictor and decoder would look like: class MaskedVisionTransformerDecoder(nn.Module, ABC):
def __init__(
self,
num_patches: int,
embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
drop_path_rate: float,
proj_drop_rate: float,
attn_drop_rate: float,
norm_layer: Callable[..., nn.Module],
):
super().__init__()
self.embed = nn.Linear(embed_dim, embed_dim, bias=True)
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim), requires_grad=False
)
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop_path=drop_path_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
for _ in range(depth)
]
)
self.norm = norm_layer(embed_dim)
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Abstract forward method, to be implemented by subclasses."""
pass
def apply_transformer_blocks(self, x: torch.Tensor) -> torch.Tensor:
for blk in self.blocks:
x = blk(x)
return self.norm(x)
# The successive predictor/decoder logic
class IJEPAPredictorTIMM(BasePredictorDecoder):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x = self.apply_transformer_blocks(x)
# IJEPAPredictor-specific logic
return x
class MAEDecoderTIMM(BasePredictorDecoder):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x += self.pos_embed # Adding positional embeddings
x = self.apply_transformer_blocks(x)
# MAEDecoder-specific logic for prediction
return x Pros and Cons:
|
We already have
MaskedVisionTransformer
classes for TIMM and torchvision that take images as input and output a token for every image patch. These classes are versatile as they can be used for many different methods (I-JEPA, DINOv2, MAE, etc.). Some of these methods also require a ViT decoder (I-JEPA and MAE) which we currently implement independently for every method (we haveIJEPAPredictor
andMAEDecoder
). These decoders have very similar structures and I think we could cover them in a singleMaskedVisionTransformerDecoder
class.The class would have the following interface:
This would allow us to implement MAE like this:
And I-JEPA like this:
By sharing the
MaskedVisionTransformerDecoder
class we can deduplicate a lot of the code around positional embeddings, transformer blocks, and masking. And by moving theembed
andprediction_head
layers out of the decoder, the decoder class becomes more modular and easier to reuse. We cannot re-useMaskedVisionTransformer
directly because it expects images instead of tokens as input.I am not yet 100% sure that this is possible. Especially for I-JEPA there is some funky masking logic that might be hard to generalize. Could be worth a try though.
The text was updated successfully, but these errors were encountered: