Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _codeql_detected_source_root
1 change: 1 addition & 0 deletions qmp/_hamiltonian.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ PYBIND11_MODULE(qmb_hamiltonian, m) {
#define QMB_LIBRARY(x, y) QMB_LIBRARY_HELPER(x, y)
TORCH_LIBRARY_FRAGMENT(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), m) {
m.def("apply_within(Tensor configs_i, Tensor psi_i, Tensor configs_j, Tensor site, Tensor kind, Tensor coef) -> Tensor");
m.def("apply_within_conjugate(Tensor configs_j, Tensor psi_j, Tensor configs_i, Tensor site, Tensor kind, Tensor coef) -> Tensor");
m.def("find_relative(Tensor configs_i, Tensor psi_i, int count_selected, Tensor site, Tensor kind, Tensor coef, Tensor configs_exclude) -> Tensor"
);
m.def("diagonal_term(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor");
Expand Down
211 changes: 211 additions & 0 deletions qmp/_hamiltonian_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,38 @@ std::pair<bool, bool> hamiltonian_apply_kernel(
return std::make_pair(success, parity);
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
std::pair<bool, bool> hamiltonian_apply_conjugate_kernel(
std::array<std::uint8_t, n_qubytes>& current_configs,
std::int64_t term_index,
std::int64_t batch_index,
const std::array<std::int16_t, max_op_number>* site, // term_number
const std::array<std::uint8_t, max_op_number>* kind // term_number
) {
static_assert(particle_cut == 1 || particle_cut == 2, "particle_cut != 1 or 2 not implemented");
bool success = true;
bool parity = false;
for (std::int64_t op_index = 0; op_index < max_op_number; ++op_index) {
std::int16_t site_single = site[term_index][op_index];
std::uint8_t kind_single = kind[term_index][op_index];
if (kind_single == 2) {
continue;
}
std::uint8_t to_what = 1 - kind_single;
if (get_bit(&current_configs[site_single / 8], site_single % 8) == to_what) {
success = false;
break;
}
set_bit(&current_configs[site_single / 8], site_single % 8, to_what);
if constexpr (particle_cut == 1) {
for (std::int16_t s = 0; s < site_single; ++s) {
parity ^= get_bit(&current_configs[s / 8], s % 8);
}
}
}
return std::make_pair(success, parity);
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
void apply_within_kernel(
std::int64_t term_index,
Expand Down Expand Up @@ -128,6 +160,57 @@ void apply_within_kernel(
result_psi[mid][1] += sign * (coef[term_index][0] * psi[batch_index][1] + coef[term_index][1] * psi[batch_index][0]);
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
void apply_within_conjugate_kernel(
std::int64_t term_index,
std::int64_t result_batch_index,
std::int64_t term_number,
std::int64_t batch_size,
std::int64_t result_batch_size,
const std::array<std::int16_t, max_op_number>* site, // term_number
const std::array<std::uint8_t, max_op_number>* kind, // term_number
const std::array<double, 2>* coef, // term_number
const std::array<std::uint8_t, n_qubytes>* sorted_configs, // batch_size
const std::array<double, 2>* sorted_psi, // batch_size
const std::array<std::uint8_t, n_qubytes>* result_configs, // result_batch_size
std::array<double, 2>* result_psi
) {
std::array<std::uint8_t, n_qubytes> current_configs = result_configs[result_batch_index];
auto [success, parity] = hamiltonian_apply_conjugate_kernel<max_op_number, n_qubytes, particle_cut>(
/*current_configs=*/current_configs,
/*term_index=*/term_index,
/*batch_index=*/result_batch_index,
/*site=*/site,
/*kind=*/kind
);

if (!success) {
return;
}
success = false;
std::int64_t low = 0;
std::int64_t high = batch_size - 1;
std::int64_t mid = 0;
auto less = array_less<std::uint8_t, n_qubytes>();
while (low <= high) {
mid = (low + high) / 2;
if (less(current_configs, sorted_configs[mid])) {
high = mid - 1;
} else if (less(sorted_configs[mid], current_configs)) {
low = mid + 1;
} else {
success = true;
break;
}
}
if (!success) {
return;
}
std::int8_t sign = parity ? -1 : +1;
result_psi[result_batch_index][0] += sign * (coef[term_index][0] * sorted_psi[mid][0] + coef[term_index][1] * sorted_psi[mid][1]);
result_psi[result_batch_index][1] += sign * (coef[term_index][0] * sorted_psi[mid][1] - coef[term_index][1] * sorted_psi[mid][0]);
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
void apply_within_kernel_interface(
std::int64_t term_number,
Expand Down Expand Up @@ -161,6 +244,39 @@ void apply_within_kernel_interface(
}
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
void apply_within_conjugate_kernel_interface(
std::int64_t term_number,
std::int64_t batch_size,
std::int64_t result_batch_size,
const std::array<std::int16_t, max_op_number>* site, // term_number
const std::array<std::uint8_t, max_op_number>* kind, // term_number
const std::array<double, 2>* coef, // term_number
const std::array<std::uint8_t, n_qubytes>* sorted_configs, // batch_size
const std::array<double, 2>* sorted_psi, // batch_size
const std::array<std::uint8_t, n_qubytes>* result_configs, // result_batch_size
std::array<double, 2>* result_psi
) {
for (std::int64_t term_index = 0; term_index < term_number; ++term_index) {
for (std::int64_t result_batch_index = 0; result_batch_index < result_batch_size; ++result_batch_index) {
apply_within_conjugate_kernel<max_op_number, n_qubytes, particle_cut>(
/*term_index=*/term_index,
/*result_batch_index=*/result_batch_index,
/*term_number=*/term_number,
/*batch_size=*/batch_size,
/*result_batch_size=*/result_batch_size,
/*site=*/site,
/*kind=*/kind,
/*coef=*/coef,
/*sorted_configs=*/sorted_configs,
/*sorted_psi=*/sorted_psi,
/*result_configs=*/result_configs,
/*result_psi=*/result_psi
);
}
}
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
auto apply_within_interface(
const torch::Tensor& configs,
Expand Down Expand Up @@ -256,6 +372,100 @@ auto apply_within_interface(
return result_psi;
}

template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
auto apply_within_conjugate_interface(
const torch::Tensor& configs,
const torch::Tensor& psi,
const torch::Tensor& result_configs,
const torch::Tensor& site,
const torch::Tensor& kind,
const torch::Tensor& coef
) -> torch::Tensor {
std::int64_t device_id = configs.device().index();
std::int64_t batch_size = configs.size(0);
std::int64_t result_batch_size = result_configs.size(0);
std::int64_t term_number = site.size(0);

TORCH_CHECK(configs.device().type() == torch::kCPU, "configs must be on CPU.")
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
TORCH_CHECK(configs.is_contiguous(), "configs must be contiguous.")
TORCH_CHECK(configs.dtype() == torch::kUInt8, "configs must be uint8.")
TORCH_CHECK(configs.dim() == 2, "configs must be 2D.")
TORCH_CHECK(configs.size(0) == batch_size, "configs batch size must match the provided batch_size.");
TORCH_CHECK(configs.size(1) == n_qubytes, "configs must have the same number of qubits as the provided n_qubytes.");

TORCH_CHECK(psi.device().type() == torch::kCPU, "psi must be on CPU.")
TORCH_CHECK(psi.device().index() == device_id, "psi must be on the same device as others.");
TORCH_CHECK(psi.is_contiguous(), "psi must be contiguous.")
TORCH_CHECK(psi.dtype() == torch::kFloat64, "psi must be float64.")
TORCH_CHECK(psi.dim() == 2, "psi must be 2D.")
TORCH_CHECK(psi.size(0) == batch_size, "psi batch size must match the provided batch_size.");
TORCH_CHECK(psi.size(1) == 2, "psi must contain 2 elements for each batch.");

TORCH_CHECK(result_configs.device().type() == torch::kCPU, "result_configs must be on CPU.")
TORCH_CHECK(result_configs.device().index() == device_id, "result_configs must be on the same device as others.");
TORCH_CHECK(result_configs.is_contiguous(), "result_configs must be contiguous.")
TORCH_CHECK(result_configs.dtype() == torch::kUInt8, "result_configs must be uint8.")
TORCH_CHECK(result_configs.dim() == 2, "result_configs must be 2D.")
TORCH_CHECK(result_configs.size(0) == result_batch_size, "result_configs batch size must match the provided result_batch_size.")
TORCH_CHECK(result_configs.size(1) == n_qubytes, "result_configs must have the same number of qubits as the provided n_qubytes.");

TORCH_CHECK(site.device().type() == torch::kCPU, "site must be on CPU.")
TORCH_CHECK(site.device().index() == device_id, "site must be on the same device as others.");
TORCH_CHECK(site.is_contiguous(), "site must be contiguous.")
TORCH_CHECK(site.dtype() == torch::kInt16, "site must be int16.")
TORCH_CHECK(site.dim() == 2, "site must be 2D.")
TORCH_CHECK(site.size(0) == term_number, "site size must match the provided term_number.");
TORCH_CHECK(site.size(1) == max_op_number, "site must match the provided max_op_number.");

TORCH_CHECK(kind.device().type() == torch::kCPU, "kind must be on CPU.")
TORCH_CHECK(kind.device().index() == device_id, "kind must be on the same device as others.");
TORCH_CHECK(kind.is_contiguous(), "kind must be contiguous.")
TORCH_CHECK(kind.dtype() == torch::kUInt8, "kind must be uint8.")
TORCH_CHECK(kind.dim() == 2, "kind must be 2D.")
TORCH_CHECK(kind.size(0) == term_number, "kind size must match the provided term_number.");
TORCH_CHECK(kind.size(1) == max_op_number, "kind must match the provided max_op_number.");

TORCH_CHECK(coef.device().type() == torch::kCPU, "coef must be on CPU.")
TORCH_CHECK(coef.device().index() == device_id, "coef must be on the same device as others.");
TORCH_CHECK(coef.is_contiguous(), "coef must be contiguous.")
TORCH_CHECK(coef.dtype() == torch::kFloat64, "coef must be float64.")
TORCH_CHECK(coef.dim() == 2, "coef must be 2D.")
TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number.");
TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term.");

auto configs_sort_index = torch::arange(batch_size, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU, device_id));

std::sort(
reinterpret_cast<std::int64_t*>(configs_sort_index.data_ptr()),
reinterpret_cast<std::int64_t*>(configs_sort_index.data_ptr()) + batch_size,
[&configs](std::int64_t i1, std::int64_t i2) {
return array_less<std::uint8_t, n_qubytes>()(
reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(configs.data_ptr())[i1],
reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(configs.data_ptr())[i2]
);
}
);
auto sorted_configs = configs.index({configs_sort_index});
auto sorted_psi = psi.index({configs_sort_index});
auto result_psi = torch::zeros({result_batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU, device_id));

apply_within_conjugate_kernel_interface<max_op_number, n_qubytes, particle_cut>(
/*term_number=*/term_number,
/*batch_size=*/batch_size,
/*result_batch_size=*/result_batch_size,
/*site=*/reinterpret_cast<const std::array<std::int16_t, max_op_number>*>(site.data_ptr()),
/*kind=*/reinterpret_cast<const std::array<std::uint8_t, max_op_number>*>(kind.data_ptr()),
/*coef=*/reinterpret_cast<const std::array<double, 2>*>(coef.data_ptr()),
/*sorted_configs=*/reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(sorted_configs.data_ptr()),
/*sorted_psi=*/reinterpret_cast<const std::array<double, 2>*>(sorted_psi.data_ptr()),
/*result_configs=*/reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(result_configs.data_ptr()),
/*result_psi=*/reinterpret_cast<std::array<double, 2>*>(result_psi.data_ptr())
);

return result_psi;
}

template<typename T, typename Less = std::less<T>>
void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
auto less = Less();
Expand Down Expand Up @@ -557,6 +767,7 @@ auto find_relative_interface(
#define QMB_LIBRARY(x, y) QMB_LIBRARY_HELPER(x, y)
TORCH_LIBRARY_IMPL(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CPU, m) {
m.impl("apply_within", apply_within_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
m.impl("apply_within_conjugate", apply_within_conjugate_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
m.impl("find_relative", find_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
}
#undef QMB_LIBRARY
Expand Down
Loading