diff --git a/sphericart-torch/include/sphericart/torch.hpp b/sphericart-torch/include/sphericart/torch.hpp index 6bc814e8..e4fb51c0 100644 --- a/sphericart-torch/include/sphericart/torch.hpp +++ b/sphericart-torch/include/sphericart/torch.hpp @@ -1,7 +1,7 @@ #ifndef SPHERICART_TORCH_HPP #define SPHERICART_TORCH_HPP -#include +#include #include @@ -45,8 +45,8 @@ class SphericalHarmonics : public torch::CustomClassHolder { sphericart::SphericalHarmonics calculator_float_; // CUDA implementation - sphericart::cuda::SphericalHarmonics calculator_cuda_double_; - sphericart::cuda::SphericalHarmonics calculator_cuda_float_; + std::unique_ptr> calculator_cuda_double_ptr; + std::unique_ptr> calculator_cuda_float_ptr; }; } // namespace sphericart_torch diff --git a/sphericart-torch/src/autograd.cpp b/sphericart-torch/src/autograd.cpp index 74498703..eb89e337 100644 --- a/sphericart-torch/src/autograd.cpp +++ b/sphericart-torch/src/autograd.cpp @@ -233,7 +233,7 @@ torch::autograd::variable_list SphericalHarmonicsAutograd::forward( } if (xyz.dtype() == c10::kDouble) { - calculator.calculator_cuda_double_.compute( + calculator.calculator_cuda_double_ptr->compute( xyz.data_ptr(), xyz.size(0), requires_grad, @@ -245,7 +245,7 @@ torch::autograd::variable_list SphericalHarmonicsAutograd::forward( ); } else if (xyz.dtype() == c10::kFloat) { - calculator.calculator_cuda_float_.compute( + calculator.calculator_cuda_float_ptr->compute( xyz.data_ptr(), xyz.size(0), requires_grad, diff --git a/sphericart-torch/src/torch.cpp b/sphericart-torch/src/torch.cpp index cc734b51..ba832544 100644 --- a/sphericart-torch/src/torch.cpp +++ b/sphericart-torch/src/torch.cpp @@ -1,20 +1,25 @@ - -#include +#include #include "sphericart/torch.hpp" #include "sphericart/autograd.hpp" +using namespace torch; using namespace sphericart_torch; +using namespace std; + SphericalHarmonics::SphericalHarmonics(int64_t l_max, bool normalized, bool backward_second_derivatives) : l_max_(l_max), normalized_(normalized), backward_second_derivatives_(backward_second_derivatives), + calculator_double_(l_max_, normalized_), calculator_float_(l_max_, normalized_) { + this->omp_num_threads_ = calculator_double_.get_omp_num_threads(); - calculator_double_(l_max_, normalized_), calculator_float_(l_max_, normalized_), - - calculator_cuda_double_(l_max_, normalized_), calculator_cuda_float_(l_max_, normalized_) //, + if (torch::cuda::is_available()) { + this->calculator_cuda_double_ptr = + std::make_unique>(l_max_, normalized_); -{ - this->omp_num_threads_ = calculator_double_.get_omp_num_threads(); + this->calculator_cuda_float_ptr = + std::make_unique>(l_max_, normalized_); + } } torch::Tensor SphericalHarmonics::compute(torch::Tensor xyz) { diff --git a/sphericart/include/sphericart_cuda.hpp b/sphericart/include/sphericart_cuda.hpp index 63d6d046..a9e47e9f 100644 --- a/sphericart/include/sphericart_cuda.hpp +++ b/sphericart/include/sphericart_cuda.hpp @@ -41,6 +41,11 @@ template class SphericalHarmonics { */ SphericalHarmonics(size_t l_max, bool normalized = false); + /** Default constructor + * Required so sphericart_torch can conditionally instantiate this class + * depending on if cuda is available. + */ + /* @cond */ ~SphericalHarmonics(); /* @endcond */ @@ -95,14 +100,12 @@ template class SphericalHarmonics { private: size_t l_max; // maximum l value computed by this class size_t nprefactors; - bool normalized; // should we normalize the input vectors? - T* prefactors_cpu; // host prefactors buffer - T** prefactors_cuda; // storage space for prefactors - int device_count; // number of visible GPU devices - + bool normalized; // should we normalize the input vectors? + T* prefactors_cpu = nullptr; // host prefactors buffer + T** prefactors_cuda = nullptr; // storage space for prefactors + int device_count = 0; // number of visible GPU devices int64_t CUDA_GRID_DIM_X_ = 8; int64_t CUDA_GRID_DIM_Y_ = 8; - bool cached_compute_with_gradients = false; bool cached_compute_with_hessian = false; int64_t _current_shared_mem_allocation = 0; diff --git a/sphericart/src/cuda_base.cu b/sphericart/src/cuda_base.cu index 1c56186c..1e70c19d 100644 --- a/sphericart/src/cuda_base.cu +++ b/sphericart/src/cuda_base.cu @@ -900,7 +900,6 @@ void sphericart::cuda::spherical_harmonics_backward_cuda_base( scalar_t* __restrict__ xyz_grad, void* cuda_stream ) { - dim3 grid_dim(4, 32); auto find_num_blocks = [](int x, int bdim) { return (x + bdim - 1) / bdim; }; diff --git a/sphericart/src/sphericart_cuda.cu b/sphericart/src/sphericart_cuda.cu index d05b7a28..a894d7c9 100644 --- a/sphericart/src/sphericart_cuda.cu +++ b/sphericart/src/sphericart_cuda.cu @@ -30,66 +30,78 @@ template SphericalHarmonics::SphericalHarmonics(size_t l_max, bo buffer space, compute prefactors, and sets the function pointers that are used for the actual calls */ - this->l_max = (int)l_max; this->nprefactors = (int)(l_max + 1) * (l_max + 2); this->normalized = normalized; this->prefactors_cpu = new T[this->nprefactors]; + CUDA_CHECK(cudaGetDeviceCount(&this->device_count)); + // compute prefactors on host first compute_sph_prefactors((int)l_max, this->prefactors_cpu); - CUDA_CHECK(cudaGetDeviceCount(&this->device_count)); + if (this->device_count) { + int current_device; - int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); - CUDA_CHECK(cudaGetDevice(¤t_device)); + // allocate prefactorts on every visible device and copy from host + this->prefactors_cuda = new T*[this->device_count]; - // allocate prefactorts on every visible device and copy from host - this->prefactors_cuda = new T*[this->device_count]; + for (int device = 0; device < this->device_count; device++) { + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK( + cudaMalloc((void**)&this->prefactors_cuda[device], this->nprefactors * sizeof(T)) + ); + CUDA_CHECK(cudaMemcpy( + this->prefactors_cuda[device], + this->prefactors_cpu, + this->nprefactors * sizeof(T), + cudaMemcpyHostToDevice + )); + } - for (int device = 0; device < this->device_count; device++) { - CUDA_CHECK(cudaSetDevice(device)); - CUDA_CHECK(cudaMalloc((void**)&this->prefactors_cuda[device], this->nprefactors * sizeof(T)) + // initialise the currently available amount of shared memory on all visible devices + this->_current_shared_mem_allocation = adjust_shared_memory( + sizeof(T), + this->l_max, + this->CUDA_GRID_DIM_X_, + this->CUDA_GRID_DIM_Y_, + false, + false, + this->_current_shared_mem_allocation ); - CUDA_CHECK(cudaMemcpy( - this->prefactors_cuda[device], - this->prefactors_cpu, - this->nprefactors * sizeof(T), - cudaMemcpyHostToDevice - )); - } - // initialise the currently available amount of shared memory on all visible devices - this->_current_shared_mem_allocation = adjust_shared_memory( - sizeof(T), - this->l_max, - this->CUDA_GRID_DIM_X_, - this->CUDA_GRID_DIM_Y_, - false, - false, - this->_current_shared_mem_allocation - ); - - // set the context back to the current device - CUDA_CHECK(cudaSetDevice(current_device)); + // set the context back to the current device + CUDA_CHECK(cudaSetDevice(current_device)); + } } template SphericalHarmonics::~SphericalHarmonics() { // Destructor, frees the prefactors - delete[] (this->prefactors_cpu); + if (this->prefactors_cpu != nullptr) { + delete[] (this->prefactors_cpu); + this->prefactors_cpu = nullptr; + } - int current_device; + if (this->device_count > 0) { - CUDA_CHECK(cudaGetDevice(¤t_device)); + int current_device; - for (int device = 0; device < this->device_count; device++) { - CUDA_CHECK(cudaSetDevice(device)); - CUDA_CHECK(cudaDeviceSynchronize()); - CUDA_CHECK(cudaFree(this->prefactors_cuda[device])); - } + CUDA_CHECK(cudaGetDevice(¤t_device)); - CUDA_CHECK(cudaSetDevice(current_device)); + for (int device = 0; device < this->device_count; device++) { + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaDeviceSynchronize()); + if (this->prefactors_cuda != nullptr && this->prefactors_cuda[device] != nullptr) { + CUDA_CHECK(cudaFree(this->prefactors_cuda[device])); + this->prefactors_cuda[device] = nullptr; + } + } + this->prefactors_cuda = nullptr; + + CUDA_CHECK(cudaSetDevice(current_device)); + } } template