Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于harry_potter_lstm.py 152行 tf.concat有个疑问 #1

Open
lu161513 opened this issue Feb 9, 2019 · 2 comments
Open

关于harry_potter_lstm.py 152行 tf.concat有个疑问 #1

lu161513 opened this issue Feb 9, 2019 · 2 comments

Comments

@lu161513
Copy link

lu161513 commented Feb 9, 2019

按照示例中我们有的是三维的:
[n_seqs, n_sequencd_length, lstm_num_units]
现在要变成二维的:
[n_seqs * n_sequencd_length, lstm_num_units]
是不是应该在第0维度上进行拼接?axis=0而不是axis=1?

比如我有下面的数据:
t1 = [ [[0,1],[2,3],[3,4],[4,5]], [[5,6],[6,7],[7,8],[8,9]], [[9,10],[10,11],[11,12],[12,13]] ] t2=tf.concat(t1,axis=0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(t2)) # t2是[[ 0 1] [ 2 3] [ 3 4] [ 4 5] [ 5 6] [ 6 7] [ 7 8] [ 8 9] [ 9 10] [10 11] [11 12] [12 13]]
如果按照axis=1拼接的话,是在第二个维度拼接,会变成:
[[ 0 1 5 6 9 10] [ 2 3 6 7 10 11] [ 3 4 7 8 11 12] [ 4 5 8 9 12 13]]

是我理解错了还是这里有问题,感觉输出的最终维度(列)应该就是lstm的units

@Dod-o
Copy link
Owner

Dod-o commented Feb 11, 2019

按照示例中我们有的是三维的:
[n_seqs, n_sequencd_length, lstm_num_units]
现在要变成二维的:
[n_seqs * n_sequencd_length, lstm_num_units]
是不是应该在第0维度上进行拼接?axis=0而不是axis=1?

比如我有下面的数据:
t1 = [ [[0,1],[2,3],[3,4],[4,5]], [[5,6],[6,7],[7,8],[8,9]], [[9,10],[10,11],[11,12],[12,13]] ] t2=tf.concat(t1,axis=0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(t2)) # t2是[[ 0 1] [ 2 3] [ 3 4] [ 4 5] [ 5 6] [ 6 7] [ 7 8] [ 8 9] [ 9 10] [10 11] [11 12] [12 13]]
如果按照axis=1拼接的话,是在第二个维度拼接,会变成:
[[ 0 1 5 6 9 10] [ 2 3 6 7 10 11] [ 3 4 7 8 11 12] [ 4 5 8 9 12 13]]

是我理解错了还是这里有问题,感觉输出的最终维度(列)应该就是lstm的units

我做了一些测试,之前对tf.concat的理解可能存在一些误解。

t1 = [[[0, 1], [2, 3], [3, 4], [4, 5]], [[5, 6], [6, 7], [7, 8], [8, 9]], [[9, 10], [10, 11], [11, 12], [12, 13]]]
t1 = tf.Variable(t1)
print(t1.shape)		#(3, 4, 2)

t2 = tf.concat([t1], axis=1)
print(t2.shape)		#(3, 4, 2)
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(t1))
	print('----------------')
	print(sess.run(t2))

运行结果中t1和t2仍然是两个完全相同的矩阵,包括形状和内部元素。我认为主要原因在于tf.concat被用于两个矩阵之间的连接,对于单个张量的传入并不能起作用,对tf.concat传入两个及以上参数后才能正常工作:

t1 = [[[0, 1], [2, 3], [3, 4], [4, 5]], [[5, 6], [6, 7], [7, 8], [8, 9]], [[9, 10], [10, 11], [11, 12], [12, 13]]]
t1 = tf.Variable(t1)
print(t1.shape)     #(3, 4, 2)

t2 = tf.concat([t1], axis=1)
print(t2.shape)     #(3, 4, 2)

t3 = tf.concat([t1, t1], axis=1)
print(t3.shape)     #(3, 8, 2)

因此下方第一行的concat没有产生作用,实际上应该去掉,程序能正常运行主要依靠第二行的reshape。

lstm_output = tf.concat(lstm_output, axis=1)
lstm_output = tf.reshape(lstm_output, shape=(-1, in_size))

至于t1未转换成张量前,以list形式单独传入可以正常运行,应该是list被迫降了1维,例如上面的t1是(3, 4, 2),最外侧括号被认为是多个矩阵拼接用的括号,参数被转换为3个(4, 2)的矩阵,所以对于原始的t1,虽然t1是三维的,但axis只能为0或1,不能为2。

最后,非常感谢issues的提出,我已经将代码修改重新运行,如果最后结果无误,会更新github上的代码。

@lu161513
Copy link
Author

这样就没问题了,实际起作用的是后面的reshape,输出的维度就是in_size也就是lstm的unit,行数根据两个相乘的来
谢谢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants