From 0e276289bec723f096dd5d20316db24436dea957 Mon Sep 17 00:00:00 2001 From: Weh Andreas Date: Fri, 10 Dec 2021 15:48:03 +0100 Subject: [PATCH] Allow to split nodes differently --- tensornetwork/network_operations.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tensornetwork/network_operations.py b/tensornetwork/network_operations.py index 9817511b4..dc9428891 100644 --- a/tensornetwork/network_operations.py +++ b/tensornetwork/network_operations.py @@ -137,6 +137,7 @@ def split_node( left_name: Optional[Text] = None, right_name: Optional[Text] = None, edge_name: Optional[Text] = None, + split: Text = "symmetric", ) -> Tuple[AbstractNode, AbstractNode, Tensor]: """Split a `node` using Singular Value Decomposition. @@ -144,9 +145,10 @@ def split_node( `right_edges` into 2 axes. Let :math:`U S V^* = M` be the SVD of :math:`M`. This will split the network into 2 nodes. - The left node's tensor will be :math:`U \\sqrt{S}` + The left node's tensor will be :math:`U \\sqrt{S}` and the right node's tensor will be :math:`\\sqrt{S} V^*` where :math:`V^*` is the adjoint of :math:`V`. + The way the network is split can be modified by the `split` argument. The singular value decomposition is truncated if `max_singular_values` or `max_truncation_err` is not `None`. @@ -180,6 +182,10 @@ def split_node( edge_name: The name of the new `Edge` connecting the new left and right node. If `None`, a name will be generated automatically. The new axis will get the same name as the edge. + split: Where to split the network. The default 'symmetric' splits the node + into :math:`U \\sqrt{S}` :math:`\\sqrt{S} V^*`. Alternatively, + ``split='left'`` splits the network into :math:`U` and :math:`S V^*`, + ``split='right'`` splits it into :math:`U S` and :math:`V^*`. Returns: A tuple containing: @@ -198,6 +204,9 @@ def split_node( if not hasattr(node, 'backend'): raise AttributeError('Node {} of type {} has no `backend`'.format( node, type(node))) + if split not in ('symmetric', 'left', 'right'): + raise ValueError("`split` has to be in {'symmetric', 'left', 'right'}" + f" (given: {split})") if node.axis_names and edge_name: left_axis_names = [] @@ -221,9 +230,14 @@ def split_node( max_singular_values, max_truncation_err, relative=relative) - sqrt_s = backend.sqrt(s) - u_s = backend.broadcast_right_multiplication(u, sqrt_s) - vh_s = backend.broadcast_left_multiplication(sqrt_s, vh) + if split == 'symmetric': + sqrt_s = backend.sqrt(s) + u_s = backend.broadcast_right_multiplication(u, sqrt_s) + vh_s = backend.broadcast_left_multiplication(sqrt_s, vh) + elif split == "left": + u_s, vh_s = u, backend.broadcast_left_multiplication(s, vh) + elif split == "right": + u_s, vh_s = backend.broadcast_right_multiplication(u, s), vh left_node = Node(u_s, name=left_name,