Skip to content

Commit

Permalink
update comment for navit 3d
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2024
1 parent 141239c commit 6693d47
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 2 additions & 0 deletions vit_pytorch/na_vit_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,5 @@ def forward(
]

assert v(images).shape == (5, 1000)

v(images).sum().backward()
4 changes: 3 additions & 1 deletion vit_pytorch/na_vit_nested_tensor_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def forward(

if __name__ == '__main__':

# works for torch 2.4
# works for torch 2.5

v = NaViT(
image_size = 256,
Expand All @@ -362,3 +362,5 @@ def forward(
]

assert v(volumes).shape == (5, 1000)

v(volumes).sum().backward()

0 comments on commit 6693d47

Please sign in to comment.