diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 0000000..945c9b4 --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/qmp/_hamiltonian.cpp b/qmp/_hamiltonian.cpp index 1968002..30cd57c 100644 --- a/qmp/_hamiltonian.cpp +++ b/qmp/_hamiltonian.cpp @@ -82,6 +82,7 @@ PYBIND11_MODULE(qmb_hamiltonian, m) { #define QMB_LIBRARY(x, y) QMB_LIBRARY_HELPER(x, y) TORCH_LIBRARY_FRAGMENT(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), m) { m.def("apply_within(Tensor configs_i, Tensor psi_i, Tensor configs_j, Tensor site, Tensor kind, Tensor coef) -> Tensor"); + m.def("apply_within_conjugate(Tensor configs_j, Tensor psi_j, Tensor configs_i, Tensor site, Tensor kind, Tensor coef) -> Tensor"); m.def("find_relative(Tensor configs_i, Tensor psi_i, int count_selected, Tensor site, Tensor kind, Tensor coef, Tensor configs_exclude) -> Tensor" ); m.def("diagonal_term(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor"); diff --git a/qmp/_hamiltonian_cpu.cpp b/qmp/_hamiltonian_cpu.cpp index c8cc2df..42d0d4b 100644 --- a/qmp/_hamiltonian_cpu.cpp +++ b/qmp/_hamiltonian_cpu.cpp @@ -77,6 +77,38 @@ std::pair hamiltonian_apply_kernel( return std::make_pair(success, parity); } +template +std::pair hamiltonian_apply_conjugate_kernel( + std::array& current_configs, + std::int64_t term_index, + std::int64_t batch_index, + const std::array* site, // term_number + const std::array* kind // term_number +) { + static_assert(particle_cut == 1 || particle_cut == 2, "particle_cut != 1 or 2 not implemented"); + bool success = true; + bool parity = false; + for (std::int64_t op_index = 0; op_index < max_op_number; ++op_index) { + std::int16_t site_single = site[term_index][op_index]; + std::uint8_t kind_single = kind[term_index][op_index]; + if (kind_single == 2) { + continue; + } + std::uint8_t to_what = 1 - kind_single; + if (get_bit(¤t_configs[site_single / 8], site_single % 8) == to_what) { + success = false; + break; + } + set_bit(¤t_configs[site_single / 8], site_single % 8, to_what); + if constexpr (particle_cut == 1) { + for (std::int16_t s = 0; s < site_single; ++s) { + parity ^= get_bit(¤t_configs[s / 8], s % 8); + } + } + } + return std::make_pair(success, parity); +} + template void apply_within_kernel( std::int64_t term_index, @@ -128,6 +160,57 @@ void apply_within_kernel( result_psi[mid][1] += sign * (coef[term_index][0] * psi[batch_index][1] + coef[term_index][1] * psi[batch_index][0]); } +template +void apply_within_conjugate_kernel( + std::int64_t term_index, + std::int64_t result_batch_index, + std::int64_t term_number, + std::int64_t batch_size, + std::int64_t result_batch_size, + const std::array* site, // term_number + const std::array* kind, // term_number + const std::array* coef, // term_number + const std::array* sorted_configs, // batch_size + const std::array* sorted_psi, // batch_size + const std::array* result_configs, // result_batch_size + std::array* result_psi +) { + std::array current_configs = result_configs[result_batch_index]; + auto [success, parity] = hamiltonian_apply_conjugate_kernel( + /*current_configs=*/current_configs, + /*term_index=*/term_index, + /*batch_index=*/result_batch_index, + /*site=*/site, + /*kind=*/kind + ); + + if (!success) { + return; + } + success = false; + std::int64_t low = 0; + std::int64_t high = batch_size - 1; + std::int64_t mid = 0; + auto less = array_less(); + while (low <= high) { + mid = (low + high) / 2; + if (less(current_configs, sorted_configs[mid])) { + high = mid - 1; + } else if (less(sorted_configs[mid], current_configs)) { + low = mid + 1; + } else { + success = true; + break; + } + } + if (!success) { + return; + } + std::int8_t sign = parity ? -1 : +1; + result_psi[result_batch_index][0] += sign * (coef[term_index][0] * sorted_psi[mid][0] + coef[term_index][1] * sorted_psi[mid][1]); + result_psi[result_batch_index][1] += sign * (coef[term_index][0] * sorted_psi[mid][1] - coef[term_index][1] * sorted_psi[mid][0]); +} + template void apply_within_kernel_interface( std::int64_t term_number, @@ -161,6 +244,39 @@ void apply_within_kernel_interface( } } +template +void apply_within_conjugate_kernel_interface( + std::int64_t term_number, + std::int64_t batch_size, + std::int64_t result_batch_size, + const std::array* site, // term_number + const std::array* kind, // term_number + const std::array* coef, // term_number + const std::array* sorted_configs, // batch_size + const std::array* sorted_psi, // batch_size + const std::array* result_configs, // result_batch_size + std::array* result_psi +) { + for (std::int64_t term_index = 0; term_index < term_number; ++term_index) { + for (std::int64_t result_batch_index = 0; result_batch_index < result_batch_size; ++result_batch_index) { + apply_within_conjugate_kernel( + /*term_index=*/term_index, + /*result_batch_index=*/result_batch_index, + /*term_number=*/term_number, + /*batch_size=*/batch_size, + /*result_batch_size=*/result_batch_size, + /*site=*/site, + /*kind=*/kind, + /*coef=*/coef, + /*sorted_configs=*/sorted_configs, + /*sorted_psi=*/sorted_psi, + /*result_configs=*/result_configs, + /*result_psi=*/result_psi + ); + } + } +} + template auto apply_within_interface( const torch::Tensor& configs, @@ -256,6 +372,100 @@ auto apply_within_interface( return result_psi; } +template +auto apply_within_conjugate_interface( + const torch::Tensor& configs, + const torch::Tensor& psi, + const torch::Tensor& result_configs, + const torch::Tensor& site, + const torch::Tensor& kind, + const torch::Tensor& coef +) -> torch::Tensor { + std::int64_t device_id = configs.device().index(); + std::int64_t batch_size = configs.size(0); + std::int64_t result_batch_size = result_configs.size(0); + std::int64_t term_number = site.size(0); + + TORCH_CHECK(configs.device().type() == torch::kCPU, "configs must be on CPU.") + TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others."); + TORCH_CHECK(configs.is_contiguous(), "configs must be contiguous.") + TORCH_CHECK(configs.dtype() == torch::kUInt8, "configs must be uint8.") + TORCH_CHECK(configs.dim() == 2, "configs must be 2D.") + TORCH_CHECK(configs.size(0) == batch_size, "configs batch size must match the provided batch_size."); + TORCH_CHECK(configs.size(1) == n_qubytes, "configs must have the same number of qubits as the provided n_qubytes."); + + TORCH_CHECK(psi.device().type() == torch::kCPU, "psi must be on CPU.") + TORCH_CHECK(psi.device().index() == device_id, "psi must be on the same device as others."); + TORCH_CHECK(psi.is_contiguous(), "psi must be contiguous.") + TORCH_CHECK(psi.dtype() == torch::kFloat64, "psi must be float64.") + TORCH_CHECK(psi.dim() == 2, "psi must be 2D.") + TORCH_CHECK(psi.size(0) == batch_size, "psi batch size must match the provided batch_size."); + TORCH_CHECK(psi.size(1) == 2, "psi must contain 2 elements for each batch."); + + TORCH_CHECK(result_configs.device().type() == torch::kCPU, "result_configs must be on CPU.") + TORCH_CHECK(result_configs.device().index() == device_id, "result_configs must be on the same device as others."); + TORCH_CHECK(result_configs.is_contiguous(), "result_configs must be contiguous.") + TORCH_CHECK(result_configs.dtype() == torch::kUInt8, "result_configs must be uint8.") + TORCH_CHECK(result_configs.dim() == 2, "result_configs must be 2D.") + TORCH_CHECK(result_configs.size(0) == result_batch_size, "result_configs batch size must match the provided result_batch_size.") + TORCH_CHECK(result_configs.size(1) == n_qubytes, "result_configs must have the same number of qubits as the provided n_qubytes."); + + TORCH_CHECK(site.device().type() == torch::kCPU, "site must be on CPU.") + TORCH_CHECK(site.device().index() == device_id, "site must be on the same device as others."); + TORCH_CHECK(site.is_contiguous(), "site must be contiguous.") + TORCH_CHECK(site.dtype() == torch::kInt16, "site must be int16.") + TORCH_CHECK(site.dim() == 2, "site must be 2D.") + TORCH_CHECK(site.size(0) == term_number, "site size must match the provided term_number."); + TORCH_CHECK(site.size(1) == max_op_number, "site must match the provided max_op_number."); + + TORCH_CHECK(kind.device().type() == torch::kCPU, "kind must be on CPU.") + TORCH_CHECK(kind.device().index() == device_id, "kind must be on the same device as others."); + TORCH_CHECK(kind.is_contiguous(), "kind must be contiguous.") + TORCH_CHECK(kind.dtype() == torch::kUInt8, "kind must be uint8.") + TORCH_CHECK(kind.dim() == 2, "kind must be 2D.") + TORCH_CHECK(kind.size(0) == term_number, "kind size must match the provided term_number."); + TORCH_CHECK(kind.size(1) == max_op_number, "kind must match the provided max_op_number."); + + TORCH_CHECK(coef.device().type() == torch::kCPU, "coef must be on CPU.") + TORCH_CHECK(coef.device().index() == device_id, "coef must be on the same device as others."); + TORCH_CHECK(coef.is_contiguous(), "coef must be contiguous.") + TORCH_CHECK(coef.dtype() == torch::kFloat64, "coef must be float64.") + TORCH_CHECK(coef.dim() == 2, "coef must be 2D.") + TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number."); + TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term."); + + auto configs_sort_index = torch::arange(batch_size, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU, device_id)); + + std::sort( + reinterpret_cast(configs_sort_index.data_ptr()), + reinterpret_cast(configs_sort_index.data_ptr()) + batch_size, + [&configs](std::int64_t i1, std::int64_t i2) { + return array_less()( + reinterpret_cast*>(configs.data_ptr())[i1], + reinterpret_cast*>(configs.data_ptr())[i2] + ); + } + ); + auto sorted_configs = configs.index({configs_sort_index}); + auto sorted_psi = psi.index({configs_sort_index}); + auto result_psi = torch::zeros({result_batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU, device_id)); + + apply_within_conjugate_kernel_interface( + /*term_number=*/term_number, + /*batch_size=*/batch_size, + /*result_batch_size=*/result_batch_size, + /*site=*/reinterpret_cast*>(site.data_ptr()), + /*kind=*/reinterpret_cast*>(kind.data_ptr()), + /*coef=*/reinterpret_cast*>(coef.data_ptr()), + /*sorted_configs=*/reinterpret_cast*>(sorted_configs.data_ptr()), + /*sorted_psi=*/reinterpret_cast*>(sorted_psi.data_ptr()), + /*result_configs=*/reinterpret_cast*>(result_configs.data_ptr()), + /*result_psi=*/reinterpret_cast*>(result_psi.data_ptr()) + ); + + return result_psi; +} + template> void add_into_heap(T* heap, std::int64_t heap_size, const T& value) { auto less = Less(); @@ -557,6 +767,7 @@ auto find_relative_interface( #define QMB_LIBRARY(x, y) QMB_LIBRARY_HELPER(x, y) TORCH_LIBRARY_IMPL(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CPU, m) { m.impl("apply_within", apply_within_interface); + m.impl("apply_within_conjugate", apply_within_conjugate_interface); m.impl("find_relative", find_relative_interface); } #undef QMB_LIBRARY diff --git a/qmp/_hamiltonian_cuda.cu b/qmp/_hamiltonian_cuda.cu index 74c7e55..72e68d8 100644 --- a/qmp/_hamiltonian_cuda.cu +++ b/qmp/_hamiltonian_cuda.cu @@ -83,6 +83,38 @@ __device__ std::pair hamiltonian_apply_kernel( return std::make_pair(success, parity); } +template +__device__ std::pair hamiltonian_apply_conjugate_kernel( + std::array& current_configs, + std::int64_t term_index, + std::int64_t batch_index, + const std::array* site, // term_number + const std::array* kind // term_number +) { + static_assert(particle_cut == 1 || particle_cut == 2, "particle_cut != 1 or 2 not implemented"); + bool success = true; + bool parity = false; + for (std::int64_t op_index = 0; op_index < max_op_number; ++op_index) { + std::int16_t site_single = site[term_index][op_index]; + std::uint8_t kind_single = kind[term_index][op_index]; + if (kind_single == 2) { + continue; + } + std::uint8_t to_what = 1 - kind_single; + if (get_bit(¤t_configs[site_single / 8], site_single % 8) == to_what) { + success = false; + break; + } + set_bit(¤t_configs[site_single / 8], site_single % 8, to_what); + if constexpr (particle_cut == 1) { + for (std::int16_t s = 0; s < site_single; ++s) { + parity ^= get_bit(¤t_configs[s / 8], s % 8); + } + } + } + return std::make_pair(success, parity); +} + template __device__ void apply_within_kernel( std::int64_t term_index, @@ -134,6 +166,57 @@ __device__ void apply_within_kernel( atomicAdd(&result_psi[mid][1], sign * (coef[term_index][0] * psi[batch_index][1] + coef[term_index][1] * psi[batch_index][0])); } +template +__device__ void apply_within_conjugate_kernel( + std::int64_t term_index, + std::int64_t result_batch_index, + std::int64_t term_number, + std::int64_t batch_size, + std::int64_t result_batch_size, + const std::array* site, // term_number + const std::array* kind, // term_number + const std::array* coef, // term_number + const std::array* sorted_configs, // batch_size + const std::array* sorted_psi, // batch_size + const std::array* result_configs, // result_batch_size + std::array* result_psi +) { + std::array current_configs = result_configs[result_batch_index]; + auto [success, parity] = hamiltonian_apply_conjugate_kernel( + /*current_configs=*/current_configs, + /*term_index=*/term_index, + /*batch_index=*/result_batch_index, + /*site=*/site, + /*kind=*/kind + ); + + if (!success) { + return; + } + success = false; + std::int64_t low = 0; + std::int64_t high = batch_size - 1; + std::int64_t mid = 0; + auto less = array_less(); + while (low <= high) { + mid = (low + high) / 2; + if (less(current_configs, sorted_configs[mid])) { + high = mid - 1; + } else if (less(sorted_configs[mid], current_configs)) { + low = mid + 1; + } else { + success = true; + break; + } + } + if (!success) { + return; + } + std::int8_t sign = parity ? -1 : +1; + atomicAdd(&result_psi[result_batch_index][0], sign * (coef[term_index][0] * sorted_psi[mid][0] + coef[term_index][1] * sorted_psi[mid][1])); + atomicAdd(&result_psi[result_batch_index][1], sign * (coef[term_index][0] * sorted_psi[mid][1] - coef[term_index][1] * sorted_psi[mid][0])); +} + template __global__ void apply_within_kernel_interface( std::int64_t term_number, @@ -168,6 +251,40 @@ __global__ void apply_within_kernel_interface( } } +template +__global__ void apply_within_conjugate_kernel_interface( + std::int64_t term_number, + std::int64_t batch_size, + std::int64_t result_batch_size, + const std::array* site, // term_number + const std::array* kind, // term_number + const std::array* coef, // term_number + const std::array* sorted_configs, // batch_size + const std::array* sorted_psi, // batch_size + const std::array* result_configs, // result_batch_size + std::array* result_psi +) { + std::int64_t term_index = blockIdx.x * blockDim.x + threadIdx.x; + std::int64_t result_batch_index = blockIdx.y * blockDim.y + threadIdx.y; + + if (term_index < term_number && result_batch_index < result_batch_size) { + apply_within_conjugate_kernel( + /*term_index=*/term_index, + /*result_batch_index=*/result_batch_index, + /*term_number=*/term_number, + /*batch_size=*/batch_size, + /*result_batch_size=*/result_batch_size, + /*site=*/site, + /*kind=*/kind, + /*coef=*/coef, + /*sorted_configs=*/sorted_configs, + /*sorted_psi=*/sorted_psi, + /*result_configs=*/result_configs, + /*result_psi=*/result_psi + ); + } +} + template auto apply_within_interface( const torch::Tensor& configs, @@ -273,6 +390,110 @@ auto apply_within_interface( return result_psi; } +template +auto apply_within_conjugate_interface( + const torch::Tensor& configs, + const torch::Tensor& psi, + const torch::Tensor& result_configs, + const torch::Tensor& site, + const torch::Tensor& kind, + const torch::Tensor& coef +) -> torch::Tensor { + std::int64_t device_id = configs.device().index(); + std::int64_t batch_size = configs.size(0); + std::int64_t result_batch_size = result_configs.size(0); + std::int64_t term_number = site.size(0); + at::cuda::CUDAGuard cuda_device_guard(device_id); + + TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.") + TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others."); + TORCH_CHECK(configs.is_contiguous(), "configs must be contiguous.") + TORCH_CHECK(configs.dtype() == torch::kUInt8, "configs must be uint8.") + TORCH_CHECK(configs.dim() == 2, "configs must be 2D.") + TORCH_CHECK(configs.size(0) == batch_size, "configs batch size must match the provided batch_size."); + TORCH_CHECK(configs.size(1) == n_qubytes, "configs must have the same number of qubits as the provided n_qubytes."); + + TORCH_CHECK(psi.device().type() == torch::kCUDA, "psi must be on CUDA.") + TORCH_CHECK(psi.device().index() == device_id, "psi must be on the same device as others."); + TORCH_CHECK(psi.is_contiguous(), "psi must be contiguous.") + TORCH_CHECK(psi.dtype() == torch::kFloat64, "psi must be float64.") + TORCH_CHECK(psi.dim() == 2, "psi must be 2D.") + TORCH_CHECK(psi.size(0) == batch_size, "psi batch size must match the provided batch_size."); + TORCH_CHECK(psi.size(1) == 2, "psi must contain 2 elements for each batch."); + + TORCH_CHECK(result_configs.device().type() == torch::kCUDA, "result_configs must be on CUDA.") + TORCH_CHECK(result_configs.device().index() == device_id, "result_configs must be on the same device as others."); + TORCH_CHECK(result_configs.is_contiguous(), "result_configs must be contiguous.") + TORCH_CHECK(result_configs.dtype() == torch::kUInt8, "result_configs must be uint8.") + TORCH_CHECK(result_configs.dim() == 2, "result_configs must be 2D.") + TORCH_CHECK(result_configs.size(0) == result_batch_size, "result_configs batch size must match the provided result_batch_size.") + TORCH_CHECK(result_configs.size(1) == n_qubytes, "result_configs must have the same number of qubits as the provided n_qubytes."); + + TORCH_CHECK(site.device().type() == torch::kCUDA, "site must be on CUDA.") + TORCH_CHECK(site.device().index() == device_id, "site must be on the same device as others."); + TORCH_CHECK(site.is_contiguous(), "site must be contiguous.") + TORCH_CHECK(site.dtype() == torch::kInt16, "site must be int16.") + TORCH_CHECK(site.dim() == 2, "site must be 2D.") + TORCH_CHECK(site.size(0) == term_number, "site size must match the provided term_number."); + TORCH_CHECK(site.size(1) == max_op_number, "site must match the provided max_op_number."); + + TORCH_CHECK(kind.device().type() == torch::kCUDA, "kind must be on CUDA.") + TORCH_CHECK(kind.device().index() == device_id, "kind must be on the same device as others."); + TORCH_CHECK(kind.is_contiguous(), "kind must be contiguous.") + TORCH_CHECK(kind.dtype() == torch::kUInt8, "kind must be uint8.") + TORCH_CHECK(kind.dim() == 2, "kind must be 2D.") + TORCH_CHECK(kind.size(0) == term_number, "kind size must match the provided term_number."); + TORCH_CHECK(kind.size(1) == max_op_number, "kind must match the provided max_op_number."); + + TORCH_CHECK(coef.device().type() == torch::kCUDA, "coef must be on CUDA.") + TORCH_CHECK(coef.device().index() == device_id, "coef must be on the same device as others."); + TORCH_CHECK(coef.is_contiguous(), "coef must be contiguous.") + TORCH_CHECK(coef.dtype() == torch::kFloat64, "coef must be float64.") + TORCH_CHECK(coef.dim() == 2, "coef must be 2D.") + TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number."); + TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term."); + + auto stream = at::cuda::getCurrentCUDAStream(device_id); + auto policy = thrust::device.on(stream); + + cudaDeviceProp prop; + AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); + std::int64_t max_threads_per_block = prop.maxThreadsPerBlock; + + auto sorted_configs = configs.clone(torch::MemoryFormat::Contiguous); + auto configs_sort_index = torch::arange(batch_size, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA, device_id)); + + thrust::sort_by_key( + policy, + reinterpret_cast*>(sorted_configs.data_ptr()), + reinterpret_cast*>(sorted_configs.data_ptr()) + batch_size, + reinterpret_cast(configs_sort_index.data_ptr()), + array_less() + ); + auto sorted_psi = psi.index({configs_sort_index}); + auto result_psi = torch::zeros({result_batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCUDA, device_id)); + + auto threads_per_block = dim3{1, max_threads_per_block >> 1}; // I don't know why, but need to divide by 2 to avoid errors + auto num_blocks = + dim3{(term_number + threads_per_block.x - 1) / threads_per_block.x, (result_batch_size + threads_per_block.y - 1) / threads_per_block.y}; + + apply_within_conjugate_kernel_interface<<>>( + /*term_number=*/term_number, + /*batch_size=*/batch_size, + /*result_batch_size=*/result_batch_size, + /*site=*/reinterpret_cast*>(site.data_ptr()), + /*kind=*/reinterpret_cast*>(kind.data_ptr()), + /*coef=*/reinterpret_cast*>(coef.data_ptr()), + /*sorted_configs=*/reinterpret_cast*>(sorted_configs.data_ptr()), + /*sorted_psi=*/reinterpret_cast*>(sorted_psi.data_ptr()), + /*result_configs=*/reinterpret_cast*>(result_configs.data_ptr()), + /*result_psi=*/reinterpret_cast*>(result_psi.data_ptr()) + ); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + + return result_psi; +} + __device__ void _mutex_lock(int* mutex) { // I don't know why we need to wait for these periods of time, but the examples in the CUDA documentation are written this way. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#nanosleep-example @@ -1007,6 +1228,7 @@ auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor #define QMB_LIBRARY(x, y) QMB_LIBRARY_HELPER(x, y) TORCH_LIBRARY_IMPL(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CUDA, m) { m.impl("apply_within", apply_within_interface); + m.impl("apply_within_conjugate", apply_within_conjugate_interface); m.impl("find_relative", find_relative_interface); m.impl("diagonal_term", diagonal_term_interface); m.impl("single_relative", single_relative_interface); diff --git a/qmp/hamiltonian.py b/qmp/hamiltonian.py index 2ee1201..e747d86 100644 --- a/qmp/hamiltonian.py +++ b/qmp/hamiltonian.py @@ -135,7 +135,7 @@ def apply_within( configs_i : torch.Tensor A uint8 tensor of shape [batch_size_i, n_qubytes] representing the input configurations. psi_i : torch.Tensor - A complex64 tensor of shape [batch_size_i] representing the input amplitudes on the girven configurations. + A complex64 tensor of shape [batch_size_i] representing the input amplitudes on the given configurations. configs_j : torch.Tensor A uint8 tensor of shape [batch_size_j, n_qubytes] representing the output configurations. @@ -156,6 +156,41 @@ def apply_within( ) return psi_j + def apply_within_conjugate( + self, + configs_j: torch.Tensor, + psi_j: torch.Tensor, + configs_i: torch.Tensor, + ) -> torch.Tensor: + """ + Applies the conjugate transpose of the Hamiltonian to the given vector. + + Parameters + ---------- + configs_j : torch.Tensor + A uint8 tensor of shape [batch_size_j, n_qubytes] representing the input configurations. + psi_j : torch.Tensor + A complex64 tensor of shape [batch_size_j] representing the input amplitudes on the given configurations. + configs_i : torch.Tensor + A uint8 tensor of shape [batch_size_i, n_qubytes] representing the output configurations. + + Returns + ------- + torch.Tensor + A tensor of shape [batch_size_i] representing the output amplitudes on the given configurations. + """ + self._prepare_data(configs_j.device) + _apply_within_conjugate: typing.Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor + ] + _apply_within_conjugate = getattr( + self._load_module(configs_j.device.type, configs_j.size(1), self.particle_cut), "apply_within_conjugate" + ) + psi_i = torch.view_as_complex( + _apply_within_conjugate(configs_j, torch.view_as_real(psi_j), configs_i, self.site, self.kind, self.coef) + ) + return psi_i + def find_relative( self, configs_i: torch.Tensor,