tensorflow机器学习模型的跨平台上线

    在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法优化的PMML文件大多数时候很笨拙,因此本文我们专门讨论下tensorflow机器学习模型的跨平台上线的方法。

1. tensorflow模型的跨平台上线的备选方案

    tensorflow模型的跨平台上线的备选方案一般有三种:即PMML方式,tensorflow serving方式,以及跨语言API方式。

    PMML方式的主要思路在上一篇以及讲过。这里唯一的区别是转化生成PMML文件需要用一个Java库jpmml-tensorflow来完成,生成PMML文件后,跨语言加载模型和其他PMML模型文件基本类似。

    tensorflow serving是tensorflow 官方推荐的模型上线预测方式,它需要一个专门的tensorflow服务器,用来提供预测的API服务。如果你的模型和对应的应用是比较大规模的,那么使用tensorflow serving是比较好的使用方式。但是它也有一个缺点,就是比较笨重,如果你要使用tensorflow serving,那么需要自己搭建serving集群并维护这个集群。所以为了一个小的应用去做这个工作,有时候会觉得麻烦。

    跨语言API方式是本文要讨论的方式,它会用tensorflow自己的Python API生成模型文件,然后用tensorflow的客户端库比如Java或C++库来做模型的在线预测。下面我们会给一个生成生成模型文件并用tensorflow Java API来做在线预测的例子。

2. 训练模型并生成模型文件

    我们这里给一个简单的逻辑回归并生成逻辑回归tensorflow模型文件的例子。

    首先,我们生成了一个6特征,3分类输出的4000个样本数据。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets.samples_generator import make_classification
import tensorflow as tf
X1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0,
                             n_clusters_per_class=1, n_classes=3)

    接着我们构建tensorflow的数据流图,这里要注意里面的两个名字,第一个是输入x的名字input,第二个是输出prediction_labels的名字output,这里的这两个名字可以自己取,但是后面会用到,所以要保持一致。

learning_rate = 0.01
training_epochs = 600
batch_size = 100

x = tf.placeholder(tf.float32, [None, 6],name='input') # 6 features
y = tf.placeholder(tf.float32, [None, 3]) # 3 classes

W = tf.Variable(tf.zeros([6, 3]))
b = tf.Variable(tf.zeros([3]))

# softmax回归
pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax") 
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

prediction_labels = tf.argmax(pred, axis=1, name="output")

init = tf.global_variables_initializer()

    接着就是训练模型了,代码比较简单,毕竟只是一个演示:

sess = tf.Session()
sess.run(init)
y2 = tf.one_hot(y1, 3)
y2 = sess.run(y2)

for epoch in range(training_epochs):

    _, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2})
    if (epoch+1) % 10 == 0:
        print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c))
    
print ("优化完毕!")
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y2, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: X1, y: y2})
print (acc)

    打印输出我这里就不写了,大家可以自己去试一试。接着就是关键的一步,存模型文件了,注意要用convert_variables_to_constants这个API来保存模型,否则模型参数不会随着模型图一起存下来。

graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)

    至此,我们的模型文件rf.pb已经被保存下来了,下面就是要跨平台上线了。 

3. 模型文件在Java平台上线

    这里我们以Java平台的模型上线为例,C++的API上线我没有用过,这里就不写了。我们需要引入tensorflow的java库到我们工程的maven或者gradle文件。这里给出maven的依赖如下,版本可以根据实际情况选择一个较新的版本。

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.7.0</version>
        </dependency>

    接着就是代码了,这个代码会比JPMML的要简单,我给出了4个测试样本的预测例子如下,一定要注意的是里面的input和output要和训练模型的时候对应的节点名字一致。

import org.tensorflow.*;
import org.tensorflow.Graph;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;


/**
 * Created by 刘建平pinard on 2018/7/1.
 */
public class TFjavaDemo {
    public static void main(String args[]){
        byte[] graphDef = loadTensorflowModel("D:/rf.pb");
        float inputs[][] = new float[4][6];
        for(int i = 0; i< 4; i++){
            for(int j =0; j< 6;j++){
                if(i<2) {
                    inputs[i][j] = 2 * i - 5 * j - 6;
                }
                else{
                    inputs[i][j] = 2 * i + 5 * j - 6;
                }
            }
        }
        Tensor<Float> input = covertArrayToTensor(inputs);
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        Session s = new Session(g);
        Tensor result = s.runner().feed("input", input).fetch("output").run().get(0);

        long[] rshape = result.shape();
        int rs = (int) rshape[0];
        long realResult[] = new long[rs];
        result.copyTo(realResult);

        for(long a: realResult ) {
            System.out.println(a);
        }
    }
    static private byte[] loadTensorflowModel(String path){
        try {
            return Files.readAllBytes(Paths.get(path));
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }

    static private Tensor<Float> covertArrayToTensor(float inputs[][]){
        return Tensors.create(inputs);
    }
}

    我的预测输出是1,1,0,0,供大家参考。

4. 一点小结

    对于tensorflow来说,模型上线一般选择tensorflow serving或者client API库来上线,前者适合于较大的模型和应用场景,后者则适合中小型的模型和应用场景。因此算法工程师使用在产品之前需要做好选择和评估。

(欢迎转载,转载请注明出处。欢迎沟通交流: liujianping-ok@163.com) 

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI科技评论

实战 | BERT fine-tune 终极实践教程

AI科技评论按:从 11 月初开始,google-research 就陆续开源了 BERT 的各个版本。google 此次开源的 BERT 是通过 tensor...

73950
来自专栏SIGAI学习与实践平台

【免费线上实践】动手训练模型系列:条件GAN

从无序的输出到按照类别输出,Conditional Generative Neural Networks到底借助了什么样的魔(xin)法(xi)?点击下方小程序...

14850
来自专栏新智元

10 亿图片仅需 17.7微秒:Facebook AI 实验室开源图像搜索工具Faiss

【新智元导读】Facebook的 FAIR 最新开源了一个用于有效的相似性搜索和稠密矢量聚类的库,名为 Faiss,在10亿图像数据集上的一次查询仅需17.7 ...

46650
来自专栏Script Boy (CN-SIMO)

魔法少女【动态规划问题】——NYOJ1204

13220
来自专栏CSDN技术头条

大数据并行计算利器之MPI/OpenMP

1 背景 图像连通域标记算法是从一幅栅格图像(通常为二值图像)中,将互相邻接(4邻接或8邻接)的具有非背景值的像素集合提取出来,为不同的连通域填入数字标记,并且...

29160
来自专栏PaddlePaddle

PaddlePaddle发布新版API,简化深度学习编程

PaddlePaddle是百度于2016年9月开源的一款分布式深度学习平台,为百度内部多项产品提供深度学习算法支持。为了使PaddlePaddle更加易用,我们...

36480
来自专栏AI研习社

Github 项目推荐 | 用 JavaScript 实现的神经网络 —— brain.js

不过,一般的开发者应该都不会用神经网络来实现异或的功能吧,所以这里有一个更加实际的例子:训练一个神经网络来识别颜色对比 https://brain.js.org...

19520
来自专栏机器学习实践二三事

Caffe中均值文件的问题

关于均值文件 (1) 在Caffe中作classification时经常需要使用均值文件,但是caffe自己提供的脚本只能将图像数据转换为 binarypr...

22990
来自专栏新智元

【干货】神经增强:用 Python 实现深度学习超分辨率处理

【新智元导读】神经网络基于样本图像的训练为模糊图像补充细节,从而把模糊图像变高清。它不能把你的照片重建成一模一样的高清版。这只有好莱坞大片才有可能做到——但使用...

75150
来自专栏北京马哥教育

搭建python机器学习环境以及一个机器学习例子

作者 | hzyido 来源 | 简书 糖豆贴心提醒,本文阅读时间6分钟,文末有秘密! 这篇文章介绍了Python机器学习环境的搭建,我用的机器学习开...

619120

扫码关注云+社区

领取腾讯云代金券