首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras自定义softmax层:是否可以在softmax层的输出中将基于零的输出神经元设置为0作为输入层中的数据?

Keras自定义softmax层:是否可以在softmax层的输出中将基于零的输出神经元设置为0作为输入层中的数据?
EN

Stack Overflow用户
提问于 2018-12-19 20:20:58
回答 2查看 1.5K关注 0票数 1

我有一个神经网络,最后一层有10个输出神经元,使用softmax激活。我还确切地知道,基于输入值,输出层中的某些神经元应该具有0值。所以我有一个特殊的输入层,有10个神经元,每个神经元要么是0,要么是1。

如果3号输入神经元也是0,有没有可能强制3号输出神经元的值= 0?

代码语言:javascript
复制
action_input = Input(shape=(10,), name='action_input')
...

x = Dense(10,  kernel_initializer = RandomNormal(),bias_initializer = RandomNormal() )(x)
x = Activation('softmax')(x)

我知道有一种方法可以屏蔽神经网络外部输出层的结果,并对所有非零相关的输出进行整形(以便总和为1)。但我想在网络中解决这个问题,并在网络的训练中使用它。我应该为此使用自定义层吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-12-23 14:47:18

最后,我想出了这个代码:

代码语言:javascript
复制
from keras import backend as K
import tensorflow as tf
def mask_output2(x):
    inp, soft_out = x
    # add a very small value in order to avoid having 0 everywhere
    c = K.constant(0.0000001, dtype='float32', shape=(32, 13))
    y = soft_out + c

    y = Lambda(lambda x: K.switch(K.equal(x[0],0), x[1], K.zeros_like(x[1])))([inp, soft_out])
    y_sum =  K.sum(y, axis=-1)

    y_sum_corrected = Lambda(lambda x: K.switch(K.equal(x[0],0), K.ones_like(x[0]), x[0] ))([y_sum])

    y_sum_corrected = tf.divide(1,y_sum_corrected)

    y = tf.einsum('ij,i->ij', y, y_sum_corrected)
    return y
票数 1
EN

Stack Overflow用户

发布于 2018-12-19 20:34:57

您可以使用Lambda层和K.switch检查输入中的零值,并在输出中对其进行遮罩:

代码语言:javascript
复制
from keras import backend as K

inp = Input((5,))
soft_out = Dense(5, activation='softmax')(inp)
out = Lambda(lambda x: K.switch(x[0], x[1], K.zeros_like(x[1])))([inp, soft_out])

model = Model(inp, out)

model.predict(np.array([[0, 3, 0, 2, 0]]))
# array([[0., 0.35963967, 0., 0.47805876, 0.]], dtype=float32)

但是,正如您所看到的,输出总和不再是1。如果希望总和为1,则可以重新缩放这些值:

代码语言:javascript
复制
def mask_output(x):
    inp, soft_out = x
    y = K.switch(inp, soft_out, K.zeros_like(inp))
    y /= K.sum(y, axis=-1)
    return y

# ...
out = Lambda(mask_output)([inp, soft_out])
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53851175

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档