用法:
a1 = tf.nn.embedding_lookup(a, index)
index是索引,a是输入,通过index来选取a中对应的元素返回给a1,注意index是从0开始算起
例子:
import tensorflow as tfa = tf.constant([5, 6, 7, 8, 9])
index = tf.constant([1, 3])
a1 = tf.nn.embedding_lookup(a, index)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(a1))
输出:[6 8],可知输出了a中索引为1和3的元素
当a是二维数组时,输出第index行元素
import tensorflow as tfa = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
index = tf.constant([1])
a1 = tf.nn.embedding_lookup(a, index)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(a))print(sess.run(a1))
输出:
[[0 1 2]
[3 4 5]
[6 7 8]]
[[3 4 5]]