Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Allow to split nodes differently #956

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions tensornetwork/network_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,18 @@ def split_node(
left_name: Optional[Text] = None,
right_name: Optional[Text] = None,
edge_name: Optional[Text] = None,
split: Text = "symmetric",
) -> Tuple[AbstractNode, AbstractNode, Tensor]:
"""Split a `node` using Singular Value Decomposition.

Let :math:`M` be the matrix created by flattening `left_edges` and
`right_edges` into 2 axes.
Let :math:`U S V^* = M` be the SVD of :math:`M`.
This will split the network into 2 nodes.
The left node's tensor will be :math:`U \\sqrt{S}`
The left node's tensor will be :math:`U \\sqrt{S}`
and the right node's tensor will be
:math:`\\sqrt{S} V^*` where :math:`V^*` is the adjoint of :math:`V`.
The way the network is split can be modified by the `split` argument.

The singular value decomposition is truncated if `max_singular_values` or
`max_truncation_err` is not `None`.
Expand Down Expand Up @@ -180,6 +182,10 @@ def split_node(
edge_name: The name of the new `Edge` connecting the new left and
right node. If `None`, a name will be generated automatically.
The new axis will get the same name as the edge.
split: Where to split the network. The default 'symmetric' splits the node
into :math:`U \\sqrt{S}` :math:`\\sqrt{S} V^*`. Alternatively,
``split='left'`` splits the network into :math:`U` and :math:`S V^*`,
``split='right'`` splits it into :math:`U S` and :math:`V^*`.

Returns:
A tuple containing:
Expand All @@ -198,6 +204,9 @@ def split_node(
if not hasattr(node, 'backend'):
raise AttributeError('Node {} of type {} has no `backend`'.format(
node, type(node)))
if split not in ('symmetric', 'left', 'right'):
raise ValueError("`split` has to be in {'symmetric', 'left', 'right'}"
f" (given: {split})")

if node.axis_names and edge_name:
left_axis_names = []
Expand All @@ -221,9 +230,14 @@ def split_node(
max_singular_values,
max_truncation_err,
relative=relative)
sqrt_s = backend.sqrt(s)
u_s = backend.broadcast_right_multiplication(u, sqrt_s)
vh_s = backend.broadcast_left_multiplication(sqrt_s, vh)
if split == 'symmetric':
sqrt_s = backend.sqrt(s)
u_s = backend.broadcast_right_multiplication(u, sqrt_s)
vh_s = backend.broadcast_left_multiplication(sqrt_s, vh)
elif split == "left":
u_s, vh_s = u, backend.broadcast_left_multiplication(s, vh)
elif split == "right":
u_s, vh_s = backend.broadcast_right_multiplication(u, s), vh

left_node = Node(u_s,
name=left_name,
Expand Down