diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 7338714..fe54d29 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -28,7 +28,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest python -m pip install wheel - python -m pip install torch==2.5.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu + python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Test with pytest run: | diff --git a/vit_pytorch/na_vit_nested_tensor.py b/vit_pytorch/na_vit_nested_tensor.py index 6084e6c..04882c8 100644 --- a/vit_pytorch/na_vit_nested_tensor.py +++ b/vit_pytorch/na_vit_nested_tensor.py @@ -6,9 +6,6 @@ import torch import packaging.version as pkg_version -if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): - print('nested tensor NaViT was tested on pytorch 2.5') - from torch import nn, Tensor import torch.nn.functional as F from torch.nn import Module, ModuleList @@ -152,6 +149,11 @@ def __init__( token_dropout_prob: float | None = None ): super().__init__() + + if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): + print('nested tensor NaViT was tested on pytorch 2.5') + + image_height, image_width = pair(image_size) # what percent of tokens to dropout diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index 7722c0f..1f6ab59 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -6,9 +6,6 @@ import torch import packaging.version as pkg_version -if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): - print('nested tensor NaViT was tested on pytorch 2.5') - from torch import nn, Tensor import torch.nn.functional as F from torch.nn import Module, ModuleList @@ -169,6 +166,9 @@ def __init__( super().__init__() image_height, image_width = pair(image_size) + if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): + print('nested tensor NaViT was tested on pytorch 2.5') + # what percent of tokens to dropout # if int or float given, then assume constant dropout prob # otherwise accept a callback that in turn calculates dropout prob from height and width