Skip to content

Commit

Permalink
Conditional instantiation of cuda calculator (#118)
Browse files Browse the repository at this point in the history
Co-authored-by: frostedoyster <[email protected]>
  • Loading branch information
nickjbrowning and frostedoyster authored Jul 17, 2024
1 parent 3960152 commit 40f277e
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 57 deletions.
6 changes: 3 additions & 3 deletions sphericart-torch/include/sphericart/torch.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef SPHERICART_TORCH_HPP
#define SPHERICART_TORCH_HPP

#include <torch/script.h>
#include <torch/torch.h>

#include <mutex>

Expand Down Expand Up @@ -45,8 +45,8 @@ class SphericalHarmonics : public torch::CustomClassHolder {
sphericart::SphericalHarmonics<float> calculator_float_;

// CUDA implementation
sphericart::cuda::SphericalHarmonics<double> calculator_cuda_double_;
sphericart::cuda::SphericalHarmonics<float> calculator_cuda_float_;
std::unique_ptr<sphericart::cuda::SphericalHarmonics<double>> calculator_cuda_double_ptr;
std::unique_ptr<sphericart::cuda::SphericalHarmonics<float>> calculator_cuda_float_ptr;
};

} // namespace sphericart_torch
Expand Down
4 changes: 2 additions & 2 deletions sphericart-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ torch::autograd::variable_list SphericalHarmonicsAutograd::forward(
}

if (xyz.dtype() == c10::kDouble) {
calculator.calculator_cuda_double_.compute(
calculator.calculator_cuda_double_ptr->compute(
xyz.data_ptr<double>(),
xyz.size(0),
requires_grad,
Expand All @@ -245,7 +245,7 @@ torch::autograd::variable_list SphericalHarmonicsAutograd::forward(
);

} else if (xyz.dtype() == c10::kFloat) {
calculator.calculator_cuda_float_.compute(
calculator.calculator_cuda_float_ptr->compute(
xyz.data_ptr<float>(),
xyz.size(0),
requires_grad,
Expand Down
19 changes: 12 additions & 7 deletions sphericart-torch/src/torch.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@

#include <torch/script.h>
#include <torch/torch.h>

#include "sphericart/torch.hpp"
#include "sphericart/autograd.hpp"

using namespace torch;
using namespace sphericart_torch;
using namespace std;

SphericalHarmonics::SphericalHarmonics(int64_t l_max, bool normalized, bool backward_second_derivatives)
: l_max_(l_max), normalized_(normalized),
backward_second_derivatives_(backward_second_derivatives),
calculator_double_(l_max_, normalized_), calculator_float_(l_max_, normalized_) {
this->omp_num_threads_ = calculator_double_.get_omp_num_threads();

calculator_double_(l_max_, normalized_), calculator_float_(l_max_, normalized_),

calculator_cuda_double_(l_max_, normalized_), calculator_cuda_float_(l_max_, normalized_) //,
if (torch::cuda::is_available()) {
this->calculator_cuda_double_ptr =
std::make_unique<sphericart::cuda::SphericalHarmonics<double>>(l_max_, normalized_);

{
this->omp_num_threads_ = calculator_double_.get_omp_num_threads();
this->calculator_cuda_float_ptr =
std::make_unique<sphericart::cuda::SphericalHarmonics<float>>(l_max_, normalized_);
}
}

torch::Tensor SphericalHarmonics::compute(torch::Tensor xyz) {
Expand Down
15 changes: 9 additions & 6 deletions sphericart/include/sphericart_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ template <typename T> class SphericalHarmonics {
*/
SphericalHarmonics(size_t l_max, bool normalized = false);

/** Default constructor
* Required so sphericart_torch can conditionally instantiate this class
* depending on if cuda is available.
*/

/* @cond */
~SphericalHarmonics();
/* @endcond */
Expand Down Expand Up @@ -95,14 +100,12 @@ template <typename T> class SphericalHarmonics {
private:
size_t l_max; // maximum l value computed by this class
size_t nprefactors;
bool normalized; // should we normalize the input vectors?
T* prefactors_cpu; // host prefactors buffer
T** prefactors_cuda; // storage space for prefactors
int device_count; // number of visible GPU devices

bool normalized; // should we normalize the input vectors?
T* prefactors_cpu = nullptr; // host prefactors buffer
T** prefactors_cuda = nullptr; // storage space for prefactors
int device_count = 0; // number of visible GPU devices
int64_t CUDA_GRID_DIM_X_ = 8;
int64_t CUDA_GRID_DIM_Y_ = 8;

bool cached_compute_with_gradients = false;
bool cached_compute_with_hessian = false;
int64_t _current_shared_mem_allocation = 0;
Expand Down
1 change: 0 additions & 1 deletion sphericart/src/cuda_base.cu
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,6 @@ void sphericart::cuda::spherical_harmonics_backward_cuda_base(
scalar_t* __restrict__ xyz_grad,
void* cuda_stream
) {

dim3 grid_dim(4, 32);

auto find_num_blocks = [](int x, int bdim) { return (x + bdim - 1) / bdim; };
Expand Down
88 changes: 50 additions & 38 deletions sphericart/src/sphericart_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,66 +30,78 @@ template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bo
buffer space, compute prefactors, and sets the function pointers that are
used for the actual calls
*/

this->l_max = (int)l_max;
this->nprefactors = (int)(l_max + 1) * (l_max + 2);
this->normalized = normalized;
this->prefactors_cpu = new T[this->nprefactors];

CUDA_CHECK(cudaGetDeviceCount(&this->device_count));

// compute prefactors on host first
compute_sph_prefactors<T>((int)l_max, this->prefactors_cpu);

CUDA_CHECK(cudaGetDeviceCount(&this->device_count));
if (this->device_count) {
int current_device;

int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));

CUDA_CHECK(cudaGetDevice(&current_device));
// allocate prefactorts on every visible device and copy from host
this->prefactors_cuda = new T*[this->device_count];

// allocate prefactorts on every visible device and copy from host
this->prefactors_cuda = new T*[this->device_count];
for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(
cudaMalloc((void**)&this->prefactors_cuda[device], this->nprefactors * sizeof(T))
);
CUDA_CHECK(cudaMemcpy(
this->prefactors_cuda[device],
this->prefactors_cpu,
this->nprefactors * sizeof(T),
cudaMemcpyHostToDevice
));
}

for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaMalloc((void**)&this->prefactors_cuda[device], this->nprefactors * sizeof(T))
// initialise the currently available amount of shared memory on all visible devices
this->_current_shared_mem_allocation = adjust_shared_memory(
sizeof(T),
this->l_max,
this->CUDA_GRID_DIM_X_,
this->CUDA_GRID_DIM_Y_,
false,
false,
this->_current_shared_mem_allocation
);
CUDA_CHECK(cudaMemcpy(
this->prefactors_cuda[device],
this->prefactors_cpu,
this->nprefactors * sizeof(T),
cudaMemcpyHostToDevice
));
}

// initialise the currently available amount of shared memory on all visible devices
this->_current_shared_mem_allocation = adjust_shared_memory(
sizeof(T),
this->l_max,
this->CUDA_GRID_DIM_X_,
this->CUDA_GRID_DIM_Y_,
false,
false,
this->_current_shared_mem_allocation
);

// set the context back to the current device
CUDA_CHECK(cudaSetDevice(current_device));
// set the context back to the current device
CUDA_CHECK(cudaSetDevice(current_device));
}
}

template <typename T> SphericalHarmonics<T>::~SphericalHarmonics() {
// Destructor, frees the prefactors
delete[] (this->prefactors_cpu);
if (this->prefactors_cpu != nullptr) {
delete[] (this->prefactors_cpu);
this->prefactors_cpu = nullptr;
}

int current_device;
if (this->device_count > 0) {

CUDA_CHECK(cudaGetDevice(&current_device));
int current_device;

for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaFree(this->prefactors_cuda[device]));
}
CUDA_CHECK(cudaGetDevice(&current_device));

CUDA_CHECK(cudaSetDevice(current_device));
for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaDeviceSynchronize());
if (this->prefactors_cuda != nullptr && this->prefactors_cuda[device] != nullptr) {
CUDA_CHECK(cudaFree(this->prefactors_cuda[device]));
this->prefactors_cuda[device] = nullptr;
}
}
this->prefactors_cuda = nullptr;

CUDA_CHECK(cudaSetDevice(current_device));
}
}

template <typename T>
Expand Down

0 comments on commit 40f277e

Please sign in to comment.