Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for #114 #118

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sphericart-torch/include/sphericart/torch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SphericalHarmonics : public torch::CustomClassHolder {
SphericalHarmonics(
int64_t l_max, bool normalized = false, bool backward_second_derivatives = false
);
~SphericalHarmonics();

// Actual calculation, with autograd support
torch::Tensor compute(torch::Tensor xyz);
Expand Down Expand Up @@ -45,8 +46,8 @@ class SphericalHarmonics : public torch::CustomClassHolder {
sphericart::SphericalHarmonics<float> calculator_float_;

// CUDA implementation
sphericart::cuda::SphericalHarmonics<double> calculator_cuda_double_;
sphericart::cuda::SphericalHarmonics<float> calculator_cuda_float_;
sphericart::cuda::SphericalHarmonics<double>* calculator_cuda_double_ptr = nullptr;
nickjbrowning marked this conversation as resolved.
Show resolved Hide resolved
sphericart::cuda::SphericalHarmonics<float>* calculator_cuda_float_ptr = nullptr;
};

} // namespace sphericart_torch
Expand Down
4 changes: 2 additions & 2 deletions sphericart-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ torch::autograd::variable_list SphericalHarmonicsAutograd::forward(
}

if (xyz.dtype() == c10::kDouble) {
calculator.calculator_cuda_double_.compute(
calculator.calculator_cuda_double_ptr->compute(
xyz.data_ptr<double>(),
xyz.size(0),
requires_grad,
Expand All @@ -245,7 +245,7 @@ torch::autograd::variable_list SphericalHarmonicsAutograd::forward(
);

} else if (xyz.dtype() == c10::kFloat) {
calculator.calculator_cuda_float_.compute(
calculator.calculator_cuda_float_ptr->compute(
xyz.data_ptr<float>(),
xyz.size(0),
requires_grad,
Expand Down
23 changes: 19 additions & 4 deletions sphericart-torch/src/torch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,32 @@
#include "sphericart/autograd.hpp"
#include "sphericart/torch_cuda_wrapper.hpp"

using namespace torch;
using namespace sphericart_torch;
using namespace std;

SphericalHarmonics::SphericalHarmonics(int64_t l_max, bool normalized, bool backward_second_derivatives)
: l_max_(l_max), normalized_(normalized),
backward_second_derivatives_(backward_second_derivatives),
calculator_double_(l_max_, normalized_), calculator_float_(l_max_, normalized_) {
this->omp_num_threads_ = calculator_double_.get_omp_num_threads();

calculator_double_(l_max_, normalized_), calculator_float_(l_max_, normalized_),
if (torch::cuda::is_available()) {
this->calculator_cuda_double_ptr =
new sphericart::cuda::SphericalHarmonics<double>(l_max_, normalized_);
this->calculator_cuda_float_ptr =
new sphericart::cuda::SphericalHarmonics<float>(l_max_, normalized_);
}
}

calculator_cuda_double_(l_max_, normalized_), calculator_cuda_float_(l_max_, normalized_) //,
SphericalHarmonics::~SphericalHarmonics() {
if (this->calculator_cuda_double_ptr != nullptr) {
delete this->calculator_cuda_double_ptr;
}

{
this->omp_num_threads_ = calculator_double_.get_omp_num_threads();
if (this->calculator_cuda_float_ptr != nullptr) {
delete this->calculator_cuda_float_ptr;
}
}

torch::Tensor SphericalHarmonics::compute(torch::Tensor xyz) {
Expand Down
15 changes: 9 additions & 6 deletions sphericart/include/sphericart_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ template <typename T> class SphericalHarmonics {
*/
SphericalHarmonics(size_t l_max, bool normalized = false);

/** Default constructor
* Required so sphericart_torch can conditionally instantiate this class
* depending on if cuda is available.
*/

/* @cond */
~SphericalHarmonics();
/* @endcond */
Expand Down Expand Up @@ -95,14 +100,12 @@ template <typename T> class SphericalHarmonics {
private:
size_t l_max; // maximum l value computed by this class
size_t nprefactors;
bool normalized; // should we normalize the input vectors?
T* prefactors_cpu; // host prefactors buffer
T** prefactors_cuda; // storage space for prefactors
int device_count; // number of visible GPU devices

bool normalized; // should we normalize the input vectors?
T* prefactors_cpu = nullptr; // host prefactors buffer
T** prefactors_cuda = nullptr; // storage space for prefactors
int device_count = 0; // number of visible GPU devices
int64_t CUDA_GRID_DIM_X_ = 8;
int64_t CUDA_GRID_DIM_Y_ = 8;

bool cached_compute_with_gradients = false;
bool cached_compute_with_hessian = false;
int64_t _current_shared_mem_allocation = 0;
Expand Down
88 changes: 50 additions & 38 deletions sphericart/src/sphericart_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,66 +30,78 @@ template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bo
buffer space, compute prefactors, and sets the function pointers that are
used for the actual calls
*/

this->l_max = (int)l_max;
this->nprefactors = (int)(l_max + 1) * (l_max + 2);
this->normalized = normalized;
this->prefactors_cpu = new T[this->nprefactors];

CUDA_CHECK(cudaGetDeviceCount(&this->device_count));

// compute prefactors on host first
compute_sph_prefactors<T>((int)l_max, this->prefactors_cpu);

CUDA_CHECK(cudaGetDeviceCount(&this->device_count));
if (this->device_count) {
int current_device;

int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));

CUDA_CHECK(cudaGetDevice(&current_device));
// allocate prefactorts on every visible device and copy from host
this->prefactors_cuda = new T*[this->device_count];

// allocate prefactorts on every visible device and copy from host
this->prefactors_cuda = new T*[this->device_count];
for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(
cudaMalloc((void**)&this->prefactors_cuda[device], this->nprefactors * sizeof(T))
);
CUDA_CHECK(cudaMemcpy(
this->prefactors_cuda[device],
this->prefactors_cpu,
this->nprefactors * sizeof(T),
cudaMemcpyHostToDevice
));
}

for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaMalloc((void**)&this->prefactors_cuda[device], this->nprefactors * sizeof(T))
// initialise the currently available amount of shared memory on all visible devices
this->_current_shared_mem_allocation = adjust_shared_memory(
sizeof(T),
this->l_max,
this->CUDA_GRID_DIM_X_,
this->CUDA_GRID_DIM_Y_,
false,
false,
this->_current_shared_mem_allocation
);
CUDA_CHECK(cudaMemcpy(
this->prefactors_cuda[device],
this->prefactors_cpu,
this->nprefactors * sizeof(T),
cudaMemcpyHostToDevice
));
}

// initialise the currently available amount of shared memory on all visible devices
this->_current_shared_mem_allocation = adjust_shared_memory(
sizeof(T),
this->l_max,
this->CUDA_GRID_DIM_X_,
this->CUDA_GRID_DIM_Y_,
false,
false,
this->_current_shared_mem_allocation
);

// set the context back to the current device
CUDA_CHECK(cudaSetDevice(current_device));
// set the context back to the current device
CUDA_CHECK(cudaSetDevice(current_device));
}
}

template <typename T> SphericalHarmonics<T>::~SphericalHarmonics() {
// Destructor, frees the prefactors
delete[] (this->prefactors_cpu);
if (this->prefactors_cpu != nullptr) {
delete[] (this->prefactors_cpu);
this->prefactors_cpu = nullptr;
}

int current_device;
if (this->device_count > 0) {

CUDA_CHECK(cudaGetDevice(&current_device));
int current_device;

for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaFree(this->prefactors_cuda[device]));
}
CUDA_CHECK(cudaGetDevice(&current_device));

CUDA_CHECK(cudaSetDevice(current_device));
for (int device = 0; device < this->device_count; device++) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaDeviceSynchronize());
if (this->prefactors_cuda != nullptr && this->prefactors_cuda[device] != nullptr) {
CUDA_CHECK(cudaFree(this->prefactors_cuda[device]));
this->prefactors_cuda[device] = nullptr;
}
}
this->prefactors_cuda = nullptr;

CUDA_CHECK(cudaSetDevice(current_device));
}
}

template <typename T>
Expand Down
Loading