专栏首页专知【干货】使用TensorFlow官方Java API调用TensorFlow模型(附代码)

【干货】使用TensorFlow官方Java API调用TensorFlow模型(附代码)

【导读】随着TensorFlow的普及,越来越多的行业希望将Github中大量已有的TensorFlow代码和模型集成到自己的业务系统中,如何在常见的编程语言(Java、NodeJS等)中使用TensorFlow成为了一个比较常见的问题。专知成员Hujun给大家详细介绍了在Java中使用TensorFlow的两种方法,并着重介绍如何用TensorFlow官方Java API调用已有TensorFlow模型的方法。

专知成员Hujun在以前就写过TensorFlow 1.4 Eager Execution系列教程,欢迎查看。

完整代码可以参见专知Github链接:

https://github.com/ZhuanZhiCode

1.Java调用TensorFlow的两种方法



使用Java调用TensorFlow大致有两种方法:

  • 直接使用TensorFlow官方API调用训练好的pb模型: https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary
  • (推荐) 使用KerasServer托管TensorFlow/Keras代码及模型: https://github.com/CrawlScript/KerasServer

虽然使用TensorFlow官方Java API可以直接对接训练好的pb模型,但在实际使用中,依然存在着与跨语种对接相关的繁琐代码。例如虽然已有使用Python编写好的基于TensorFlow的文本分类代码,但TensorFlow Java API的输入需要是量化的文本,这样我们又需要用Java重新实现在Python代码中已经实现的分词、从字符串到索引的转换等预处理操作(这些操作同时依赖于Python代码依赖的单词表等数据)。另外,由于Java没有numpy支持,在构建多维数组作为输入时,使用的依然是类似循环的操作,非常繁琐。

KerasServer支持restful交互,因此可以支持用任何程序语言调用TensorFlow/ Keras。由于KerasServer的服务端提供Python API, 因此可以直接将已有的TensorFlow/Keras Python代码和模型转换为KerasServer API,供Java/c/c++/C#/ Python/ NodeJS/Browser Javascript等调用,而不需要再其他语种中进行繁琐的数据预处理操作。

例如,Java可直接将需要分类的文本数据提交给KerasServer,KerasServer可利用已有的Python代码对字符串进行分词、预处理等操作。

本教程介绍如何用TensorFlow官方Java API调用TensorFlow(Python)训练好的模型。教程的代码可在专知的Github项目中找到:

https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples

2.依赖库



(1)Python依赖

TensorFlow

pip install tf-nightly

(2)Java依赖

本教程使用的是TensorFlow官方提供了Java接口,因此我们需要导入下面的Maven依赖:

<dependency>
   <groupId>org.tensorflow</groupId>
   <artifactId>tensorflow</artifactId>
   <version>1.5.0</version>
</dependency>

此外,还有一些工具类依赖:

<dependency>
   <groupId>commons-io</groupId>
   <artifactId>commons-io</artifactId>
   <version>2.6</version>
</dependency>

3.保存pb模型



下面的代码中,x是图的输入,z是图的输出。在代码的最后,调用tf.graph_util.convert_variables_to_constants 将图进行转换,最后将图保存为模型文件(pb)。

#coding=utf-8
import tensorflow as tf


# 定义图
x = tf.placeholder(tf.float32, name="x")
y = tf.get_variable("y", initializer=10.0)
z = tf.log(x + y, name="z")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 进行一些训练代码,此处省略
    # xxxxxxxxxxxx

    # 显示图中的节点
    print([n.name for n in sess.graph.as_graph_def().node])
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names=["z"])

    # 保存图为pb文件
    with open('model.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

4.在Java中调用TensorFlow的图(pb模型)



模型的执行与Python类似,依然是导入图,建立Session,指定输入(feed)和输出(fetch)。

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;
import java.io.IOException;

public class DemoImportGraph {

    public static void main(String[] args) throws IOException {
        try (Graph graph = new Graph()) {
            //导入图
            byte[] graphBytes = IOUtils.toByteArray(new 
            FileInputStream("model.pb"));
            graph.importGraphDef(graphBytes);

            //根据图建立Session
            try(Session session = new Session(graph)){
                //相当于TensorFlow Python中的sess.run(z, 
feed_dict = {'x': 10.0})
                float z = session.runner()
                        .feed("x", Tensor.create(10.0f))
                        .fetch("z").run().get(0).floatValue();
                System.out.println("z = " + z);
            }
        }

    }
}

运行结果:

z = 2.9957323

完整代码链接:

https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples

-END-

本文分享自微信公众号 - 专知(Quan_Zhuanzhi),作者:Hujun

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-04-22

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【分享】Java 9正式发布,9个新特性解读

    转自:开源中国, www.oschina.net/translate/java-9-new-features Java 8 发布三年多之后,即将快到2017年7...

    WZEARW
  • 【下载】面向机器智能的TensorFlow实践书籍和代码

    【导读】自2015年11月TensorFlow第一个开源版本发布以来,它便迅速跻身于最激动人心的机器学习库的行列,并在科研、产品和教育等领域正在得到日益广泛的应...

    WZEARW
  • 2018年Google TensorFlow开发者峰会亮点总结

    本文由TensorFlow的产品经理Sandeep Gupta代表TensorFlow团队发布。 原文:https://medium.com/tensorflo...

    WZEARW
  • 数据科学中应该学习哪些语言?来看看哪些应该掌握的?

     作者:Aceyclee   简评:原始的数据科学是劳动密集型活动,但当你会用适合的语言进行工作时,数据科学应该是非常智能有趣的工作,会让你得到一些不容易看到...

    机器人网
  • 数据科学中应该学习哪些语言?

    ? 简评:原始的数据科学是劳动密集型活动,但当你会用适合的语言进行工作时,数据科学应该是非常智能有趣的工作,会让你得到一些不容易看到的结论。 一般来说,数据科...

    小莹莹
  • 转行大数据,编程学Java还是Python?

    Python和Java,是大数据行业最常见的两种编程语言,对于想转行大数据的人来说,学习哪个语言是比较好的选择呢?

    加米谷大数据
  • 动态 | TensorFlow 三周岁!2.0 版本将于 2019 年发布

    2015 年 11 月,谷歌宣布开源 TensorFlow 深度学习框架,这一框架基于谷歌 DistBelief 框架。

    AI科技评论
  • Python一键转Jar包,Java调用Python新姿势!

    粉丝朋友们,不知道大家看故事看腻了没(要是没腻可一定留言告诉我^_^),今天这篇文章换换口味,正经的来写写技术文。言归正传,咱们开始吧!

    轩辕之风
  • 轻松理解计算机的内存模型及Java内存模型

    本文转载自:https://www.hollischuang.com/archives/2550

    aoho求索
  • 你真的知道Java内存模型是什么吗

    前几天,发了一篇文章,介绍了一下JVM内存结构、Java内存模型以及Java对象模型之间的区别。有很多小伙伴反馈希望可以深入的讲解下每个知识点。Java内存模型...

    格姗知识圈

扫码关注云+社区

领取腾讯云代金券