Skip to content

Commit 47a089a

Browse files
authored
Install torch from pip (#1058)
* Install torch packages from pip. Conda takes too long to converge. * Remove basemap Basemap has been deprecated in favor of cartopy: https://github.com/matplotlib/basemap#basemap * remove conda command for torch in gpu file * specify cudatoolkit version * Use * vars in torch version
1 parent b7d1851 commit 47a089a

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

Dockerfile

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,15 @@ ENV PROJ_LIB=/opt/conda/share/proj
4040
# Using the same global consistent ordered list of channels
4141
RUN conda config --add channels conda-forge && \
4242
conda config --add channels nvidia && \
43-
conda config --add channels pytorch && \
4443
conda config --add channels rapidsai && \
4544
# ^ rapidsai is the highest priority channel, default lowest, conda-forge 2nd lowest.
46-
# b/182405233 pyproj 3.x is not compatible with basemap 1.2.1
4745
# b/161473620#comment7 pin required to prevent resolver from picking pysal 1.x., pysal 2.2.x is also downloading data on import.
48-
conda install basemap cartopy imagemagick pyproj "pysal==2.1.0" && \
49-
conda install "pytorch=1.7" "torchvision=0.8" "torchaudio=0.7" "torchtext=0.8" cpuonly && \
46+
conda install cartopy=0.19 imagemagick=7.0 pyproj==3.1.0 pysal==2.1.0 && \
47+
/tmp/clean-layer.sh
48+
49+
RUN pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 torchtext==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html && \
5050
/tmp/clean-layer.sh
5151

52-
# The anaconda base image includes outdated versions of these packages. Update them to include the latest version.
5352
RUN pip install seaborn python-dateutil dask python-igraph && \
5453
pip install pyyaml joblib husl geopy ml_metrics mne pyshp && \
5554
pip install pandas && \

gpu.Dockerfile

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo libboost-all-dev && \
5555
# the remaining pip commands: https://www.anaconda.com/using-pip-in-a-conda-environment/
5656
# However, because this image is based on the CPU image, this isn't possible but better
5757
# to put them at the top of this file to minize conflicts.
58-
RUN conda remove --force -y pytorch torchvision torchaudio torchtext cpuonly && \
59-
conda install "pytorch=1.7" "torchvision=0.8" "torchaudio=0.7" "torchtext=0.8" cudatoolkit=$CUDA_VERSION && \
60-
conda install "cudf=21.06" "cuml=21.06" && \
58+
RUN conda install cudf=21.06 cuml=21.06 cudatoolkit=$CUDA_VERSION && \
59+
/tmp/clean-layer.sh
60+
61+
# Install Pytorch and torchvision with GPU support.
62+
# Note: torchtext and torchaudio do not require a separate GPU package.
63+
RUN pip install torch==1.7.1+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION torchvision==0.8.2+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
6164
/tmp/clean-layer.sh
6265

6366
# Install LightGBM with GPU

tests/test_matplotlib.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,9 @@
44
import matplotlib.pyplot as plt
55
import numpy as np
66

7-
from mpl_toolkits.basemap import Basemap
8-
97
class TestMatplotlib(unittest.TestCase):
108
def test_plot(self):
119
plt.plot(np.linspace(0,1,50), np.random.rand(50))
1210
plt.savefig("plot1.png")
1311

1412
self.assertTrue(os.path.isfile("plot1.png"))
15-
16-
def test_basemap(self):
17-
m = Basemap(width=100,height=100,projection='aeqd', lat_0=40,lon_0=-105)
18-
self.assertEqual(0, m.xmin)

0 commit comments

Comments
 (0)