前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >slim.arg_scope()的使用

slim.arg_scope()的使用

作者头像
狼啸风云
修改2022-09-04 20:52:09
2K0
修改2022-09-04 20:52:09
举报

slim.arg_scope()函数的使用

slim是一种轻量级的tensorflow库,可以使模型的构建,训练,测试都变得更加简单。在slim库中对很多常用的函数进行了定义,slim.arg_scope()是slim库中经常用到的函数之一。函数的定义如下;

代码语言:javascript
复制
@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
  """Stores the default arguments for the given set of list_ops.
  For usage, please see examples at top of the file.
  Args:
    list_ops_or_scope: List or tuple of operations to set argument scope for or
      a dictionary containing the current scope. When list_ops_or_scope is a
      dict, kwargs must be empty. When list_ops_or_scope is a list or tuple,
      then every op in it need to be decorated with @add_arg_scope to work.
    **kwargs: keyword=value that will define the defaults for each op in
              list_ops. All the ops need to accept the given set of arguments.
  Yields:
    the current_scope, which is a dictionary of {op: {arg: value}}
  Raises:
    TypeError: if list_ops is not a list or a tuple.
    ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
  """
  if isinstance(list_ops_or_scope, dict):
    # Assumes that list_ops_or_scope is a scope that is being reused.
    if kwargs:
      raise ValueError('When attempting to re-use a scope by suppling a'
                       'dictionary, kwargs must be empty.')
    current_scope = list_ops_or_scope.copy()
    try:
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()
  else:
    # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
    if not isinstance(list_ops_or_scope, (list, tuple)):
      raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
                      'scope (i.e. dict)')
    try:
      current_scope = current_arg_scope().copy()
      for op in list_ops_or_scope:
        key_op = _key_op(op)
        if not has_arg_scope(op):
          raise ValueError('%s is not decorated with @add_arg_scope',
                           _name_op(op))
        if key_op in current_scope:
          current_kwargs = current_scope[key_op].copy()
          current_kwargs.update(kwargs)
          current_scope[key_op] = current_kwargs
        else:
          current_scope[key_op] = kwargs.copy()
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()

如注释中所说,这个函数的作用是给list_ops中的内容设置默认值。但是每个list_ops中的每个成员需要用@add_arg_scope修饰才行。所以使用slim.arg_scope()有两个步骤:

  1. 使用@slim.add_arg_scope修饰目标函数
  2. 用 slim.arg_scope()为目标函数设置默认参数.

例如如下代码;首先用@slim.add_arg_scope修饰目标函数fun1(),然后利用slim.arg_scope()为它设置默认参数。

代码语言:javascript
复制
import tensorflow as tf
slim =tf.contrib.slim
     
@slim.add_arg_scope
def fun1(a=0,b=0):
    return (a+b)
     
with slim.arg_scope([fun1],a=10):
    x=fun1(b=30)
    print(x)

运行结果为:

40

平常所用到的slim.conv2d( ),slim.fully_connected( ),slim.max_pool2d( )等函数在他被定义的时候就已经添加了@add_arg_scope。以slim.conv2d( )为例;

代码语言:javascript
复制
    @add_arg_scope
    def convolution(inputs,
                    num_outputs,
                    kernel_size,
                    stride=1,
                    padding='SAME',
                    data_format=None,
                    rate=1,
                    activation_fn=nn.relu,
                    normalizer_fn=None,
                    normalizer_params=None,
                    weights_initializer=initializers.xavier_initializer(),
                    weights_regularizer=None,
                    biases_initializer=init_ops.zeros_initializer(),
                    biases_regularizer=None,
                    reuse=None,
                    variables_collections=None,
                    outputs_collections=None,
                    trainable=True,
                    scope=None):

所以,在使用过程中可以直接slim.conv2d( )等函数设置默认参数。例如在下面的代码中,不做单独声明的情况下,slim.conv2d, slim.max_pool2d, slim.avg_pool2d三个函数默认的步长都设为1,padding模式都是'VALID'的。但是也可以在调用时进行单独声明。这种参数设置方式在构建网络模型时,尤其是较深的网络时,可以节省时间。

代码语言:javascript
复制
     with slim.arg_scope(
                    [slim.conv2d, slim.max_pool2d, slim.avg_pool2d],stride = 1, padding = 'VALID'):
                net = slim.conv2d(inputs, 32, [3, 3], stride = 2, scope = 'Conv2d_1a_3x3')
                net = slim.conv2d(net, 32, [3, 3], scope = 'Conv2d_2a_3x3')
                net = slim.conv2d(net, 64, [3, 3], padding = 'SAME', scope = 'Conv2d_2b_3x3')

@修饰符

其实这种用法是python中常用到的。在python中@修饰符放在函数定义的上方,它将被修饰的函数作为参数,并返回修饰后的同名函数。形式如下;

代码语言:javascript
复制
    @fun_a     #等价于fun_a(fun_b)
    def fun_b():

这在本质上讲跟直接调用被修饰的函数没什么区别,但是有时候也有用处,例如在调用被修饰函数前需要输出时间信息,我们可以在@后方的函数中添加输出时间信息的语句,这样每次我们只需要调用@后方的函数即可。

代码语言:javascript
复制
    def funs(fun,factor=20):
        x=fun()
        print(factor*x)
        
        
    @funs     #等价funs(add(),fator=20)
    def add(a=10,b=20):
        return(a+b)

转载地址:https://blog.csdn.net/u013921430/article/details/80915696

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年06月24日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • slim.arg_scope()函数的使用
  • @修饰符
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档