From af0bc60220317d48159c2582c97676e00441e8b3 Mon Sep 17 00:00:00 2001 From: Hanrui Wang Date: Wed, 30 Aug 2023 20:46:04 -0400 Subject: [PATCH] [minor] expand_param supports fixed gate --- examples/mnist/mnist.py | 2 +- torchquantum/encoding/encodings.py | 14 ++++++++++++++ torchquantum/plugin/qiskit/qiskit_plugin.py | 9 +++++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/examples/mnist/mnist.py b/examples/mnist/mnist.py index 65276e91..23e5eafe 100644 --- a/examples/mnist/mnist.py +++ b/examples/mnist/mnist.py @@ -79,7 +79,7 @@ def forward(self, qdev: tq.QuantumDevice): def __init__(self): super().__init__() self.n_wires = 4 - self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3rx"]) + self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3_h_rx"]) self.q_layer = self.QLayer() self.measure = tq.MeasureAll(tq.PauliZ) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index e519181b..f8d2056d 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -301,6 +301,20 @@ def __init__(self): {"input_idx": [12, 13, 14], "func": "u3", "wires": [3]}, {"input_idx": [15], "func": "rx", "wires": [3]}, ], + "4x4_u3_h_rx": [ + {"input_idx": [0, 1, 2], "func": "u3", "wires": [0]}, + {"input_idx": [3], "func": "rx", "wires": [0]}, + {"func": "h", "wires": [0]}, + {"func": "h", "wires": [1]}, + {"func": "h", "wires": [2]}, + {"func": "h", "wires": [3]}, + {"input_idx": [4, 5, 6], "func": "u3", "wires": [1]}, + {"input_idx": [7], "func": "rx", "wires": [1]}, + {"input_idx": [8, 9, 10], "func": "u3", "wires": [2]}, + {"input_idx": [11], "func": "rx", "wires": [2]}, + {"input_idx": [12, 13, 14], "func": "u3", "wires": [3]}, + {"input_idx": [15], "func": "rx", "wires": [3]}, + ], "4x4_ryzxy": [ {"input_idx": [0], "func": "ry", "wires": [0]}, {"input_idx": [1], "func": "ry", "wires": [1]}, diff --git a/torchquantum/plugin/qiskit/qiskit_plugin.py b/torchquantum/plugin/qiskit/qiskit_plugin.py index 6b533b3f..954c3b8a 100644 --- a/torchquantum/plugin/qiskit/qiskit_plugin.py +++ b/torchquantum/plugin/qiskit/qiskit_plugin.py @@ -661,10 +661,15 @@ def op_history2qiskit_expand_params(n_wires, op_history, bsz): for i in range(bsz): circ = QuantumCircuit(n_wires) for op in op_history: + if "params" in op.keys() and op["params"] is not None: + param = op["params"][i] + else: + param = None + append_fixed_gate( - circ, op["name"], op["params"][i], op["wires"], op["inverse"] + circ, op["name"], param, op["wires"], op["inverse"] ) - + circs_all.append(circ) return circs_all