我在TensorFlow2.x中实现了一个定制层。我的要求是,程序应该在返回输出之前检查一个条件。
class SimpleRNN_cell(tf.keras.layers.Layer):
def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
super(SimpleRNN_cell, self).__init__()
pass
def call(self, X, hidden_state, return_state=True):
y = tf.constant(5)
if return_state == True:
return y, self.h
else:
return y我的问题是:我应该继续使用当前的代码(假设tape.gradient(Loss, self.trainable_weights)可以正常工作),还是应该使用tf.cond()。此外,如果可能,请解释在哪里使用tf.cond(),在哪里不使用。我还没有找到关于这个话题的太多内容。
发布于 2021-02-13 20:41:25
仅当基于可微计算图中的数据执行条件评估时,tf.cond才相关。(https://www.tensorflow.org/api_docs/python/tf/cond)这在默认为图形模式的TF 1.0中尤其必要。对于GradientTape模式,python系统还允许使用if ...: (https://www.tensorflow.org/guide/autodiff#control_flow)等python构造执行条件数据流
然而,对于仅仅基于配置参数提供不同的行为,这些行为不依赖于来自计算图的数据,并且在模型运行时期间被修复,使用简单的python if语句是正确的。
https://stackoverflow.com/questions/66184641
复制相似问题