前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow首次快速体验

TensorFlow首次快速体验

作者头像
十毛
发布2019-06-19 18:55:20
4930
发布2019-06-19 18:55:20
举报
文章被收录于专栏:用户1337634的专栏

添加依赖

代码语言:javascript
复制
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.13.1</version>
</dependency>

定义图模型

示例完成一个简单的函数:

代码语言:javascript
复制
f(x, y) = z = a*x + b*y

其中a, b是常量,x, y是变量

  • 定义Graph
代码语言:javascript
复制
Graph graph = new Graph()
  • 定义常量
代码语言:javascript
复制
Operation a = graph.opBuilder("Const", "a")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .setAttr("value", Tensor.<Double>create(3.0, Double.class))
        .build();
Operation b = graph.opBuilder("Const", "b")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .setAttr("value", Tensor.<Double>create(2.0, Double.class))
        .build()
  • 定义变量
代码语言:javascript
复制
Operation x = graph.opBuilder("Placeholder", "x")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .build();
Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .build();
  • 定义函数
代码语言:javascript
复制
Operation ax = graph.opBuilder("Mul", "ax")
        .addInput(a.output(0))
        .addInput(x.output(0))
        .build();
Operation by = graph.opBuilder("Mul", "by")
        .addInput(b.output(0))
        .addInput(y.output(0))
        .build();
Operation z = graph.opBuilder("Add", "z")
        .addInput(ax.output(0))
        .addInput(by.output(0))
        .build();

可以看出来,用Java定义图模型比较麻烦,但是使用Python会简单很多

执行

代码语言:javascript
复制
Session session = new Session(graph);
Tensor<Double> tensor = session.runner().fetch("z")
        .feed("x", Tensor.create(3.0, Double.class))
        .feed("y", Tensor.create(6.0, Double.class))
        .run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());

图模型保存及加载

  • 保存模型
代码语言:javascript
复制
Path path = Paths.get("tensor.model");
byte[] bytes = graph.toGraphDef();
Files.write(path, bytes);
  • 加载模型
代码语言:javascript
复制
Graph graph = new Graph();
byte[] bytes = Files.readAllBytes(path);
graph.importGraphDef(bytes);

ps: 模型可以在不同语言通用,所以可以使用python训练模型,然后提供给其他语言使用,比如Java

结果

最后输出结果:21.0

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019.06.18 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 定义图模型
  • 执行
  • 图模型保存及加载
  • 结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档