Tensorflow中batch normalization的用法

一、原理
公式如下:

$$y=\gamma(x- \mu)/\sigma+\beta$$

其中$x$是输入,$y$是输出,$\mu$是均值,$\sigma$是方差,$\gamma$和$\beta$是缩放(scale)、偏移(offset)系数。
一般来讲,这些参数都是基于channel来做的,比如输入$x$是一个16x32x32x128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中,$\gamma$和$\beta$是可有可无的,有的话,就是一个可以学习的参数(参与前向后向),没有的话,就简化成$y=(x- \mu)/\sigma$。而$\mu$和$\sigma$,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值

为什么要使用batch normalization?
神经网络学习过程本质是为了学习数据分布,一旦测试数据与训练数据分布不同,则网络的泛化能力就会大大降低;另一方面,一旦每批训练数据的分布各不相同,那么网络就要在每次迭代都去学习适应不同的分布,这样会大大降低网络的训练速度。

网络一旦train起来,参数就会更新,除了输入层数据以外,(输入层数据已经人为为每个样本归一化),后面网络每一层的输入数据分布是一直在发生变化的,因为在训练的时候,前面层训练参数的更新将导致后面层输入数据分布的变化。

将网络中间层在训练过程中,数据分布的改变称为Internal Convariate Shift。Batch Normalization就是要解决在训练过程中,中间层数据分布发生改变的情况。

二、tensorflow中使用
tensorflow中batch normalization的实现主要有下面三个:
tf.nn.batch_normalization
tf.layers.batch_normalization
tf.contrib.layers.batch_norm
封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因此下面的步骤都是基于这个。
**相关资料(TENSORFLOW GUIDE: BATCH NORMALIZATION):http://ruishu.io/2016/12/27/batchnorm/
Batch Normalization导读)https://blog.csdn.net/malefactor/article/details/51476961 (How to correctly use the tf.layers.batch_normalization() in tensorflow?
https://stackoverflow.com/questions/46573345/how-to-correctly-use-the-tf-layers-batch-normalization-in-tensorflow\*\*
三、训练
训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算$\mu$和$\sigma$的滑动平均(测试时会用到)。

1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 2 with tf.control_dependencies(update_ops): 3 train_op = optimizer.minimize(loss) 4

四、测试
测试时需要注意一点,输入参数training=False

五、预测
预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:

1var_list = tf.trainable_variables() 2saver = tf.train.Saver(var_list=var_list, max_to_keep=5) 3

但使用了batch_normalization,$\gamma$和$\beta$是可训练参数没错,$\mu$和$\sigma$不是,它们紧紧是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到$\mu$和$\sigma$。利用tf.moving_average_variables()也没办法获取bn层中的$\mu$和$\sigma$,好在所有的参数都在tf.global_variable()中,因此可以这么写:

1var_list = tf.trainable_variables() 2g_list = tf.global_variables() 3bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] 4bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] 5var_list += bn_moving_vars 6saver = tf.train.Saver(var_list=var_list, max_to_keep=5) 7

按照上述方法,即可把$\mu$和$\sigma$保存下来,读取模型预测时也不会报错,当然输入参数training=False还是要的。
注意上面有个不严谨的地方,因为网络结构中只有bn层包含moving_mean和moving_variance,因此只根据这两个字符串做了过滤,如果你的网络结构中其他层也有这两个参数,但你不需要保存,建议使用诸如bn/moving_mean的字符串进行过滤。

六、基于mnist的示例
包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计得到的mean和var。注意示例中需要下载mnist数据集,要保持电脑可以联网。

bn_train.py

1import tensorflow as tf 2import os 3from tensorflow.examples.tutorials.mnist import input_data 4 5tf.logging.set_verbosity(tf.logging.INFO) 6 7if __name__ == '__main__': 8 mnist = input_data.read_data_sets('mnist', one_hot=True) 9 x = tf.placeholder(tf.float32, [None, 784]) 10 y_ = tf.placeholder(tf.float32, [None, 10]) 11 image = tf.reshape(x, [-1, 28, 28, 1]) 12 conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same', 13 activation=tf.nn.relu, 14 kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), 15 name='conv1') 16 bn1 = tf.layers.batch_normalization(conv1, training=True, name='bn1') 17 pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1') 18 conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same', 19 activation=tf.nn.relu, 20 kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), 21 name='conv2') 22 bn2 = tf.layers.batch_normalization(conv2, training=True, name='bn2') 23 pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2') 24 25 flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer') 26 weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32, 27 initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights') 28 biases = tf.get_variable(shape=[10], dtype=tf.float32, 29 initializer=tf.constant_initializer(0.0), name='fc_biases') 30 logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output') 31 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output)) 32 pred_label = tf.argmax(logit_output, 1) 33 label = tf.argmax(y_, 1) 34 accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32)) 35 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 36 global_step = tf.get_variable('global_step', [], dtype=tf.int32, 37 initializer=tf.constant_initializer(0), trainable=False) 38 learning_rate = tf.train.exponential_decay(learning_rate=0.1, global_step=global_step, decay_steps=5000, 39 decay_rate=0.1, staircase=True) 40 opt = tf.train.AdadeltaOptimizer(learning_rate=learning_rate, name='optimizer') 41 with tf.control_dependencies(update_ops): 42 grads = opt.compute_gradients(cross_entropy) 43 train_op = opt.apply_gradients(grads, global_step=global_step) 44 45 tf_config = tf.ConfigProto() 46 tf_config.gpu_options.allow_growth = True 47 tf_config.allow_soft_placement = True 48 sess = tf.InteractiveSession(config=tf_config) 49 sess.run(tf.global_variables_initializer()) 50 51 # only save trainable and bn variables 52 var_list = tf.trainable_variables() 53 if global_step is not None: 54 var_list.append(global_step) 55 g_list = tf.global_variables() 56 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] 57 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] 58 var_list += bn_moving_vars 59 saver = tf.train.Saver(var_list=var_list,max_to_keep=5) 60 # save all variables 61 # saver = tf.train.Saver(max_to_keep=5) 62 63 if tf.train.latest_checkpoint('ckpts') is not None: 64 saver.restore(sess, tf.train.latest_checkpoint('ckpts')) 65 train_loops = 10000 66 for i in range(train_loops): 67 batch_xs, batch_ys = mnist.train.next_batch(32) 68 _, step, loss, acc = sess.run([train_op, global_step, cross_entropy, accuracy], 69 feed_dict={x: batch_xs, y_: batch_ys}) 70 if step % 100 == 0: # print training info 71 log_str = 'step:%d \t loss:%.6f \t acc:%.6f' % (step, loss, acc) 72 tf.logging.info(log_str) 73 if step % 1000 == 0: # save current model 74 save_path = os.path.join('ckpts', 'mnist-model.ckpt') 75 saver.save(sess, save_path, global_step=step) 76 77 sess.close() 78 79

bn_test.py

1import tensorflow as tf 2 3from tensorflow.examples.tutorials.mnist import input_data 4 5tf.logging.set_verbosity(tf.logging.INFO) 6 7if __name__ == '__main__': 8 mnist = input_data.read_data_sets('mnist', one_hot=True) 9 x = tf.placeholder(tf.float32, [None, 784]) 10 y_ = tf.placeholder(tf.float32, [None, 10]) 11 image = tf.reshape(x, [-1, 28, 28, 1]) 12 conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same', 13 activation=tf.nn.relu, 14 kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), 15 name='conv1') 16 bn1 = tf.layers.batch_normalization(conv1, training=False, name='bn1') 17 pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1') 18 conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same', 19 activation=tf.nn.relu, 20 kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), 21 name='conv2') 22 bn2 = tf.layers.batch_normalization(conv2, training=False, name='bn2') 23 pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2') 24 25 flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer') 26 weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32, 27 initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights') 28 biases = tf.get_variable(shape=[10], dtype=tf.float32, 29 initializer=tf.constant_initializer(0.0), name='fc_biases') 30 logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output') 31 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output)) 32 pred_label = tf.argmax(logit_output, 1) 33 label = tf.argmax(y_, 1) 34 accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32)) 35 36 tf_config = tf.ConfigProto() 37 tf_config.gpu_options.allow_growth = True 38 tf_config.allow_soft_placement = True 39 sess = tf.InteractiveSession(config=tf_config) 40 saver = tf.train.Saver() 41 if tf.train.latest_checkpoint('ckpts') is not None: 42 saver.restore(sess, tf.train.latest_checkpoint('ckpts')) 43 else: 44 assert 'can not find checkpoint folder path!' 45 46 loss, acc = sess.run([cross_entropy,accuracy],feed_dict={x: mnist.test.images,y_: mnist.test.labels}) 47 log_str = 'loss:%.6f \t acc:%.6f' % (loss, acc) 48 tf.logging.info(log_str) 49 sess.close() 50 51

代码交流 2021