Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nested op_name scope in reduce functions #367

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

gustafsson
Copy link
Contributor

This commit makes nested op_name scopes work for the reduce_-functions.

@oxinabox
Copy link
Collaborator

This has broken some tests.
https://travis-ci.org/malmaud/TensorFlow.jl/jobs/353723466#L1184

I've not looked too closely as to why,
Might the the tests are too fragile, might not.

The name for reduce_-functions names the scope for which the operations needed to perform a reduce.
@codecov
Copy link

codecov bot commented Mar 16, 2018

Codecov Report

Merging #367 into master will increase coverage by 0.09%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #367      +/-   ##
==========================================
+ Coverage   64.73%   64.82%   +0.09%     
==========================================
  Files          50       50              
  Lines        2895     2900       +5     
==========================================
+ Hits         1874     1880       +6     
+ Misses       1021     1020       -1
Impacted Files Coverage Δ
src/ops/math.jl 79.36% <100%> (+1.39%) ⬆️
src/ops/rnn_cell.jl 77.06% <0%> (+0.21%) ⬆️
src/shape_inference.jl 85.31% <0%> (+0.28%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 69ee69e...5f640ca. Read the comment docs.

@gustafsson
Copy link
Contributor Author

Thanks, I updated the failing test.

Since reduce without axis is implemented with multiple nodes I think that @tf Ysum1 = reduce_sum(Y) should create a scope that contains these nodes. Since the last of those nodes is a sum the output node is then named Ysum1/sum.

If an axis is given explicitly Ysum5 = reduce_sum(Y, axis=2) there's just one node, and it can then take the name Ysum5.

@oxinabox
Copy link
Collaborator

It bothers me that we have a function that takes a name argument and returns a Node.
and even if we provide the name argument directly, the returned node does not have that name.

I'm not sure what the resolution is.

This is definitely a breaking change -- it will break some of my code I am certain.

@gustafsson
Copy link
Contributor Author

gustafsson commented Mar 16, 2018

Ah, just because there is a node named reduce_sum/rank doesn't prevent us from creating a node named reduce_sum. I thought it did.

However before this PR the node was named reduce, is that inherited from python? I think reduce_sum (i.e reduce_$reduction)would be a better name.

@oxinabox
Copy link
Collaborator

Hmm,
Can you pop out some python graphs and compare?

@gustafsson
Copy link
Contributor Author

Here are some graphs. tensorflow.py v1.0 and v1.8 were near identical.

TensorFlow.jl known rank:
reduce-known_rank-tensorflow jl

tensorflow.py v1.8 known rank:
reduce-known_rank-tensorflow-v1 8

TensorFlow.jl unknown rank:
reduce-unknown_rank-tensorflow jl

tensorflow.py v1.8 unknown rank:
reduce-unknown_rank-tensorflow-v1 8

@gustafsson
Copy link
Contributor Author

Code for the above:

julia

using TensorFlow; tf = TensorFlow

v = tf.placeholder(Float32, shape=[2,3,5])

myop = tf.with_op_name("myop") do
    tf.reduce_sum(v) + tf.reduce_sum(v) + tf.reduce_sum(v, name="s")
end

myop_dim = tf.with_op_name("myop3") do
    tf.reduce_sum(v, axis=3) + tf.reduce_sum(v, axis=3) + tf.reduce_sum(v, axis=3, name="dims")
end

myop + myop_dim

sess = tf.Session()
fw = tf.summary.FileWriter("models/tensorflow.jl")

python

import tensorflow as tf

v = tf.placeholder(tf.float32, shape=[2,3,4])

with(tf.name_scope("myop")):
    myop = tf.reduce_sum(v) + tf.reduce_sum(v) + tf.reduce_sum(v, name='s')

with(tf.name_scope("myop2")):
    myop2 = tf.reduce_sum(v, 2) + tf.reduce_sum(v, 2) + tf.reduce_sum(v, 2, name="dims")

myop + myop2

sess = tf.Session()
fw = tf.summary.FileWriter('models/tensorflow.py')
fw.add_graph(sess.graph)

@gustafsson
Copy link
Contributor Author

This is what the graph from TensorFlow.jl looks like before this PR (both the case with known rank and the case with unknown rank produces the same graph):

reduce-known-or-unknown-before-367

@oxinabox
Copy link
Collaborator

This looks good to me.
@malmaud can you give this a once over?

@malmaud
Copy link
Owner

malmaud commented Apr 27, 2018 via email

TensorFlow.jl typically doesn't allow the same name to be used twice whereas tensorflow.py adds a unique postfix.
@malmaud
Copy link
Owner

malmaud commented Jun 20, 2018

LGTM. Seems to be a a conflict now though.

@oxinabox oxinabox closed this Jun 21, 2018
@oxinabox oxinabox reopened this Jun 21, 2018
@malmaud
Copy link
Owner

malmaud commented May 17, 2019

@gustafsson Are you interested in rebasing this on the current TensorFlow.jl master?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants