以下是tf.case文档中的示例:
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)我也想这样做,但允许使用feed_dict作为输入,如下所示:
x = tf.placeholder(tf.float32, shape=[None])
y = tf.placeholder(tf.float32, shape=[None])
z = tf.placeholder(tf.float32, shape=[None])
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
# result should be [17, -1, -1, 23]因此,基本上,我想给三个长度相同的int-arrays提供数据,并接收一个包含17、23或-1的int-values数组。不幸的是,上面的代码给出了错误:
ValueError: Shape必须为0级,但对于输入形状为:、的“case/cond/Switch”(op:“Switch”),则为1级。
我明白,tf.case需要布尔标量张量输入值,但是有什么方法可以实现我想要的吗?我也尝试过tf.cond,但没有成功。
发布于 2019-06-17 17:01:28
为此使用tf.where,例如( seems to be on its way,但据我所知还没有,所以您必须确保所有参数的大小与一个向量相同,或者tf.fill,tf.tile.)。
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.placeholder(tf.float32, shape=[None])
y = tf.placeholder(tf.float32, shape=[None])
z = tf.placeholder(tf.float32, shape=[None])
ones = tf.ones_like(x)
r = tf.where(x < y, 17 * ones, tf.where(x > z, 23 * ones, -ones))
print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
# [17. -1. -1. 23.]https://stackoverflow.com/questions/56635027
复制相似问题