我正在做我的第一步学习TF和有一些困难的训练RNN。
我的玩具问题是这样的:一个两层的LSTM +密集层网络被输入原始音频数据,并且应该测试声音中是否存在一定的频率。
因此,网络应该从1到1映射浮动(音频数据序列)到浮动(预选的频率体积)。
我已经在Keras上完成了这个任务,并看到了类似的TFLearn解决方案,但我希望以一种相对高效的方式在裸Tensorflow上实现这一点。
我所做的:
lstm = rnn_cell.BasicLSTMCell(LSTM_SIZE,state_is_tuple=True,forget_bias=1.0)
lstm = rnn_cell.DropoutWrapper(lstm)
stacked_lstm = rnn_cell.MultiRNNCell([lstm] * 2,state_is_tuple=True)
outputs, states = rnn.dynamic_rnn(stacked_lstm, in, dtype=tf.float32)
outputs = tf.transpose(outputs, [1, 0, 2])
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1)
network= tf.matmul(last, W) + b
# cost function, optimizer etc...在训练期间,我用(BATCH_SIZE,SEQUENCE_LEN,1)批给它,看起来损失是正确的,但是我不知道如何用经过训练的网络来预测。
我的(很多)问题:如何使这个网络直接从Tensorflow返回一个序列,而不对每个示例返回python (输入一个序列并预测相同大小的序列)?
如果我想一次预测一个样本并在python中迭代,那么正确的方法是什么?
在测试期间是否需要dynamic_rnn,还是只用于在培训期间展开BPTT?为什么dynamic_rnn要返回所有反向传播步骤张量?这些是展开网络的每一层的输出,对吗?
发布于 2016-10-11 07:49:49
经过一些研究后:
如何使这个网络直接从Tensorflow返回一个序列,而不对每个示例返回python (输入一个序列并预测相同大小的序列)?
您可以使用state_saving_rnn
class Saver():
def __init__(self):
self.d = {}
def state(self, name):
if not name in self.d:
return tf.zeros([1,LSTM_SIZE],tf.float32)
return self.d[name]
def save_state(self, name, val):
self.d[name] = val
return tf.identity('save_state_name') #<-important for control_dependencies
outputs, states = rnn.state_saving_rnn(stacked_lstm, inx, Saver(),
('lstmstate', 'lstmstate2', 'lstmstate3', 'lstmstate4'),sequence_length=[EVAL_SEQ_LEN])
#4 states are for two layers of lstm each has hidden and CEC variables to restore
network = [tf.matmul(outputs[-1], W) for i in xrange(EVAL_SEQ_LEN)]一个问题是,state_saving_rnn使用的是rnn(),而不是dynamic_rnn(),因此,如果要输入长序列,则在编译时展开EVAL_SEQ_LEN步骤,您可能希望用dynamic_rnn重新实现state_saving_rnn。
如果我想一次预测一个样本并在python中迭代,那么正确的方法是什么?
您可以使用dynamic_rnn并提供initial_state。这可能和state_saving_rnn一样高效。查看state_saving_rnn实现以获得参考
在测试期间是否需要dynamic_rnn,还是只用于在培训期间展开BPTT?为什么dynamic_rnn要返回所有反向传播步骤张量?这些是展开网络的每一层的输出,对吗?
dynamic_rnn确实在运行时进行展开,类似于编译时rnn()。我想它会返回所有的步骤,让您在其他地方将图分支--在较少的时间步骤之后。在使用一个时间步长输入*当前状态->一个输出的网络中,像上面描述的新状态在测试中不需要,但可以用于训练截断的时间回传
https://stackoverflow.com/questions/39934043
复制相似问题