首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用tf.cond()检查自定义图层的调用方法中的条件

使用tf.cond()检查自定义图层的调用方法中的条件
EN

Stack Overflow用户
提问于 2021-02-13 19:34:54
回答 1查看 191关注 0票数 0

我在TensorFlow2.x中实现了一个定制层。我的要求是,程序应该在返回输出之前检查一个条件。

代码语言:javascript
复制
class SimpleRNN_cell(tf.keras.layers.Layer):
    def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
        super(SimpleRNN_cell, self).__init__()
        pass        
    def call(self, X, hidden_state, return_state=True):
        y = tf.constant(5)
        if return_state == True:
            return y, self.h
        else:
            return y

我的问题是:我应该继续使用当前的代码(假设tape.gradient(Loss, self.trainable_weights)可以正常工作),还是应该使用tf.cond()。此外,如果可能,请解释在哪里使用tf.cond(),在哪里不使用。我还没有找到关于这个话题的太多内容。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-02-13 20:41:25

仅当基于可微计算图中的数据执行条件评估时,tf.cond才相关。(https://www.tensorflow.org/api_docs/python/tf/cond)这在默认为图形模式的TF 1.0中尤其必要。对于GradientTape模式,python系统还允许使用if ...: (https://www.tensorflow.org/guide/autodiff#control_flow)等python构造执行条件数据流

然而,对于仅仅基于配置参数提供不同的行为,这些行为不依赖于来自计算图的数据,并且在模型运行时期间被修复,使用简单的python if语句是正确的。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66184641

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档