前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow-Slim 简介【转载】

TensorFlow-Slim 简介【转载】

作者头像
用户6021899
发布2019-12-10 16:40:10
5420
发布2019-12-10 16:40:10
举报
文章被收录于专栏:Python编程 pyqt matplotlib

TF-Slim 是 TensorFlow 中一个用来构建、训练、评估复杂模型的轻量化库。TF-Slim 模块可以和 TensorFlow 中其它API混合使用。

Slim 模块可以使模型的构建、训练、评估变得简单:允许用户用紧凑的代码定义模型。这主要由 arg_scope、大量的高级 layers 和 variables 来实现。这些工具增加了代码的可读性和维护性,减少了复制、粘贴超参数值出错的可能性,并且简化了超参数的调整。通过提供常用的 regularizers 来简化模型的开发。很多常用的计算机视觉模型(例如 VGG、AlexNet)在 Slim 里面已经有了实现。这些模型开箱可用,并且能够以多种方式进行扩展(例如,给内部的不同层添加 multiple heads)。

  • Slim 层(Layers)

虽然 TensorFlow 的操作集合相当广泛,但神经网络的开发人员通常会在更高的层次上考虑模型,比如:“layers”、“losses”、“metrics” 和 “networks”。layer(例如conv层、fc层、bn层)比 TensorFlow op 更加抽象,并且 layer 通常涉及多个 op。更进一步,layer 通常(但不总是)有很多与之相关的 variable(可调参数(tunable parameters)),这一点与大多数的基本操作区别很大。例如,神经网络中的一个 conv 层由很多低级的 op 组成: 1. 创建权重和偏差 viriable 2. 对权重和输入进行卷积(输入来自前一层) 3. 卷积结果加上偏差 4. 应用一个激活函数 仅使用基础的 TensorFlow 代码,这可能相当费力:

代码语言:javascript
复制
input = ...
with tf.name_scope('conv1_1') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32,stddev=1e-1), name='weights')
  conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
  biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),trainable=True, name='biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name=scope)

为了避免代码的重复。Slim 提供了很多方便的神经网络 layers 的高层 op。例如:与上面的代码对应的 Slim 版的代码:

代码语言:javascript
复制
input = ...
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')
  • slim.repeat

slim 还提供了两个 meta-operations:repeat 和 stack。tf.contrib.layers.repeat 和 stack,普通函数可以用这两个函数。它们允许用户去重复的进行(perform)相同的操作(operation)。

例如,考虑下面的代码段(来自 VGG 网络,它的 layers 在两个 pooling 层之间进行了很多 conv):

代码语言:javascript
复制
net = ...
net = slim.conv2d(net, 256, [3, 3], scope='conv3_1')
net = slim.conv2d(net, 256, [3, 3], scope='conv3_2')
net = slim.conv2d(net, 256, [3, 3], scope='conv3_3')
net = slim.max_pool2d(net, [2, 2], scope='pool2')

使用 slim.repeat 可以使上面的代码变得更清晰明了:

代码语言:javascript
复制
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool2')

注意:slim.repeat 不仅对 repeated 单元采用相同的参数,而且它对 repeated 单元的 scope 采用更好的命名方式(加下划线,再加迭代序号)。具体来说,上面例子中的 scopes 将会命名为 ‘conv3/conv3_1’,‘conv3/conv3_2’,‘conv3/conv3_3’

  • slim.stack

更进一步,slim 的 slim.stack 允许使用不同的参数去重复多个操作,从而创建一个多层的堆叠结构。slim.stack 也为每一个创建的 op 创造了一个新的 tf.variable_scope。例如,创建一个多层感知器的基本写法如下:

代码语言:javascript
复制
net= slim.fully_connected(net, 32, scope='fc/fc_1') #全连接层,32个节点
net = slim.fully_connected(net, 64, scope='fc/fc_2')#64个节点
net = slim.fully_connected(net, 128, scope='fc/fc_3')#128个节点

使用slim可以简写为:

代码语言:javascript
复制
net = slim.stack(net, slim.fully_connected, [32, 64, 128], scope='fc')

在这个例子中,slim.stack 调用 slim.fully_connected 三次,并将函数上一次调用的输出传递给下一次调用。但是,在每个调用中,隐形单元(hidden units)的数量分别为 32,64,128。

相似地,我们可以使用 stack 去简化多层卷积的堆叠。下面是用基本代码写的4层卷积层,

代码语言:javascript
复制
net = slim.conv2d(net, 32, [3, 3], scope='core/core_1')
net = slim.conv2d(net, 32, [1, 1], scope='core/core_2')
net = slim.conv2d(net, 64, [3, 3], scope='core/core_3')
net = slim.conv2d(net, 64, [1, 1], scope='core/core_4')

可以用slim.stack 简写做:

代码语言:javascript
复制
net = slim.stack(net, slim.conv2d, [(32, [3, 3]), (32, [1, 1]), (64, [3, 3]), (64, [1, 1])], scope='core')

slim 作用域(Scopes)

slim 增加了一个名为 arg_scope 的新 scope 机制。这个新 scope 允许用户去给一个或多个 op 指定一套默认参数,这些默认参数将被传给 arg_scope 里使用的的每一个 op。通过使用一个 arg_scope,我们能够在保证每一层使用相同参数值的同时,简化代码:

代码语言:javascript
复制
  with slim.arg_scope([slim.conv2d], padding='SAME',
                      weights_initializer=tf.truncated_normal_initializer(stddev=0.01)
                      weights_regularizer=slim.l2_regularizer(0.0005)):
    net = slim.conv2d(inputs, 64, [11, 11], scope='conv1')
    net = slim.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2')
    net = slim.conv2d(net, 256, [11, 11], scope='conv3')

使用 arg_scope 使代码更清晰、简单并且容易去维护。

注意,在 arg_scope 内部指定op的参数值时,指定的参数将取代默认参数。具体来讲,当 padding 参数的默认值被设置为 ‘SAME’ 时,第二个卷积的 padding 参数被指定为 ‘VALID’。我们也可以嵌套地使用 arg_scope,并且在同一个 scope 中可以使用多个 op。例如:

代码语言:javascript
复制
with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      activation_fn=tf.nn.relu,
                      weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                      weights_regularizer=slim.l2_regularizer(0.0005)):
  with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'):
    net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
    net = slim.conv2d(net, 256, [5, 5],
                      weights_initializer=tf.truncated_normal_initializer(stddev=0.03),
                      scope='conv2')
    net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc')

在这个例子中,第一个 arg_scope 中对 conv2d、fully_connected 层使用相同的 weights_initializer。在第二 arg_scope 中,给 conv2d 的其它默认参数进行了指定。

利用slim我们能够用很少行的代码实现非常复杂的网络。例如,整个 VGG 架构可以使用下面的代码段实现:

代码语言:javascript
复制
def vgg16(inputs):
  with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      activation_fn=tf.nn.relu,
                      weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
                      weights_regularizer=slim.l2_regularizer(0.0005)):
    net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
    net = slim.max_pool2d(net, [2, 2], scope='pool1')
    net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
    net = slim.max_pool2d(net, [2, 2], scope='pool2')
    net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
    net = slim.max_pool2d(net, [2, 2], scope='pool3')
    net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
    net = slim.max_pool2d(net, [2, 2], scope='pool4')
    net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
    net = slim.max_pool2d(net, [2, 2], scope='pool5')
    net = slim.fully_connected(net, 4096, scope='fc6')
    net = slim.dropout(net, 0.5, scope='dropout6')
    net = slim.fully_connected(net, 4096, scope='fc7')
    net = slim.dropout(net, 0.5, scope='dropout7')
    net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')
  return net
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档