我有2级-1布尔张量,我需要应用tf.case,根据每个输入向量元素的值,逐个元素给输出张量赋值。我的代码如下:
f1 = lambda: tf.constant(1)
f2 = lambda: tf.constant(2)
f3 = lambda: tf.constant(0)
result = tf.case({c1 : f1, c2 : f2}, default=f3)其中c1和c2是秩-1张量,输出是相同形状的张量.
发布于 2020-10-04 16:33:26
解决这个问题的答案是:
f1 = tf.math.add(tf.zero_like(c1),1)
f2 = tf.math.add(tf.zero_like(c1),2)
f3 = tf.zero_like(c1)
result = tf.where(c1, f1, f3)
result = tf.where(c2, f2, result)发布于 2020-10-02 00:25:05
tf.case()只接受对标量布尔值(秩-0张量)和可调用项(函数)。
在您的例子中,您正在传递排名-1的张量,因此您得到了异常:
形状必须为0级,但对于“case/cond/Switch”则为1级。
为了在tf.case中使用非零秩张量,您必须先将张量降为标量值。
您可以使用类似于tf.reduce_all或tf.reduce_any的东西(相当于numpy.all和numpy.any)。
x = tf.constant([True, False]) #your original rank-1 tensor
c1 = tf.reduce_all(x) #False: Computes the "logical AND" of all elements
c2 = tf.reduce_any(x) #True : Computes the "logical OR" of all elements
result = tf.case({c1 : f1, c2 : f2}, default=f3) #returns f2https://stackoverflow.com/questions/64138959
复制相似问题