ResNet && DenseNet(实践篇)

上篇博客说了ResNetDenseNet的原理,这次说说具体实现

ResNet

def basic_block(input, in_features, out_features, stride, is_training, keep_prob):
    """Residual block"""
  if stride == 1:
    shortcut = input
  else:
    shortcut = tf.nn.avg_pool(input, [ 1, stride, stride, 1 ], [1, stride, stride, 1 ], 'VALID')
    shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0],
      [(out_features-in_features)//2, (out_features-in_features)//2]])
  current = conv2d(input, in_features, out_features, 3, stride)
  current = tf.nn.dropout(current, keep_prob)
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  current = tf.nn.relu(current)
  current = conv2d(current, out_features, out_features, 3, 1)
  current = tf.nn.dropout(current, keep_prob)
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  return current + shortcut

def block_stack(input, in_features, out_features, stride, depth, is_training, keep_prob):
    """Stack Residual block"""
  current = basic_block(input, in_features, out_features, stride, is_training, keep_prob)
  for _d in xrange(depth - 1):
    current = basic_block(current, out_features, out_features, 1, is_training, keep_prob)
  return current

DenseNet

def conv2d(input, in_features, out_features, kernel_size, with_bias=False):
  W = weight_variable([ kernel_size, kernel_size, in_features, out_features ])
  conv = tf.nn.conv2d(input, W, [ 1, 1, 1, 1 ], padding='SAME')
  if with_bias:
    return conv + bias_variable([ out_features ])
  return conv

def batch_activ_conv(current, in_features, out_features, kernel_size, is_training, keep_prob):
    """BatchNorm+Relu+conv+dropout"""
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  current = tf.nn.relu(current)
  current = conv2d(current, in_features, out_features, kernel_size)
  current = tf.nn.dropout(current, keep_prob)
  return current

def block(input, layers, in_features, growth, is_training, keep_prob):
    """Dense Block"""
  current = input
  features = in_features
  for idx in xrange(layers):
    tmp = batch_activ_conv(current, features, growth, 3, is_training, keep_prob)
    current = tf.concat(3, (current, tmp))
    features += growth
  return current, features

def model():
    """DenseNet on ImageNet"""
    current = tf.reshape(xs, [ -1, 32, 32, 3 ])  # Input
    current = conv2d(current, 3, 16, 3)

    current, features = block(current, layers, 16, 12, is_training, keep_prob)
    current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
    current = avg_pool(current, 2)
    current, features = block(current, layers, features, 12, is_training, keep_prob)
    current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
    current = avg_pool(current, 2)
    current, features = block(current, layers, features, 12, is_training, keep_prob)

    current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
    current = tf.nn.relu(current)
    current = avg_pool(current, 8)
    final_dim = features
    current = tf.reshape(current, [ -1, final_dim ])
    Wfc = weight_variable([ final_dim, label_count ])
    bfc = bias_variable([ label_count ])
    ys_ = tf.nn.softmax( tf.matmul(current, Wfc) + bfc )

代码不是完整的,只是表达最navie的思想核心部分

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏算法channel

玩转Pandas,让数据处理更easy系列6

玩转Pandas系列已经连续推送5篇,尽量贴近Pandas的本质原理,结合工作实践,按照使用Pandas的逻辑步骤,系统地并结合实例推送Pandas的主要常用功...

802
来自专栏数据结构与算法

cf492E. Vanya and Field(扩展欧几里得)

$n \times n$的网格,有$m$个苹果树,选择一个点出发,每次增加一个偏移量$(dx, dy)$,最大化经过的苹果树的数量

731
来自专栏每日一篇技术文章

OpengL ES _ 入门_02

顶点是啥? 顶点就是坐标位置,不管你是画直线,三角形,正方体,球体,以及3D游戏人物等,都需要顶点来确定其形状。 顶点坐标创建 1.记住顶点的坐标数据类型...

701
来自专栏瓜大三哥

视频压缩编码技术(H.264) 之哈夫曼编码

第二步,将两个最小概率组成一组,划成2 个分支域,并标以0 和1;再把2 个分支域合并成1个支域,标以两个概率之和;

822
来自专栏前端新视界

使用 SVG 和 JS 创建一个由星形变心形的动画

序言:首先,这是一篇学习 SVG 及 JS 动画不可多得的优秀文章。我非常喜欢 Ana Tudor 写的教程。在她的教程中有大量使用 SVG 制作的图解以及实...

4515
来自专栏数据结构与算法

士兵站队问题

士兵站队问题 【问题描述】        在一个划分成网格的操场上,n个士兵散乱地站在网格点上。网格点由整数坐标(x,y)表示。士兵们可以沿网格边上、下、左、...

2837
来自专栏塔奇克马敲代码

C语言中的atan和atan2

1623
来自专栏数据结构与算法

浅谈线段树中加与乘标记的下放

假设我们一个节点为 ,其中 代表该节点的权值, 为乘法标记, 为加法标记 那么我们有两种表示方式, 第一种:先加再乘 此时该节点为 当再遇到一个 ...

3567
来自专栏应兆康的专栏

100个Numpy练习【4】

翻译:YingJoy 网址: https://www.yingjoy.cn/ 来源: https://github.com/rougier/numpy-100...

4258
来自专栏应兆康的专栏

100个Numpy练习【4】

Numpy是Python做数据分析必须掌握的基础库之一,非常适合刚学习完Numpy基础的同学,完成以下习题可以帮助你更好的掌握这个基础库。

47912

扫码关注云+社区