使用vgg.h5模型+ Keras (GPU上的Tensorflow后端)进行实时对象分类.效果很好。
然后,我尝试使用带有vgg.h5权重的纯tensorflow图:
那么,也许任何人都有在tensorflow从头开始构建vgg16的经验,并能提供帮助吗?为什么tensorflow作为Keras后端很好,但是纯tensorflow (具有相同的权重)不能计算完全连接输出?在Keras中是否有实现完全连接(密集)层的其他优化?
发布于 2018-01-21 20:38:02
下面是代码的一个测试变体,在几个地方使用打印张量形状的工具:
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
# mock the previous layer's output with a placeholder
pool5_input = tf.placeholder(dtype = tf.float32, shape = (None,7,7,512))
# insert a print operation to print the shape
pool5 = tf.Print(pool5_input, [ tf.shape(pool5_input) ], "pool5 shape is ", summarize = 4)
layer_name = 'fc1'
wd = tf.Variable(np.ones((25088, 4096), dtype='float32'), trainable=False, name=layer_name+'_wd')
bd = tf.Variable(np.ones((4096,), dtype='float32'), trainable=False, name=layer_name+'_bd')
layer_shape = [-1, wd.get_shape().as_list()[0]]
print('layer_shape:', layer_shape)
fc1_flat = tf.reshape(pool5, shape=layer_shape)
fc1_flat = tf.Print(fc1_flat, [ tf.shape(fc1_flat) ], "fc1_flat shape is ")
fc1 = tf.nn.relu( tf.nn.bias_add( tf.matmul(fc1_flat, wd, name=layer_name), bd ) )
fc1 = tf.Print(fc1, [ tf.shape(fc1) ], "fc1 shape is ")
import time
sess.run(tf.global_variables_initializer())
# evaluate network for in input of (minibatch_size, 7, 7, 512)
minibatch_size = 32
start = time.time()
output = sess.run(fc1, feed_dict = { pool5_input: np.ones((minibatch_size, 7, 7, 512), dtype = 'float32')})
elapsed = time.time() - start
print("time to evaluate fully connected layer for minibatch size %d: %.3f seconds" % (minibatch_size, elapsed))
print("output shape is",output.shape)我得到以下输出:
layer_shape: [-1, 25088]
...: I tensorflow/core/kernels/logging_ops.cc:79] pool5 shape is [32 7 7 512]
...: I tensorflow/core/kernels/logging_ops.cc:79] fc1_flat shape is [32 25088]
...: I tensorflow/core/kernels/logging_ops.cc:79] fc1 shape is [32 4096]
time to evaluate fully connected layer for minibatch size 32: 0.329 seconds
output shape is (32, 4096)所以对我来说,一个32的小批次的大小只需要不到1秒的时间(在GPU上)。
您可以在代码中插入类似的tf.Print()语句,并验证您是否具有相同的(或类似的)维度。通过将维度的大小乘以,您可以看到每个阶段使用了多少内存。
https://stackoverflow.com/questions/48368197
复制相似问题