ResNet && DenseNet(实践篇)

上篇博客说了ResNetDenseNet的原理,这次说说具体实现

ResNet

def basic_block(input, in_features, out_features, stride, is_training, keep_prob):
    """Residual block"""
  if stride == 1:
    shortcut = input
  else:
    shortcut = tf.nn.avg_pool(input, [ 1, stride, stride, 1 ], [1, stride, stride, 1 ], 'VALID')
    shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0],
      [(out_features-in_features)//2, (out_features-in_features)//2]])
  current = conv2d(input, in_features, out_features, 3, stride)
  current = tf.nn.dropout(current, keep_prob)
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  current = tf.nn.relu(current)
  current = conv2d(current, out_features, out_features, 3, 1)
  current = tf.nn.dropout(current, keep_prob)
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  return current + shortcut

def block_stack(input, in_features, out_features, stride, depth, is_training, keep_prob):
    """Stack Residual block"""
  current = basic_block(input, in_features, out_features, stride, is_training, keep_prob)
  for _d in xrange(depth - 1):
    current = basic_block(current, out_features, out_features, 1, is_training, keep_prob)
  return current

DenseNet

def conv2d(input, in_features, out_features, kernel_size, with_bias=False):
  W = weight_variable([ kernel_size, kernel_size, in_features, out_features ])
  conv = tf.nn.conv2d(input, W, [ 1, 1, 1, 1 ], padding='SAME')
  if with_bias:
    return conv + bias_variable([ out_features ])
  return conv

def batch_activ_conv(current, in_features, out_features, kernel_size, is_training, keep_prob):
    """BatchNorm+Relu+conv+dropout"""
  current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
  current = tf.nn.relu(current)
  current = conv2d(current, in_features, out_features, kernel_size)
  current = tf.nn.dropout(current, keep_prob)
  return current

def block(input, layers, in_features, growth, is_training, keep_prob):
    """Dense Block"""
  current = input
  features = in_features
  for idx in xrange(layers):
    tmp = batch_activ_conv(current, features, growth, 3, is_training, keep_prob)
    current = tf.concat(3, (current, tmp))
    features += growth
  return current, features

def model():
    """DenseNet on ImageNet"""
    current = tf.reshape(xs, [ -1, 32, 32, 3 ])  # Input
    current = conv2d(current, 3, 16, 3)

    current, features = block(current, layers, 16, 12, is_training, keep_prob)
    current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
    current = avg_pool(current, 2)
    current, features = block(current, layers, features, 12, is_training, keep_prob)
    current = batch_activ_conv(current, features, features, 1, is_training, keep_prob)
    current = avg_pool(current, 2)
    current, features = block(current, layers, features, 12, is_training, keep_prob)

    current = tf.contrib.layers.batch_norm(current, scale=True, is_training=is_training, updates_collections=None)
    current = tf.nn.relu(current)
    current = avg_pool(current, 8)
    final_dim = features
    current = tf.reshape(current, [ -1, final_dim ])
    Wfc = weight_variable([ final_dim, label_count ])
    bfc = bias_variable([ label_count ])
    ys_ = tf.nn.softmax( tf.matmul(current, Wfc) + bfc )

代码不是完整的,只是表达最navie的思想核心部分

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏C/C++基础

网易游戏技术岗在线编程题(二)

小v今年有n门课,每门都有考试,为了拿到奖学金,小v必须让自己的平均成绩至少为avg。每门课由平时成绩和考试成绩组成,满分为r。现在他知道每门课的平时成绩为ai...

1102
来自专栏小樱的经验随笔

每日一水之strcmp用法

strcmp函数 C/C++函数,比较两个字符串 设这两个字符串为str1,str2, 若str1==str2,则返回零; 若str1<str2,则返回负数; ...

3079
来自专栏mathor

动态规划(一)

985
来自专栏Java 源码分析

Java位操作

     无论说是在哪一门计算机语言,位操作运算对于计算机来说肯定是最高效的,因为计算机的底层是按就是二进制,而位操作就是为了节省开销,加快程序的执行速度,以及...

2798
来自专栏云飞学编程

python学习,数据分析系列工具,初识numpy

其实,数据分析看着很高大上,也很实用,但是真的很枯燥啊。。。。但是它又不得不学,毕竟数据分析对很多工作是很有帮助的,比如爬虫,抓到的数据,不论是保存到文件还是数...

502
来自专栏拂晓风起

Flex Actionscript 3 小球碰撞 多球碰撞

1104
来自专栏数据结构与算法

473. 核电站问题

★   输入文件:nucle.in   输出文件:nucle.out   简单对比 时间限制:1 s   内存限制:128 MB 【问题描述】     一个核...

35213
来自专栏韦弦的微信小程序

Swift 数数并说 - LeetCode

1 被读作 "one 1" ("一个一") , 即 11。 11 被读作 "two 1s" ("两个一"), 即 21。 21 被读作 "one 2",...

432
来自专栏数据结构与算法

BZOJ2440: [中山市选2011]完全平方数(莫比乌斯+容斥原理)

Description 小 X 自幼就很喜欢数。但奇怪的是,他十分讨厌完全平方数。他觉得这些 数看起来很令人难受。由此,他也讨厌所有是完全平方数的正整数倍的数。...

2985
来自专栏数据结构与算法

Link-Cut-Tree(LCT)详解

1700

扫码关注云+社区