Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ methods for applying operations to DDs #674

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions include/mqt-core/dd/Operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

#include "Definitions.hpp"
#include "dd/DDDefinitions.hpp"
#include "dd/Edge.hpp"
#include "dd/GateMatrixDefinitions.hpp"
#include "dd/Package.hpp"
#include "ir/Permutation.hpp"
#include "ir/operations/ClassicControlledOperation.hpp"
#include "ir/operations/CompoundOperation.hpp"
#include "ir/operations/Control.hpp"
#include "ir/operations/NonUnitaryOperation.hpp"
#include "ir/operations/OpType.hpp"
#include "ir/operations/Operation.hpp"
#include "ir/operations/StandardOperation.hpp"
Expand All @@ -16,6 +18,7 @@
#include <cmath>
#include <cstddef>
#include <ostream>
#include <random>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -262,6 +265,114 @@ qc::MatrixDD getInverseDD(const qc::Operation* op, Package<Config>& dd,
return getDD(op, dd, permutation, true);
}

template <class Config, class Node>
Edge<Node> applyUnitaryOperation(const qc::Operation* op, Edge<Node> in,
Package<Config>& dd,
qc::Permutation& permutation) {
static_assert(std::is_same_v<Node, dd::vNode> ||
std::is_same_v<Node, dd::mNode>);
assert(op->isUnitary());
auto tmp = dd.multiply(getDD(op, dd, permutation), in);
dd.incRef(tmp);
dd.decRef(in);
dd.garbageCollect();
return tmp;
}

template <class Config>
qc::VectorDD
applyMeasurement(const qc::Operation* op, qc::VectorDD in, Package<Config>& dd,
const qc::Permutation& permutation, std::mt19937_64& rng,
std::vector<bool>& measurements) {
assert(op->getType() == qc::Measure);
assert(op->isNonUnitaryOperation());
const auto* measure = dynamic_cast<const qc::NonUnitaryOperation*>(op);
assert(measure != nullptr);

const auto& qubits = measure->getTargets();
const auto& bits = measure->getClassics();
for (size_t j = 0U; j < qubits.size(); ++j) {
measurements.at(bits.at(j)) =
dd.measureOneCollapsing(
in, static_cast<dd::Qubit>(permutation.at(qubits.at(j))), true,
rng) == '1';
}
return in;
}

template <class Config>
qc::VectorDD applyReset(const qc::Operation* op, qc::VectorDD in,
Package<Config>& dd, qc::Permutation& permutation,
std::mt19937_64& rng) {
assert(op->getType() == qc::Reset);
assert(op->isNonUnitaryOperation());
const auto* reset = dynamic_cast<const qc::NonUnitaryOperation*>(op);
assert(reset != nullptr);

const auto& qubits = reset->getTargets();
for (const auto& qubit : qubits) {
const auto bit = dd.measureOneCollapsing(
in, static_cast<dd::Qubit>(permutation.at(qubit)), true, rng);
// apply an X operation whenever the measured result is one
if (bit == '1') {
const auto x = qc::StandardOperation(qubit, qc::X);
in = applyUnitaryOperation(&x, in, dd, permutation);
}
}
return in;
}

template <class Config>
qc::VectorDD applyClassicControlledOperation(const qc::Operation* op,
qc::VectorDD in,
Package<Config>& dd,
qc::Permutation& permutation,
std::vector<bool>& measurements) {
assert(op->isClassicControlledOperation());
const auto* classic = dynamic_cast<const qc::ClassicControlledOperation*>(op);
assert(classic != nullptr);

const auto& controlRegister = classic->getControlRegister();
const auto& expectedValue = classic->getExpectedValue();
const auto& comparisonKind = classic->getComparisonKind();

auto actualValue = 0ULL;
// determine the actual value from measurements
for (std::size_t j = 0; j < controlRegister.second; ++j) {
if (measurements[controlRegister.first + j]) {
actualValue |= 1ULL << j;
}
}

// check if the actual value matches the expected value according to the
// comparison kind
const auto control = [&]() -> bool {
switch (comparisonKind) {
case qc::ComparisonKind::Eq:
return actualValue == expectedValue;
case qc::ComparisonKind::Neq:
return actualValue != expectedValue;
case qc::ComparisonKind::Lt:
return actualValue < expectedValue;
case qc::ComparisonKind::Leq:
return actualValue <= expectedValue;
case qc::ComparisonKind::Gt:
return actualValue > expectedValue;
case qc::ComparisonKind::Geq:
return actualValue >= expectedValue;
default:
qc::unreachable();
throw qc::QFRException("Unknown comparison kind.");
}
}();

if (!control) {
return in;
}

return applyUnitaryOperation(classic->getOperation(), in, dd, permutation);
}

template <class Config>
void dumpTensor(qc::Operation* op, std::ostream& of,
std::vector<std::size_t>& inds, std::size_t& gateIdx,
Expand Down
16 changes: 13 additions & 3 deletions include/mqt-core/dd/Package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ template <class Config> class Package {
static_cast<Qubit>(p),
std::array{f, dEdge::zero(), dEdge::zero(), dEdge::zero()});
}
incRef(f);
return f;
}

Expand All @@ -319,6 +320,7 @@ template <class Config> class Package {
for (std::size_t p = start; p < n + start; p++) {
f = makeDDNode(static_cast<Qubit>(p), std::array{f, vEdge::zero()});
}
incRef(f);
return f;
}
// generate computational basis state |i> with n qubits
Expand All @@ -339,6 +341,7 @@ template <class Config> class Package {
f = makeDDNode(static_cast<Qubit>(p), std::array{vEdge::zero(), f});
}
}
incRef(f);
return f;
}
// generate general basis state with n qubits
Expand Down Expand Up @@ -391,7 +394,9 @@ template <class Config> class Package {
break;
}
}
return {f.p, cn.lookup(f.w)};
vEdge e{f.p, cn.lookup(f.w)};
incRef(e);
return e;
}

// generate general GHZ state with n qubits
Expand All @@ -418,11 +423,13 @@ template <class Config> class Package {
std::array{vEdge::zero(), rightSubtree});
}

return makeDDNode(
vEdge e = makeDDNode(
static_cast<Qubit>(n - 1),
std::array<vEdge, RADIX>{
{{leftSubtree.p, {&constants::sqrt2over2, &constants::zero}},
{rightSubtree.p, {&constants::sqrt2over2, &constants::zero}}}});
incRef(e);
return e;
}

// generate general W state with n qubits
Expand Down Expand Up @@ -459,6 +466,7 @@ template <class Config> class Package {
std::array{rightSubtree, vEdge::zero()});
}
}
incRef(leftSubtree);
return leftSubtree;
}

Expand All @@ -480,7 +488,9 @@ template <class Config> class Package {
const auto level = static_cast<Qubit>(std::log2(length) - 1);
const auto state =
makeStateFromVector(stateVector.begin(), stateVector.end(), level);
return {state.p, cn.lookup(state.w)};
vEdge e{state.p, cn.lookup(state.w)};
incRef(e);
return e;
}

/**
Expand Down
8 changes: 1 addition & 7 deletions include/mqt-core/dd/Simulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,9 @@ VectorDD simulate(const QuantumComputation* qc, const VectorDD& in,
// measurements are currently not supported here
auto permutation = qc->initialLayout;
auto e = in;
dd.incRef(e);

for (const auto& op : *qc) {
auto tmp = dd.multiply(getDD(op.get(), dd, permutation), e);
dd.incRef(tmp);
dd.decRef(e);
e = tmp;

dd.garbageCollect();
e = applyUnitaryOperation(op.get(), e, dd, permutation);
}

// correct permutation if necessary
Expand Down
2 changes: 2 additions & 0 deletions include/mqt-core/ir/operations/ClassicControlledOperation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class ClassicControlledOperation final : public Operation {

[[nodiscard]] auto getOperation() const { return op.get(); }

[[nodiscard]] auto getComparisonKind() const { return comparisonKind; }

[[nodiscard]] const Targets& getTargets() const override {
return op->getTargets();
}
Expand Down
8 changes: 1 addition & 7 deletions src/dd/FunctionalityConstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@ MatrixDD buildFunctionality(const QuantumComputation* qc, Package<Config>& dd) {
auto e = dd.createInitialMatrix(qc->ancillary);

for (const auto& op : *qc) {
auto tmp = dd.multiply(getDD(op.get(), dd, permutation), e);

dd.incRef(tmp);
dd.decRef(e);
e = tmp;

dd.garbageCollect();
e = applyUnitaryOperation(op.get(), e, dd, permutation);
}
// correct permutation if necessary
changePermutation(e, permutation, qc->outputPermutation, dd);
Expand Down
86 changes: 23 additions & 63 deletions src/dd/Simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <map>
#include <random>
#include <string>
#include <vector>

namespace dd {
template <class Config>
Expand Down Expand Up @@ -80,20 +81,14 @@ simulate(const QuantumComputation* qc, const VectorDD& in, Package<Config>& dd,
// simulate once and measure all qubits repeatedly
auto permutation = qc->initialLayout;
auto e = in;
dd.incRef(e);

for (const auto& op : *qc) {
// simply skip any non-unitary
if (!op->isUnitary()) {
continue;
}

auto tmp = dd.multiply(getDD(op.get(), dd, permutation), e);
dd.incRef(tmp);
dd.decRef(e);
e = tmp;

dd.garbageCollect();
e = applyUnitaryOperation(op.get(), e, dd, permutation);
}

// correct permutation if necessary
Expand Down Expand Up @@ -139,82 +134,47 @@ simulate(const QuantumComputation* qc, const VectorDD& in, Package<Config>& dd,
std::map<std::string, std::size_t> counts{};

for (std::size_t i = 0U; i < shots; i++) {
std::map<std::size_t, char> measurements{};
// increase reference count of input state so that it is collected
dd.incRef(in);

std::vector<bool> measurements(qc->getNcbits(), false);

auto permutation = qc->initialLayout;
auto e = in;
dd.incRef(e);

for (const auto& op : *qc) {
if (const auto* nonunitary = dynamic_cast<NonUnitaryOperation*>(op.get());
nonunitary != nullptr) {
if (nonunitary->getType() == Measure) {
const auto& qubits = nonunitary->getTargets();
const auto& bits = nonunitary->getClassics();
for (std::size_t j = 0U; j < qubits.size(); ++j) {
measurements[bits.at(j)] = dd.measureOneCollapsing(
e, static_cast<Qubit>(permutation.at(qubits.at(j))), true, mt);
}
continue;
}

if (nonunitary->getType() == Reset) {
const auto& qubits = nonunitary->getTargets();
for (const auto& qubit : qubits) {
auto bit = dd.measureOneCollapsing(
e, static_cast<Qubit>(permutation.at(qubit)), true, mt);
// apply an X operation whenever the measured result is one
if (bit == '1') {
const auto x =
qc::StandardOperation(permutation.at(qubit), qc::X);
auto tmp = dd.multiply(getDD(&x, dd), e);
dd.incRef(tmp);
dd.decRef(e);
e = tmp;
dd.garbageCollect();
}
}
continue;
}
if (op->getType() == Measure) {
e = applyMeasurement(op.get(), e, dd, permutation, mt, measurements);
continue;
}

if (const auto* classicControlled =
dynamic_cast<ClassicControlledOperation*>(op.get());
classicControlled != nullptr) {
const auto& controlRegister = classicControlled->getControlRegister();
const auto& expectedValue = classicControlled->getExpectedValue();
auto actualValue = 0ULL;
// determine the actual value from measurements
for (std::size_t j = 0; j < controlRegister.second; ++j) {
if (measurements[controlRegister.first + j] == '1') {
actualValue |= 1ULL << j;
}
}

// do not apply an operation if the value is not the expected one
if (actualValue != expectedValue) {
continue;
}
if (op->getType() == Reset) {
e = applyReset(op.get(), e, dd, permutation, mt);
continue;
}

auto tmp = dd.multiply(getDD(op.get(), dd, permutation), e);
dd.incRef(tmp);
dd.decRef(e);
e = tmp;
if (op->isClassicControlledOperation()) {
e = applyClassicControlledOperation(op.get(), e, dd, permutation,
measurements);
continue;
}

dd.garbageCollect();
e = applyUnitaryOperation(op.get(), e, dd, permutation);
}

// reduce reference count of measured state
dd.decRef(e);

std::string shot(qc->getNcbits(), '0');
for (const auto& [bit, value] : measurements) {
shot[qc->getNcbits() - bit - 1U] = value;
for (size_t bit = 0U; bit < qc->getNcbits(); ++bit) {
shot[qc->getNcbits() - bit - 1U] = measurements[bit] ? '1' : '0';
}
counts[shot]++;
}

// decrease reference count of input state so that it can be garbage collected
dd.decRef(in);

return counts;
}

Expand Down
8 changes: 4 additions & 4 deletions test/algorithms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
if(TARGET MQT::CoreAlgo)
file(GLOB_RECURSE ALGO_TEST_SOURCES *.cpp)
package_add_test(mqt-core-algo-test MQT::CoreAlgo ${ALGO_TEST_SOURCES})
target_link_libraries(mqt-core-algo-test PRIVATE MQT::CoreDD MQT::CoreCircuitOptimizer)
if(TARGET MQT::CoreAlgorithms)
file(GLOB_RECURSE ALGORITHMS_TEST_SOURCES *.cpp)
package_add_test(mqt-core-algorithms-test MQT::CoreAlgorithms ${ALGORITHMS_TEST_SOURCES})
target_link_libraries(mqt-core-algorithms-test PRIVATE MQT::CoreDD MQT::CoreCircuitOptimizer)
endif()
1 change: 0 additions & 1 deletion test/algorithms/test_qft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ TEST_P(QFT, FunctionalityRecursiveEquality) {
TEST_P(QFT, DynamicSimulation) {
// there should be no error constructing the circuit
ASSERT_NO_THROW({ qc = std::make_unique<qc::QFT>(nqubits, true, true); });
auto dd = std::make_unique<dd::Package<>>(nqubits);
qc->printStatistics(std::cout);

// simulate the circuit
Expand Down
Loading
Loading