Tensorflow常见模型及工程化方法

Tensorflow在深度学习模型研究中起到了很大的促进作用,灵活的框架免去了研究人员、开发者大量的自动求导代码工作。本文总结一下常用的模型代码和工程化需要的代码。有需求的同学收藏一下,以便日后查阅。

Tensorflow常见模型

A. LSTM模型结构

import tensorflow as tf

import tensorflow.contrib as contrib

from tensorflow.python.ops import array_ops

class lstm(object):

def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):

self.in_data = in_data

self.hidden_dim = hidden_dim

self.batch_seqlen = batch_seqlen

self.flag = flag

lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)

out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)

if flag=='all_ht':

self.out = out

if flag = 'first_ht':

self.out = out[:,0,:]

if flag = 'last_ht':

self.out = out[:,-1,:]

if flag = 'concat':

self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)

B. Bi-LSTM模型结构

import tensorflow as tf

import tensorflow.contrib as contrib

from tensorflow.python.ops import array_ops

from tensorflow.python.framework import dtypes

class bilstm(object):

def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):

self.in_data = in_data

self.hidden_dim = hidden_dim

self.batch_seqlen = batch_seqlen

self.flag = flag

lstm_cell_fw = contrib.rnn.LSTMCell(self.hidden_dim)

lstm_cell_bw = contrib.rnn.LSTMCell(self.hidden_dim)

out, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fw,cell_bw=lstm_cell_bw,inputs=self.in_data, sequence_lenth=self.batch_seqlen,dtype=tf.float32)

bi_out = tf.concat(out, 2)

if flag=='all_ht':

self.out = bi_out

if flag=='first_ht':

self.out = bi_out[:,0,:]

if flag=='last_ht':

self.out = tf.concat([state[0].h,state[1].h], 1)

if flag=='concat':

self.out = tf.concat([bi_out[:,0,:],tf.concat([state[0].h,state[1].h], 1)],1)

C multi-channel CNN

import tensorflow as tf

import tensorflow.contrib as contrib

from tensorflow.python.ops import array_ops

class lstm(object):

def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):

self.in_data = in_data

self.hidden_dim = hidden_dim

self.batch_seqlen = batch_seqlen

self.flag = flag

lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)

out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)

if flag=='all_ht':

self.out = out

if flag = 'first_ht':

self.out = out[:,0,:]

if flag = 'last_ht':

self.out = out[:,-1,:]

if flag = 'concat':

self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)

D depth-wise cnn

import tensorflow as tf

def depth_wise_conv(in_data, scope, kernel_size, dim):

with tf.variable_scope(scope):

shapes = in_data.shape.as_list()

depthwise_filter = tf.get_varibale("depthwise_conv.weight",

(kernel_size[0], kernel_size[1], shapes[-1]

dtype=tf.float32, )

pointwise_filter = tf.get_variable("pointwise_conv.weight",

(1,1, shapes[-1], dim),

dtype=tf.float32, )

outputs = tf.nn.separable_conv2d(in_data,

depthwise_filter,

pointwise_filter,

strides=(1,1,1,1),

padding="SAME"

)

return outputs

D multi-layer depth-wise cnn

def multi_convs(input_x, dim, conv_number=2, k=5):

# input_x: 输入数据,为batch * seq * dim

# dim:对应的输入的维度

# conv_number: 对应的卷积的层数,一般2,

# k对应的是卷积核的窗口大小

res = input_x

for index in range(conv_number):

out = norm(res) # layer norm

out = tf.expand_dims(out, 2) # bach * seq * 1 * dim

out = depth_wise_conv(out, kernel_size=(k, 1), dim=dim, scope="convs.%d" % index)

out = tf.squeeze(out, 2) # batch * seq * dim

out = tf.nn.relu(out)

out = out + res

res = out

out = norm(out) # 输出为 batch * seq * 1 * dim

out = tf.squeeze(out, squeeze_dims=2) # 输出为 batch * seq * dim

return out

模型参数查看

已知模型文件的ckpt文件,通过pywrap_tensorflow获取模型的各参数名。

import tensoflow as tf

from tensorflow.python import pywrap_tensorflow

model_dir = "./ckpt/"

ckpt = tf.train.get_checkpoint_state(model_dir)

ckpt_path = ckpt.model_checkpoint_path

reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)

param_dict = reader.get_variable_to_shape_map()

for key, val in param_dict.items():

try:

print key, val

except:

pass

工程化方法

A. tennsorflow模型文件打包成PB文件

import tensorflow as tf

from tensorflow.python.tools import freeze_graph

with tf.Graph().as_default():

with tf.device("/cpu:0"):

config = tf.ConfigProto(allow_soft_placement=True)

with tf.Session(config=config).as_default() as sess:

model = Your_Model_Name()

model.build_graph()

sess.run(tf.initialize_all_variables())

saver = tf.train.Saver()

ckpt_path = "/your/model/path"

saver.restore(sess, ckpt_path)

graphdef = tf.get_default_graph().as_graph_def()

tf.train.write_graph(sess.graph_def,"/your/save/path/","save_name.pb",as_text=False)

frozen_graph = tf.graph_util.convert_variables_to_constants(sess,graphdef,['output/node/name'])

frozen_graph_trim = tf.graph_util.remove_training_nodes(frozen_graph)

freeze_graph.freeze_graph('/your/save/path/save_name.pb','',True, ckpt_path,'output/node/name','save/restore_all','save/Const:0','frozen_name.pb',True,"")

B.PB文件读取使用

output_graph_def = tf.GraphDef()

with open("your_name.pb","rb") as f:

output_graph_def.ParseFromString(f.read())

_ = tf.import_graph_def(output_graph_def, name="")

node_in = sess.graph.get_tensor_by_name("input_node_name")

model_out = sess.graph.get_tensor_by_name("out_node_name")

feed_dict = {node_in:in_data}

pred = sess.run(model_out, feed_dict)

注:本文代码均为笔者手敲留存,如代码有误可以咨询探讨。

原文发布于微信公众号 - CodeInHand(CodeInHand)

原文发表时间:2018-11-02

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能头条

详解TensorBoard如何调参

993
来自专栏人工智能

TensorFlow实战——图像分类神经网络模型

Learn how to classify images with TensorFlow 使用TensorFlow创建一个简单而强大的图像分类神经网络模型 by...

4096
来自专栏Python中文社区

基于RNN自动生成古诗

專 欄 ❈ 作者:yonggege,Python中文社区专栏作者 GitHub地址:https://github.com/wzyonggege ❈ 0. ch...

2765
来自专栏机器学习算法与Python学习

Pytorch | BERT模型实现,提供转换脚本【横扫NLP】

《谷歌终于开源BERT代码:3 亿参数量,机器之心全面解读》,上周推送的这篇文章,全面解读基于TensorFlow实现的BERT代码。现在,PyTorch用户的...

1771
来自专栏ATYUN订阅号

Deep Photo Styletransfer的一种纯Tensorflow实现,教你如何转换图片风格

通过深度学习,一秒钟让你的照片高大上,这是康奈尔大学和 Adobe 的工程师合作的一个新项目,通过卷积神经网络把图片进行风格迁移。项目已开源,名字叫「Deep ...

5825
来自专栏ATYUN订阅号

自相关与偏自相关的简单介绍

自相关和偏自相关图在时间序列分析和预测中经常使用。这些图生动的总结了一个时间序列的观察值与他之前的时间步的观察值之间的关系强度。初学者要理解时间序列预测中自相关...

6234
来自专栏iOSDevLog

Core ML 2有什么新功能

Core ML是Apple的机器学习框架。仅在一年前发布,Core ML为开发人员提供了一种方法,只需几行代码即可将强大的智能机器学习功能集成到他们的应用程序中...

902
来自专栏Pytorch实践

Tensorflow常见模型及工程化方法

Tensorflow在深度学习模型研究中起到了很大的促进作用,灵活的框架免去了研究人员、开发者大量的自动求导代码工作。本文总结一下常用的模型代码和工程化需要的代...

3666
来自专栏王嘉的专栏

安全 AI 的智能对抗系统之架构实现篇

在AI的浪潮下,在现有的安全系统的基础上,SNG业务安全中心将机器学习应用到业务安全对抗中,自研建设并搭建了 – 安全AI的智能对抗系统。智能对抗系统现已应用在...

1.1K0
来自专栏机器之心

资源 | 微软发布可变形卷积网络代码:可用于多种复杂视觉任务

选自Github 机器之心编译 编辑:吴攀 上个月,微软代季峰等研究者发布的一篇论文提出了一种可变形卷积网络,该研究「引入了两种新的模块来提高卷积神经网络(CN...

3606

扫码关注云+社区

领取腾讯云代金券