From 779e9e55bc9f4602b1445e18f208d386db67adb5 Mon Sep 17 00:00:00 2001 From: purpleyun Date: Sun, 26 Apr 2020 15:08:58 +0800 Subject: [PATCH 1/2] Update memory_saving_gradients.py Change the word "/read "to "/Read". --- memory_saving_gradients.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/memory_saving_gradients.py b/memory_saving_gradients.py index 0345bb5..107654b 100644 --- a/memory_saving_gradients.py +++ b/memory_saving_gradients.py @@ -87,9 +87,9 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] - fwd_ops = [op for op in fwd_ops if not '/read' in op.name] + fwd_ops = [op for op in fwd_ops if not '/Read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors - ts_all = [t for t in ts_all if '/read' not in t.name] + ts_all = [t for t in ts_all if '/Read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not From 15217612dbbf7cfac1ef2b8c5b0ab3e25a3eabb3 Mon Sep 17 00:00:00 2001 From: purpleyun Date: Sun, 26 Apr 2020 15:21:47 +0800 Subject: [PATCH 2/2] Update memory_saving_gradients.py --- memory_saving_gradients.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/memory_saving_gradients.py b/memory_saving_gradients.py index 107654b..931aec9 100644 --- a/memory_saving_gradients.py +++ b/memory_saving_gradients.py @@ -87,8 +87,10 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] + fwd_ops = [op for op in fwd_ops if not '/read' in op.name] fwd_ops = [op for op in fwd_ops if not '/Read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors + ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = [t for t in ts_all if '/Read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys)