diff --git a/docker/unittest.Dockerfile b/docker/unittest.Dockerfile index f16b18398..49e19da5d 100644 --- a/docker/unittest.Dockerfile +++ b/docker/unittest.Dockerfile @@ -23,6 +23,18 @@ RUN source python3.9-env/bin/activate && pip install --upgrade pip \ tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ pybind11 ray[default] matplotlib +# Install PyTorch dependencies +RUN git clone https://github.com/pytorch/functorch /functorch +RUN source python3.7-env/bin/activate \ + && pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + && pushd /functorch && python setup.py install && popd +RUN source python3.8-env/bin/activate \ + && pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + && pushd /functorch && python setup.py install && popd +RUN source python3.9-env/bin/activate \ + && pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + && pushd /functorch && python setup.py install && popd + # We determine the CUDA version at `docker build ...` phase ARG JAX_CUDA_VERSION=11.1 COPY scripts/install_cuda.sh /install_cuda.sh diff --git a/docs/install.rst b/docs/install.rst index 97e337b07..63e4e9b96 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -144,7 +144,7 @@ To enable Alpa for PyTorch, install the following dependencies: # Install nightly version of torch and torchdistx pip3 uninstall -y torch torchdistx - pip install torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu + pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu # Build functorch from source git clone https://github.com/pytorch/functorch diff --git a/tests/torch/test_torch_dict_input.py b/tests/test_torch_dict_input.py similarity index 100% rename from tests/torch/test_torch_dict_input.py rename to tests/test_torch_dict_input.py diff --git a/tests/torch/test_torch_reshape.py b/tests/test_torch_reshape.py similarity index 100% rename from tests/torch/test_torch_reshape.py rename to tests/test_torch_reshape.py diff --git a/tests/torch/test_torch_simple.py b/tests/test_torch_simple.py similarity index 100% rename from tests/torch/test_torch_simple.py rename to tests/test_torch_simple.py diff --git a/tests/torch/test_torch_zhen.py b/tests/test_torch_zhen.py similarity index 100% rename from tests/torch/test_torch_zhen.py rename to tests/test_torch_zhen.py