tensorflow版的bvlc模型

     研究相关的图片分类,偶然看到bvlc模型,但是没有tensorflow版本的,所以将caffe版本的改成了tensorflow的:

关于模型这个图:

下面贴出通用模板:

  1 from __future__ import print_function
  2 import tensorflow as tf
  3 import numpy as np
  4 from scipy.misc import imread, imresize
  5 
  6 
  7 class BVLG:
  8     def __init__(self, imgs, weights=None, sess=None):
  9         self.imgs = imgs
 10         self.convlayers()
 11         self.fc_layers()
 12 
 13         self.probs = tf.nn.softmax(self.fc3l)
 14         if weights is not None and sess is not None:
 15             self.load_weights(weights,sess)
 16 
 17     def convlayers(self):
 18         self.parameters = []
 19 
 20         # zero-mean input
 21         with tf.name_scope('preprocess') as scope:
 22             mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
 23             images = self.imgs - mean
 24 
 25         # conv1
 26         with tf.name_scope('conv1') as scope:
 27             kernel = tf.Variable(tf.truncated_normal([7, 7, 3, 96], dtype=tf.float32,
 28                                                      stddev=1e-1), name='weights')
 29             conv = tf.nn.conv2d(images, kernel, [3, 3, 1, 1], padding='SAME')
 30             biases = tf.Variable(tf.constant(0.0, shape=[96], dtype=tf.float32),
 31                                  trainable=True, name='biases')
 32             out = tf.nn.bias_add(conv, biases)
 33             self.conv1 = tf.nn.relu(out, name=scope)
 34             self.parameters += [kernel, biases]
 35 
 36         # pool1
 37         self.pool1 = tf.nn.max_pool(self.conv1,
 38                                     ksize=[1, 3, 3, 1],
 39                                     strides=[1, 2, 2, 1],
 40                                     padding='SAME',
 41                                     name='pool1')
 42 
 43         # conv2
 44         with tf.name_scope('conv2') as scope:
 45             kernel = tf.Variable(tf.truncated_normal([4, 4, 96, 256], dtype=tf.float32,
 46                                                      stddev=1e-1), name='weights')
 47             conv = tf.nn.conv2d(self.pool1, kernel, [1, 1, 1, 1], padding='SAME')
 48             biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
 49                                  trainable=True, name='biases')
 50             out = tf.nn.bias_add(conv, biases)
 51             self.conv2_1 = tf.nn.relu(out, name=scope)
 52             self.parameters += [kernel, biases]
 53 
 54 
 55         # pool2
 56         self.pool2 = tf.nn.max_pool(self.conv2,
 57                                     ksize=[1, 3, 3, 1],
 58                                     strides=[1, 2, 2, 1],
 59                                     padding='SAME',
 60                                     name='pool2')
 61 
 62         # conv5
 63         with tf.name_scope('conv5') as scope:
 64             kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype=tf.float32,
 65                                                      stddev=1e-1), name='weights')
 66             conv = tf.nn.conv2d(self.pool2, kernel, [1, 1, 1, 1], padding='SAME')
 67             biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
 68                                  trainable=True, name='biases')
 69             out = tf.nn.bias_add(conv, biases)
 70             self.conv5 = tf.nn.relu(out, name=scope)
 71             self.parameters += [kernel, biases]
 72 
 73         # pool5
 74         self.pool5 = tf.nn.max_pool(self.conv5,
 75                                     ksize=[1, 2, 2, 1],
 76                                     strides=[1, 2, 2, 1],
 77                                     padding='SAME',
 78                                     name='pool4')
 79 
 80     def fc_layers(self):
 81         # fc1
 82         with tf.name_scope('fc1') as scope:
 83             shape = int(np.prod(self.pool5.get_shape()[1:]))
 84             fc1w = tf.Variable(tf.truncated_normal([shape, 4096],
 85                                                    dtype=tf.float32,
 86                                                    stddev=1e-1), name='weights')
 87             fc1b = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32),
 88                                trainable=True, name='biases')
 89             pool5_flat = tf.reshape(self.pool5, [-1, shape])
 90             fc1l = tf.nn.bias_add(tf.matmul(pool5_flat, fc1w), fc1b)
 91             self.fc1 = tf.nn.relu(fc1l)
 92             self.parameters += [fc1w, fc1b]
 93 
 94         # fc3
 95         with tf.name_scope('fc3') as scope:
 96             fc3w = tf.Variable(tf.truncated_normal([4096, 587],
 97                                                    dtype=tf.float32,
 98                                                    stddev=1e-1), name='weights')
 99             fc3b = tf.Variable(tf.constant(1.0, shape=[587], dtype=tf.float32),
100                                trainable=True, name='biases')
101             self.fc3l = tf.nn.bias_add(tf.matmul(self.fc2, fc3w), fc3b)
102             self.parameters += [fc3w, fc3b]

caffe版本的ImageNet模型地址: https://github.com/BVLC/caffe/tree/master/models/bvlc_reference_caffenet

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏运维

求两数的平均值

某文件中,有如下多行数据 ,需要统计含关键字:real 对应行的数值(第二列),并最后得出总平均值 请给出相关命令 或 实现思路? 样本数据如下: Real  ...

14410
来自专栏用户2442861的专栏

神经网络python实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haluoluo211/article/d...

24930
来自专栏北京马哥教育

实战Google深度学习框架:TensorFlow计算加速

作者:才云科技Caicloud,郑泽宇,顾思宇 要将深度学习应用到实际问题中,一个非常大的问题在于训练深度学习模型需要的计算量太大。比如Inception-v3...

34170
来自专栏Python小屋

Python+sklearn使用支持向量机算法实现数字图片分类

关于支持向量机的理论知识,大家可以查阅机器学习之类的书籍或网上资源,本文主要介绍如何使用Python扩展库sklearn中的支持向量机实现数字图片分类。 1、首...

36150
来自专栏YoungGy

ML基石_12_NonLinearTransformation

retro quadratic hypothesis nonlinear transform price on nonlinear transform stru...

20580
来自专栏ml

mxnet运行时遇到问题及解决方法

1.训练好模型之后,进行预测时出现这种错误: 1 mxnet.base.MXNetError: [15:05:50] src/ndarray/ndarray.c...

68240
来自专栏计算机视觉与深度学习基础

【深度学习】使用tensorflow实现VGG19网络

接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络。 VGG网络与AlexNet类似,也是一种CNN,VGG在2014年的 ILSV...

58840
来自专栏Petrichor的专栏

深度学习: 卷积核 为什么都是 奇数size

44810
来自专栏marsggbo

Python数据增强(data augmentation)库--Augmentor 使用介绍

Augmentor 使用介绍 原图 ? 1.random_distortion(probability, grid_height, grid_width, ma...

51180
来自专栏WOLFRAM

三维图形中指定绘图的区域,想知道这个区域上最大值是多少?

16640

扫码关注云+社区

领取腾讯云代金券