Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Port Gamma to CUDA #102

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion .jenkins/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,33 @@ python --version

pip install -r requirements.txt || true

time python setup.py install
if [[ "$JOB_NAME" == *asan* ]]; then
export ASAN_OPTIONS=detect_leaks=0:symbolize=1
# Disable Valgrind tests in run_aten_tests.sh; otherwise
# we'll be valgrind'ing an ASAN'ed binary! ASANity.
export VALGRIND=0

sudo apt-get update
sudo apt-get install clang-5.0

export PATH="/usr/lib/llvm-5.0/bin:$PATH"

# TODO: Figure out how to avoid hard-coding these paths
LD_LIBRARY_PATH=/usr/lib/llvm-5.0/lib/clang/5.0.0/lib/linux \
CC="sccache clang" \
CXX="sccache clang++" \
LDSHARED="clang --shared" \
LDFLAGS="-stdlib=libstdc++" \
CFLAGS="-fsanitize=address -shared-libasan" \
NO_CUDA=1 \
python setup.py install

export LD_PRELOAD=/usr/lib/llvm-5.0/lib/clang/5.0.0/lib/linux/libclang_rt.asan-x86_64.so

else
python setup.py install

fi

if [[ "$JOB_NAME" != *cuda* ]]; then
echo "Testing ATen"
Expand Down
4 changes: 0 additions & 4 deletions .jenkins/disabled-configs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,3 @@
# fail. You can use this to temporarily reserve a test name to
# turn on CI side before PyTorch repository supports it. This
# file has the same format as .jenkins/enabled-configs.txt

pytorch-linux-xenial-py3-clang5-asan
pytorch-linux-xenial-py3-clang5-asan-build
pytorch-linux-xenial-py3-clang5-asan-test
2 changes: 2 additions & 0 deletions .jenkins/enabled-configs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pytorch-linux-xenial-cuda9-cudnn7-py2-build
pytorch-linux-xenial-cuda9-cudnn7-py2-test
pytorch-linux-xenial-cuda9-cudnn7-py3-build
pytorch-linux-xenial-cuda9-cudnn7-py3-test
pytorch-linux-xenial-py3-clang5-asan-build
pytorch-linux-xenial-py3-clang5-asan-test
pytorch-linux-trusty-py2.7.9-build
pytorch-linux-trusty-py2.7.9-test
pytorch-linux-trusty-py2.7-build
Expand Down
4 changes: 2 additions & 2 deletions .jenkins/perf_test/perf_test_numbers.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
},

"test_gpu_speed_word_language_model": {
"mean": "5.65807",
"sigma": "0.1132"
"mean": "5.9411499999999995",
"sigma": "0.02134777505971057"
},

"test_gpu_speed_cudnn_lstm": {
Expand Down
10 changes: 10 additions & 0 deletions .jenkins/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ echo "Testing pytorch"
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4

if [[ "$JOB_NAME" == *asan* ]]; then
export PATH="/usr/lib/llvm-5.0/bin:$PATH"
export ASAN_OPTIONS=detect_leaks=0:symbolize=1
export PYTORCH_TEST_WITH_ASAN=1
fi

# JIT C++ extensions require ninja.
git clone https://github.com/ninja-build/ninja --quiet
pushd ninja
python ./configure.py --bootstrap
export PATH="$PWD:$PATH"
popd

if [[ "$JOB_NAME" == *asan* ]]; then
export LD_PRELOAD=/usr/lib/llvm-5.0/lib/clang/5.0.0/lib/linux/libclang_rt.asan-x86_64.so
fi

time test/run_test.sh -- -v

rm -rf ninja
Expand Down
12 changes: 10 additions & 2 deletions aten/cmake/FindCuDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@ include(FindPackageHandleStandardArgs)

set(CUDNN_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA cuDNN")

find_path(CUDNN_INCLUDE_DIR cudnn.h
if($ENV{CUDNN_INCLUDE_DIR})
SET(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR})
else($ENV{CUDNN_INCLUDE_DIR})
find_path(CUDNN_INCLUDE_DIR cudnn.h
HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES cuda/include include)
endif($ENV{CUDNN_INCLUDE_DIR})

find_library(CUDNN_LIBRARY cudnn
if($ENV{CUDNN_LIBRARY})
SET(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY})
else($ENV{CUDNN_LIBRARY})
find_library(CUDNN_LIBRARY cudnn
HINTS ${CUDNN_LIB_DIR} ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
endif($ENV{CUDNN_LIBRARY})

find_package_handle_standard_args(
CUDNN DEFAULT_MSG CUDNN_INCLUDE_DIR CUDNN_LIBRARY)
Expand Down
20 changes: 0 additions & 20 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3966,26 +3966,6 @@
kwarg_only: True
- THTensor* self
]]
[[
name: _standard_gamma
types:
- floating_point
backends:
- CPU
return: argument 0
variants:
- method
- function
options:
- cname: standard_gamma
arguments:
- arg: THTensor* output
output: True
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- THTensor* self
]]
[[
name: _dirichlet_grad
types:
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/SharedDist.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Dispatch.h"
#include "ATen/Config.h"

#include <nvfunctional>

namespace at {
namespace native {
namespace dist {
template<typename precision_t>
struct baseSampler {
nvstd::function<precision_t(void)> sampler;
baseSampler(nvstd::function<precision_t(void)> sampler): sampler(sampler) {}
precision_t sample() {
return sampler();
}
};
}
}
}

// this version is only linked if CUDA is enabled, so we can safely just use CUDA features here
80 changes: 77 additions & 3 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/Config.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"

#include "ATen/CPUGenerator.h"
#include "ATen/CheckGenerator.h"
#include "ATen/Generator.h"

#include <functional>

#include "TH/THRandom.h"

namespace at {
Expand Down Expand Up @@ -119,12 +122,23 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {

/*
* This section is a counterpart to Distributions.cu
*
*/

namespace dist {
// The function `sample_poisson`
// is adapted from Numpy's distributions.c implementation.

#if !AT_CUDA_ENABLED()
template<typename precision_t>
struct baseSampler {
std::function<precision_t(void)> sampler;
baseSampler(std::function<precision_t(void)> sampler): sampler(sampler) {}
precision_t sample() {
return sampler();
}
};
#endif

// The functions `sample_poisson`, `sample_gamma`
// are adapted from Numpy's distributions.c implementation.
// It is MIT licensed, so here is the copyright:

/* Copyright 2005 Robert Kern ([email protected])
Expand All @@ -149,12 +163,65 @@ namespace dist {
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/


template<typename precision_t>
#if AT_CUDA_ENABLED()
__host__ __device__
#endif
precision_t sample_gamma(precision_t alpha, baseSampler<precision_t>& standard_uniform, baseSampler<precision_t>& standard_normal) {

precision_t scale = 1.0;

// Boost alpha for higher acceptance probability.
if (alpha < 1.0) {
scale *= ::pow(1 - standard_uniform.sample(), 1.0 / alpha);
alpha += 1.0;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
// doi:10.1145/358407.358414
const precision_t d = alpha - 1.0 / 3.0;
const precision_t c = 1.0 / ::sqrt(9.0 * d);
for (;;) {
precision_t x, y;
do {
x = standard_normal.sample();
y = 1.0 + c * x;
} while (y <= 0);
const precision_t v = y * y * y;
const precision_t u = 1 - standard_uniform.sample();
const precision_t xx = x * x;
if (u < 1.0 - 0.0331 * xx * xx)
return scale * d * v;
if (::log(u) < 0.5 * xx + d * (1.0 - v + ::log(v)))
return scale * d * v;
}
}

THGenerator * get_generator(Generator *gen) {
auto default_gen = &at::globalContext().defaultGenerator(Backend::CPU);
auto gen_ = check_generator<CPUGenerator>(gen, default_gen);
return gen_->generator;
}

template <typename scalar>
struct GammaOp {
static void apply(Tensor& ret, const Tensor& alpha, THGenerator *generator) {
CPU_tensor_apply2<scalar, double>(ret, alpha,
[generator](scalar& ret_val, const double& alpha){
dist::baseSampler<float> standard_uniform([generator] () {
return THRandom_standard_uniform(generator);
});
dist::baseSampler<float> standard_normal([generator] () {
return THRandom_normal(generator, 0.0, 1.0);
});
auto sample = dist::sample_gamma<float>(alpha, standard_uniform, standard_normal);
ret_val = std::max(std::numeric_limits<scalar>::min(), (scalar) sample);
}
);
}
};

template <typename scalar>
struct PoissonOp {
static int64_t sample_poisson(double lambda, THGenerator *generator) {
Expand Down Expand Up @@ -227,5 +294,12 @@ Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
return ret;
}

Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
Tensor ret = alpha.type().zeros(alpha.sizes());
auto alpha_ = alpha.toType(ScalarType::Double);
dispatch_floating_types<void, dist::GammaOp>(ret.type(), "gamma", ret, alpha_, dist::get_generator(gen));
return ret;
}

} // at::native
} // at
35 changes: 35 additions & 0 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Dispatch.h"
#include "ATen/Config.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <utility>
#include <functional>
#include <nvfunctional>

#include "ATen/SharedDist.cu"
#include "ATen/native/Distributions.cpp"

#include <TH/THAtomic.h>

Expand All @@ -26,6 +34,26 @@ namespace dist {
return std::make_pair(gen_->initial_seed, offset);
}

template <typename scalar>
struct GammaOpCUDA {
static void apply(Tensor& ret, const Tensor& alpha, std::pair<uint64_t, uint64_t> seeds) {
at::cuda::CUDA_tensor_apply2<scalar, float>(ret, alpha,
[seeds] __device__ (scalar& ret_val, const float& alpha, bool early_exit) {
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state);
baseSampler<float> standard_uniform([&state] __device__ () {
return curand_uniform(&state);
});
baseSampler<float> standard_normal([&state] __device__ () {
return curand_normal(&state);
});
auto sample = scalar_cast<scalar>(sample_gamma<float>(alpha, standard_uniform, standard_normal));
ret_val = ::max(THCNumerics<scalar>::min(), (scalar) sample);
}
);
}
};

template <typename scalar>
struct PoissonOpCUDA {
static void apply(Tensor& ret, const Tensor& lambda, std::pair<uint64_t, uint64_t> seeds) {
Expand All @@ -48,5 +76,12 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
return ret;
}

Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
Tensor ret = alpha.type().tensor(alpha.sizes());
auto alpha_ = alpha.toType(ScalarType::Float);
dispatch_floating_types<void, dist::GammaOpCUDA>(ret.type(), "gamma", ret, alpha_, dist::next_philox_seed(gen));
return ret;
}

} // at::native
} // at
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@
CPU: _s_poisson_cpu
CUDA: _s_poisson_cuda

- func: standard_gamma(Tensor self, Generator* generator=nullptr) -> Tensor
variants: function
dispatch:
CPU: _s_gamma_cpu
CUDA: _s_gamma_cuda

- func: _cudnn_rnn_flatten_weight(TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, bool bidirectional) -> Tensor
variants: function

Expand Down
29 changes: 0 additions & 29 deletions aten/src/TH/THRandom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,35 +290,6 @@ double THRandom_exponential(THGenerator *_generator, double lambda)
return(-1. / lambda * log(1-uniform_double(_generator)));
}

double THRandom_standard_gamma(THGenerator *_generator, double alpha) {
double scale = 1.0;

// Boost alpha for higher acceptance probability.
if(alpha < 1.0) {
scale *= pow(1 - uniform_double(_generator), 1.0 / alpha);
alpha += 1.0;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
// doi:10.1145/358407.358414
const double d = alpha - 1.0 / 3.0;
const double c = 1.0 / sqrt(9.0 * d);
for(;;) {
double x, y;
do {
x = THRandom_normal(_generator, 0.0, 1.0);
y = 1.0 + c * x;
} while(y <= 0);
const double v = y * y * y;
const double u = 1 - uniform_double(_generator);
const double xx = x * x;
if(u < 1.0 - 0.0331 * xx * xx)
return scale * d * v;
if(log(u) < 0.5 * xx + d * (1.0 - v + log(v)))
return scale * d * v;
}
}

double THRandom_cauchy(THGenerator *_generator, double median, double sigma)
{
return(median + sigma * tan(M_PI*(uniform_double(_generator)-0.5)));
Expand Down
Loading