diff --git a/torchquantum/graph/graphs.py b/torchquantum/graph/graphs.py index 0dd7baa1..45bb13fc 100644 --- a/torchquantum/graph/graphs.py +++ b/torchquantum/graph/graphs.py @@ -238,11 +238,11 @@ def build_static_matrix(self): # for wire_modules in self.wire_module_list: for module in self.flat_module_list: name = module.name - if name in tq.Operator.fixed_ops: + if name in tq.operator.fixed_ops: if name not in self.static_matrix_dict.keys(): # fixed operator, all share one static matrix self.static_matrix_dict[module.name] = module.matrix.to(self.device) - elif name in tq.Operator.parameterized_ops and name not in [ + elif name in tq.operator.parameterized_ops and name not in [ "QubitUnitary", "QubitUnitaryFast", "TrainableUnitary", @@ -281,9 +281,9 @@ def build_static_matrix(self): # for wire_modules in self.wire_module_list: for module in self.flat_module_list: name = module.name - if name in tq.Operator.fixed_ops: + if name in tq.operator.fixed_ops: module.static_matrix = self.static_matrix_dict[name] - elif name in tq.Operator.parameterized_ops and name not in [ + elif name in tq.operator.parameterized_ops and name not in [ "QubitUnitary", "QubitUnitaryFast", "TrainableUnitary", diff --git a/torchquantum/layer/layers.py b/torchquantum/layer/layers.py index bd88fb13..9f129acc 100644 --- a/torchquantum/layer/layers.py +++ b/torchquantum/layer/layers.py @@ -391,7 +391,7 @@ def build_random_layer(self): ) else: operation = op(n_wires=n_op_wires, wires=op_wires) - elif op().name in tq.Operator.parameterized_ops: + elif op().name in tq.operator.parameterized_ops: operation = op(has_params=True, trainable=True, wires=op_wires) else: operation = op(wires=op_wires) diff --git a/torchquantum/operator/op_types.py b/torchquantum/operator/op_types.py index 786214a9..bdf35337 100644 --- a/torchquantum/operator/op_types.py +++ b/torchquantum/operator/op_types.py @@ -53,79 +53,6 @@ class NParamsEnum(IntEnum): class Operator(tq.QuantumModule): """The class for quantum operators.""" - fixed_ops = [ - "Hadamard", - "SHadamard", - "PauliX", - "PauliY", - "PauliZ", - "I", - "S", - "T", - "SX", - "CNOT", - "CZ", - "CY", - "SWAP", - "SSWAP", - "CSWAP", - "Toffoli", - "MultiCNOT", - "MultiXCNOT", - "Reset", - "EchoedCrossResonance", - "QFT", - "SDG", - "TDG", - "SXDG", - "CH", - "CCZ", - "ISWAP", - "CS", - "CSDG", - "CSX", - "CHadamard", - "DCX", - "C3X", - "C3SX", - "RCCX", - "RC3X", - "C4X", - ] - - parameterized_ops = [ - "RX", - "RY", - "RZ", - "RXX", - "RYY", - "RZZ", - "RZX", - "PhaseShift", - "Rot", - "MultiRZ", - "CRX", - "CRY", - "CRZ", - "CRot", - "U1", - "U2", - "U3", - "CU", - "CU1", - "CU2", - "CU3", - "QubitUnitary", - "QubitUnitaryFast", - "TrainableUnitary", - "TrainableUnitaryStrict", - "SingleExcitation", - "XXMINYY", - "XXPLUSYY", - "R", - "GlobalPhase", - ] - @property def name(self): """String for the name of the operator.""" diff --git a/torchquantum/operator/standard_gates/__init__.py b/torchquantum/operator/standard_gates/__init__.py index b6bcaf12..98f55997 100644 --- a/torchquantum/operator/standard_gates/__init__.py +++ b/torchquantum/operator/standard_gates/__init__.py @@ -133,7 +133,7 @@ __all__.extend(["U", "CH", "QubitUnitary", "QubitUnitaryFast"]) # add the dictionary -__all__.append("op_name_dict") +__all__.extend(["op_name_dict", "fixed_ops", "parameterized_ops"]) # create the operations dictionary op_name_dict = {x.op_name: x for x in all_variables} @@ -160,3 +160,6 @@ "cr": CU1, } ) + +fixed_ops = [a().__class__.__name__ for a in all_variables if a.num_params == 0] +parameterized_ops = [a().__class__.__name__ for a in all_variables if a.num_params > 0]