首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow tf.case馈电阵列(1级形状)

TensorFlow tf.case馈电阵列(1级形状)
EN

Stack Overflow用户
提问于 2019-06-17 16:02:23
回答 1查看 165关注 0票数 2

以下是tf.case文档中的示例:

代码语言:javascript
复制
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
            default=f3, exclusive=True)

我也想这样做,但允许使用feed_dict作为输入,如下所示:

代码语言:javascript
复制
x = tf.placeholder(tf.float32, shape=[None])
y = tf.placeholder(tf.float32, shape=[None])
z = tf.placeholder(tf.float32, shape=[None])
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
            default=f3, exclusive=True)
print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
# result should be [17, -1, -1, 23]

因此,基本上,我想给三个长度相同的int-arrays提供数据,并接收一个包含17、23或-1的int-values数组。不幸的是,上面的代码给出了错误:

ValueError: Shape必须为0级,但对于输入形状为:、的“case/cond/Switch”(op:“Switch”),则为1级。

我明白,tf.case需要布尔标量张量输入值,但是有什么方法可以实现我想要的吗?我也尝试过tf.cond,但没有成功。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-06-17 17:01:28

为此使用tf.where,例如( seems to be on its way,但据我所知还没有,所以您必须确保所有参数的大小与一个向量相同,或者tf.filltf.tile.)。

代码语言:javascript
复制
import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.placeholder(tf.float32, shape=[None])
    y = tf.placeholder(tf.float32, shape=[None])
    z = tf.placeholder(tf.float32, shape=[None])
    ones = tf.ones_like(x)
    r = tf.where(x < y, 17 * ones, tf.where(x > z, 23 * ones, -ones))
    print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
    # [17. -1. -1. 23.]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56635027

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档