专栏首页AI研习社将 TensorFlow 训练好的模型迁移到 Android APP上(TensorFlowLite)

将 TensorFlow 训练好的模型迁移到 Android APP上(TensorFlowLite)

本文原载于天泽28的 CSDN 博客,AI 研习社获其授权转载。

1.写在前面

最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址:一步步做一个数字手势识别APP,源代码已经开源在github上,地址:Chinese-number-gestures-recognition),要把在PC端训练好的模型放到Android APP上,调研了下,谷歌发布了TensorFlow Lite可以把TensorFlow训练好的模型迁移到Android APP上,百度也发布了移动端深度学习框架mobile-deep-learning(MDL),这个框架应该是paddlepaddle的手机版,具体的细节没有了解过。因为对TensorFlow稍微熟悉些,因此就决定用TensorFlow来做。

关于在PC端如何处理数据及训练模型,请参见博客:一步步做一个数字手势识别APP,代码已经开源在github上,上面有代码的说明和APP演示。这篇博客只介绍如何把TensorFlow训练好的模型迁移到Android Studio上进行APP的开发。

2.模型训练注意事项

第一步,首先在pc端训练模型的时候要模型保存为.pb模型,在保存的时候有一点非常非常重要,就是你待会再Android studio是使用这个模型用到哪个参数,那么你在保存pb模型的时候就把给哪个参数一个名字,再保存。

否则,你在Android studio中很难拿出这个参数,因为TensorFlow Lite的fetch()函数是根据保存在pb模型中的名字去寻找这个参数的。(如果你已经训练好了模型,并且没有给参数名字,且你不想再训练模型了,那么你可以尝试下面的方法去找到你需要使用的变量的默认名字,见下面的代码):

#输出保存的模型中参数名字及对应的值with tf.gfile.GFile('model_50_200_c3//./digital_gesture.pb', "rb") as f:  #读取模型数据
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) #得到模型中的计算图和数据with tf.Graph().as_default() as graph:  # 这里的Graph()要有括号,不然会报TypeError
    tf.import_graph_def(graph_def, name="")  #导入模型中的图到现在这个新的计算图中,不指定名字的话默认是 import
    for op in graph.get_operations():  # 打印出图中的节点信息
        print(op.name, op.values())

这段代码打出的变量的名字以及对应的值。

言归正传,通常情况该你应该保存参数的时候都给参数一个指定的名字,如下面这样(通过name参数给变量指定名字),关于训练CNN的完整代码请参见下一篇博客或者github:

X = tf.placeholder(tf.float32, [None, 64, 64, 3], name="input_x")
y = tf.placeholder(tf.float32, [None, 11], name="input_y")
kp = tf.placeholder_with_default(1.0, shape=(), name="keep_prob")
lam = tf.placeholder(tf.float32, name="lamda")#中间略过若干代码z_fc2 = tf.add(tf.matmul(z_fc1_drop, W_fc2),b_fc2, name="outlayer")
prob = tf.nn.softmax(z_fc2, name="probability")
pred = tf.argmax(prob, 1, output_type="int32", name="predict")
1

3.在Android Studio中配置

第二步,开始把pb模型移植到Android Studio上,网上绝大部分资料都是说用bazel重新编译模型生成依赖,这种方法难度太大。其实没必须这样做,TensorFlow Lite官方的例子中已经给我们展示了,我们其实只需要两个文件:

libandroid_tensorflow_inference_java.jar 和 libtensorflow_inference.so。

这两个文件我已经放到github上了,大家可以自行下载使用,下载地址:libandroid_tensorflow_inference_java.jar、libtensorflow_inference.so。

注:检神说,直接用aar依赖也可以,这个我没试过。。有兴趣的可以试一下。

准备工作已经完毕,下面正式开始Android Studio中的配置。

首先把训练好的pb模型放到Android项目中app/src/main/assets下,若不存在assets目录,则自己新建一个。如图所示:

其次,把刚刚下载的 libandroid_tensorflow_inference_java.jar 文件放到 app/libs 目下,把libtensorflow_inference.so 放到 app/libs/armeabi-v7a 目录下,如下图所示:

然后在app/build.gradle里进行如下配置:

在defaultConfig里添加

multiDexEnabled true
        ndk {
            abiFilters "armeabi-v7a"
        }

在android里添加

sourceSets {
        main {
            jni.srcDirs = []
            jniLibs.srcDirs = ['libs']
        }
    }

如图所示:

在dependencies中添加libandroid_tensorflow_inference_java.jar,即:

implementation files('libs/libandroid_tensorflow_inference_java.jar')

如图所示:

至此,所有配置已经完成,下面是模型调用。

4.在Android Studio中调用模型

在要用到模型的地方,首先要加载libtensorflow_inference.so库和初始化TensorFlowInferenceInterface对象,代码为:

TensorFlowInferenceInterface inferenceInterface;    static {        //加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
        Log.e("tensorflow","libtensorflow_inference.so库加载成功");
    }
    Classifier(AssetManager assetManager, String modePath) {        //初始化TensorFlowInferenceInterface对象
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e("tf","TensoFlow模型文件加载成功");
    }

如图所示:

下面来多看一点东西,看看TensorFlow Lite里提供了哪几个接口,官网地址:Here’s what a typical Inference Library sequence looks like on Android.

// Load the model from disk.
TensorFlowInferenceInterface inferenceInterface =
new TensorFlowInferenceInterface(assetManager, modelFilename);

// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);

// Run the inference call.
inferenceInterface.run(outputNames, logStats);

// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);

下面就可以愉快地使用模型了。放一段我调用模型的代码,以供大家参考:

public ArrayList predict(Bitmap bitmap)
    {
        ArrayList<String> list = new ArrayList<>();        float[] inputdata = getPixels(bitmap);        for(int i = 0; i <30; ++i)
        {
            Log.d("matrix",inputdata[i] + "");
        }
        inferenceInterface.feed(inputName, inputdata, 1, IMAGE_SIZE, IMAGE_SIZE, 3);        //运行模型,run的参数必须是String[]类型
        String[] outputNames = new String[]{outputName,probabilityName,outlayerName};
        inferenceInterface.run(outputNames);        //获取结果
        int[] labels = new int[1];
        inferenceInterface.fetch(outputName,labels);        int label = labels[0];        float[] prob = new float[11];
        inferenceInterface.fetch(probabilityName, prob);//        float[] outlayer = new float[11];//        inferenceInterface.fetch(outlayerName, outlayer);//        for(int i = 0; i <11; ++i)//        {//            Log.d("matrix",outlayer[i] + "");//        }
        for(int i = 0; i <11; ++i)
        {
            Log.d("matrix",prob[i] + "");
        }
        DecimalFormat df = new DecimalFormat("0.000000");        float label_prob = prob[label];        //返回值

最后放一张做的数字手势识别APP的效果,全部代码,将会开源在github上,欢迎star。

再放一张碰运气的识别结果:

Github 链接:

https://github.com/tz28/Chinese-number-gestures-recognition

本文分享自微信公众号 - AI研习社(okweiwu)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-07-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 基于 Keras 对深度学习模型进行微调的全面指南 Part 1

    我将借鉴自己的经验,列出微调背后的基本原理,所涉及的技术,及最后也是最重要的,在本文第二部分中将分步详尽阐述如何在 Keras 中对卷积神经网络模型进行微调。

    AI研习社
  • 微软协作 AI 挑战赛开始报名,沉迷 Minecraft 无法自拔的你不去试试?

    对 Minecraft 游戏感兴趣的 AI 开发者可能都知道 Project Malmo:一个微软发起的基于 Minecraft 的 AI 技术研究和测试平台。...

    AI研习社
  • 使用迁移学习/数据增强方法来实现Kaggle分类&amp;识别名人脸部

    在这个项目中,我将使用keras、迁移学习和微调过的VGG16网络来对kaggle竞赛中的名人面部图像进行分类。

    AI研习社
  • Jürgen Schmidhuber眼中的深度学习十年,以及下一个十年展望

    2020年是充满科幻的一年,曾经我们畅想飞行汽车、智能洗碗机器人以及能自动写代码的程序,然而这一切都没有发生。

    大数据文摘
  • 大案!大案!大案! 网传A站、摩拜数据库泄露

    13号凌晨,黑客聚集的暗网突现一条售卖信息,一名黑客号称出售两个权重超高的shell+内网权限,A站acfun.cn与摩拜单车,信息中称两个网站日流量均超百万,...

    安恒信息
  • 在你的CVM上安装SteamCMD服务器

    Steam命令行版客户端(SteamCMD)是一个命令行版本的Steam客户端。它的主要用途是在一个命令行界面的Steam客户端上安装和更新各种可用的专用服务端...

    尘埃
  • 《轮到你了》的菜奈AI是如何克隆声音的?

    最近在追日剧《轮到你了》,最新的15集里,二阶堂给翔太制作了一个菜奈的AI,是个手机app,界面非常简单,采用的是聊天机器人的界面,只不过是语音聊天的方式,此A...

    mixlab
  • Socket 通信原理

    什么是Socket? Socket的中文翻译过来就是“套接字”。套接字是什么,我们先来看看它的英文含义:插座。 Socket就像一个电话插座,负责连通两端的电话...

    wangxl
  • 2018年的新通用伪随机数算法(xoshiro / xoroshiro)的C++(head only)实现

    前段时间看到说Lua 5.4用了一种新的通用随机数算法,替换掉本来内部使用的CRT的随机数引擎。我看了一下大致的实现,CPU和空间复杂度任然保持了一个较低的水平...

    owent
  • MIT长篇论文:我们热捧的AI翻译和自动驾驶,需要用技术性价比来重估

    只是在过去十年里面,这种计算限制被「淡化」了。人们专注于「算法」优化和「硬件」性能的提升,以及愿意投入更高的「成本」来获得更好的性能。

    量子位

扫码关注云+社区

领取腾讯云代金券