Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 10, 2024
1 parent 0449865 commit d47c57e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
python -m pip install torch==2.5.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
Expand Down
8 changes: 5 additions & 3 deletions vit_pytorch/na_vit_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import torch
import packaging.version as pkg_version

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')

from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
Expand Down Expand Up @@ -152,6 +149,11 @@ def __init__(
token_dropout_prob: float | None = None
):
super().__init__()

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')


image_height, image_width = pair(image_size)

# what percent of tokens to dropout
Expand Down
6 changes: 3 additions & 3 deletions vit_pytorch/na_vit_nested_tensor_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import torch
import packaging.version as pkg_version

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')

from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
Expand Down Expand Up @@ -169,6 +166,9 @@ def __init__(
super().__init__()
image_height, image_width = pair(image_size)

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')

# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
Expand Down

0 comments on commit d47c57e

Please sign in to comment.