Skip to content

Commit 5e04fe8

Browse files
committed
Preparing for PyTorch v1.0.0
1 parent 2ba2eaf commit 5e04fe8

File tree

8 files changed

+66
-38
lines changed

8 files changed

+66
-38
lines changed

.travis.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
dist: trusty
2-
sudo: false
1+
dist: xenial
2+
sudo: true
33
language: python
44

55
os:
@@ -14,9 +14,10 @@ addons:
1414
- g++-4.9
1515

1616
python:
17-
- '2.7'
18-
- '3.5'
19-
- '3.6'
17+
- 2.7
18+
- 3.5
19+
- 3.6
20+
- 3.7
2021

2122
env:
2223
- CC=gcc-4.9 CXX=g++-4.9
@@ -25,7 +26,7 @@ before_install:
2526
# Upgrade PIP to latest version, in order to support --progres-bar
2627
- pip install -U pip
2728
# Manually install torch, and pybind11 since they are required in the setup
28-
- pip install torch==0.4.1 --progress-bar off
29+
- pip install torch==1.0.0 --progress-bar off
2930
- pip install pybind11 --progress-bar off
3031

3132
install:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ loss4_2 = ctc(x, y, xs, ys)
6666

6767
- C++11 compiler (tested with GCC 4.9).
6868
- Python: 2.7, 3.5, 3.6, 3.7 (tested with versions 2.7, 3.5 and 3.6).
69-
- [PyTorch](http://pytorch.org/) >= 0.4.1 (tested with version 0.4.1).
69+
- [PyTorch](http://pytorch.org/) >= 1.0.0 (tested with version 1.0.0).
7070
- For GPU support: [CUDA Toolkit](https://developer.nvidia.com/cuda-zone).
7171

7272
## Installation

setup.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def get_cuda_compile_archs(nvcc_flags=None):
5757
"third-party/warp-ctc/src/ctc_entrypoint.cu",
5858
"third-party/warp-ctc/src/reduce.cu",
5959
]
60+
extra_compile_args["cxx"].append("-DWITH_CUDA")
61+
extra_compile_args["nvcc"].append("-DWITH_CUDA")
6062
extra_compile_args["nvcc"].extend(get_cuda_compile_archs())
6163
Extension = CUDAExtension
6264
else:
@@ -66,7 +68,7 @@ def get_cuda_compile_archs(nvcc_flags=None):
6668

6769
setup(
6870
name="torch-baidu-ctc",
69-
version="0.1.1",
71+
version="0.2.0",
7072
description="PyTorch bindings for Baidu Warp-CTC",
7173
long_description=io.open("README.md", "r").read(),
7274
long_description_content_type="text/markdown",
@@ -102,6 +104,6 @@ def get_cuda_compile_archs(nvcc_flags=None):
102104
"Topic :: Software Development :: Libraries",
103105
"Topic :: Software Development :: Libraries :: Python Modules",
104106
],
105-
setup_requires=["pybind11", "torch>=0.4.1"],
106-
install_requires=["pybind11", "torch>=0.4.1"],
107+
setup_requires=["pybind11", "torch>=1.0.0"],
108+
install_requires=["pybind11", "torch>=1.0.0"],
107109
)

src/binding.cc

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,29 @@
44
#include <sstream>
55
#include <string>
66

7-
#include <torch/torch.h>
7+
#include <torch/extension.h>
88
#include <ctc.h>
99

1010
#include <ATen/Context.h>
1111
#include <ATen/CPUGeneral.h>
1212
#include <ATen/Device.h>
1313
#include <ATen/DeviceGuard.h>
1414

15+
#ifdef WITH_CUDA
16+
#include <THC/THC.h>
17+
#endif
1518

16-
#define CHECK_CONTIGUOUS(x) \
19+
20+
#define CHECK_CONTIGUOUS(x) \
1721
AT_CHECK((x).is_contiguous(), #x " must be contiguous")
1822

1923
#define CHECK_CPU(x) \
20-
AT_CHECK((x).device().type() == at::Device::Type::CPU, \
24+
AT_CHECK((x).device().type() == c10::Device::Type::CPU, \
2125
#x " must be located in the CPU")
2226

2327
#define CHECK_CPU_OR_CUDA(x) \
24-
AT_CHECK(((x).device().type() == at::Device::Type::CPU || \
25-
(x).device().type() == at::Device::Type::CUDA), \
28+
AT_CHECK(((x).device().type() == c10::Device::Type::CPU || \
29+
(x).device().type() == c10::Device::Type::CUDA), \
2630
#x " must be located in the CPU or a CUDA device")
2731

2832
#define CHECK_FLOAT(x) \
@@ -87,14 +91,17 @@ std::tuple<at::Tensor, at::Tensor> ctc_loss(
8791
ctcOptions ctc_opts;
8892
memset(&ctc_opts, 0, sizeof(ctcOptions));
8993
ctc_opts.blank_label = blank_label;
90-
if (x.device().type() == at::Device::Type::CPU) {
94+
if (x.device().type() == c10::Device::Type::CPU) {
9195
ctc_opts.loc = CTC_CPU;
9296
ctc_opts.num_threads = std::max<unsigned int>(at::get_num_threads(), 0);
93-
} else {
97+
#ifdef WITH_CUDA
98+
} else if (x.device().type() == c10::Device::Type::CUDA) {
9499
ctc_opts.loc = CTC_GPU;
95-
const auto index = x.device().index();
96100
ctc_opts.stream =
97-
at::globalContext().getCurrentCUDAStreamOnDevice(index).stream();
101+
THCState_getCurrentStream(at::globalContext().getTHCState());
102+
#endif
103+
} else {
104+
AT_ERROR("ctc_loss not implemented for the given device type");
98105
}
99106

100107
// Allocate workspace memory
@@ -105,11 +112,11 @@ std::tuple<at::Tensor, at::Tensor> ctc_loss(
105112
&workspace_size));
106113

107114
at::TensorOptions workspace_opts(x.device());
108-
workspace_opts.dtype(at::ScalarType::Byte);
115+
workspace_opts = workspace_opts.dtype(at::ScalarType::Byte);
109116
at::Tensor workspace =
110117
at::zeros({static_cast<int64_t>(workspace_size * 10)}, workspace_opts);
111118

112-
at::DeviceGuard device_guard(x.device());
119+
c10::DeviceGuard device_guard(x.device());
113120
CHECK_WARP_CTC_CALL(
114121
compute_ctc_loss(
115122
x.data<float>(),

wheels/create_wheels_cpu.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ done;
5252
set +x;
5353
ODIR="/host/tmp/pytorch_baidu_ctc/whl/cpu";
5454
mkdir -p "$ODIR";
55-
cp /tmp/src/dist/*.whl "$ODIR/";
55+
readarray -t wheels < <(find /tmp/src/dist -name "*.whl");
56+
for whl in "${wheels[@]}"; do
57+
whl_name="$(basename "$whl")";
58+
whl_name="${whl_name/-linux/-manylinux1}";
59+
cp "$whl" "${ODIR}/${whl_name}";
60+
done;
61+
5662
echo "================================================================";
57-
printf "=== %-56s ===\n" "Copied wheels to ${ODIR:5}";
63+
printf "=== %-56s ===\n" "Copied ${#wheels[@]} wheels to ${ODIR:5}";
5864
echo "================================================================";

wheels/create_wheels_cuda.sh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ SOURCE_DIR=$(cd $SDIR/.. && pwd);
1010

1111
if [ ! -f /.dockerenv ]; then
1212
DOCKER_IMAGES=(
13-
soumith/manylinux-cuda80
14-
soumith/manylinux-cuda90
15-
soumith/manylinux-cuda92
13+
# soumith/manylinux-cuda80
14+
# soumith/manylinux-cuda90
15+
soumith/manylinux-cuda100
1616
);
1717
for image in "${DOCKER_IMAGES[@]}"; do
1818
docker run --runtime=nvidia --rm --log-driver none \
@@ -61,7 +61,13 @@ wheels/fix_deps.sh \
6161
"libcudart.so.${CUDA_VERSION}" \
6262
"/usr/local/cuda-${CUDA_VERSION}/lib64/libcudart.so.${CUDA_VERSION}";
6363

64-
rm -rf /opt/rh /usr/local/cuda*;
64+
# Remove CUDA, since all dependencies should be included.
65+
# TODO: pip package of PyTorch 1.0.0 for CUDA 10 is not well built, we
66+
# need CUDA installed!
67+
if [ ${CUDA_VERSION} != "10.0" ]; then
68+
rm -rf /opt/rh /usr/local/cuda*;
69+
fi;
70+
6571
for py in cp27-cp27mu cp35-cp35m cp36-cp36m cp37-cp37m; do
6672
echo "=== Testing wheel for $py with CUDA ${CUDA_VERSION} ===";
6773
export PYTHON=/opt/python/$py/bin/python;
@@ -75,7 +81,13 @@ done;
7581
set +x;
7682
ODIR="/host/tmp/pytorch_baidu_ctc/whl/${CUDA_VERSION_S}";
7783
mkdir -p "$ODIR";
78-
cp /tmp/src/dist/*.whl "$ODIR/";
84+
readarray -t wheels < <(find /tmp/src/dist -name "*.whl");
85+
for whl in "${wheels[@]}"; do
86+
whl_name="$(basename "$whl")";
87+
whl_name="${whl_name/-linux/-manylinux1}";
88+
cp "$whl" "${ODIR}/${whl_name}";
89+
done;
90+
7991
echo "================================================================";
80-
printf "=== %-56s ===\n" "Copied wheels to ${ODIR:5}";
92+
printf "=== %-56s ===\n" "Copied ${#wheels[@]} wheels to ${ODIR:5}";
8193
echo "================================================================";

wheels/install_pytorch_cpu.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ PYTHON_VERSIONS=(
88
cp37-cp37m
99
);
1010
PYTORCH_WHEELS=(
11-
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
12-
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl
13-
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
14-
http://download.pytorch.org/whl/cpu/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl
11+
http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
12+
http://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl
13+
http://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl
14+
http://download.pytorch.org/whl/cpu/torch-1.0.0-cp37-cp37m-linux_x86_64.whl
1515
);
1616

1717
for i in $(seq ${#PYTHON_VERSIONS[@]}); do

wheels/install_pytorch_cuda.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
set -ex;
2+
set -e;
33

44
PYTORCH_WHL_PREFIX="http://download.pytorch.org/whl/${CUDA_VERSION_S}";
55
PYTHON_VERSIONS=(
@@ -9,10 +9,10 @@ PYTHON_VERSIONS=(
99
cp37-cp37m
1010
);
1111
PYTORCH_WHEELS=(
12-
${PYTORCH_WHL_PREFIX}/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
13-
${PYTORCH_WHL_PREFIX}/torch-0.4.1-cp35-cp35m-linux_x86_64.whl
14-
${PYTORCH_WHL_PREFIX}/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
15-
${PYTORCH_WHL_PREFIX}/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl
12+
${PYTORCH_WHL_PREFIX}/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
13+
${PYTORCH_WHL_PREFIX}/torch-1.0.0-cp35-cp35m-linux_x86_64.whl
14+
${PYTORCH_WHL_PREFIX}/torch-1.0.0-cp36-cp36m-linux_x86_64.whl
15+
${PYTORCH_WHL_PREFIX}/torch-1.0.0-cp37-cp37m-linux_x86_64.whl
1616
);
1717

1818
for i in $(seq ${#PYTHON_VERSIONS[@]}); do

0 commit comments

Comments
 (0)