前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于TNN在Android手机上实现图像分类

基于TNN在Android手机上实现图像分类

作者头像
夜雨飘零
修改2023-06-04 16:23:52
1.6K0
修改2023-06-04 16:23:52
举报
文章被收录于专栏:CSDN博客CSDN博客

前言

TNN:由腾讯优图实验室打造,移动端高性能、轻量级推理框架,同时拥有跨平台、高性能、模型压缩、代码裁剪等众多突出优势。TNN框架在原有Rapidnet、ncnn框架的基础上进一步加强了移动端设备的支持以及性能优化,同时也借鉴了业界主流开源框架高性能和良好拓展性的优点。

教程源码地址:https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TNNClassification

编译Android库

  1. 安装cmake 3.12
代码语言:javascript
复制
# 卸载旧的cmake
sudo apt-get autoremove cmake

# 下载cmake3.12
wget https://cmake.org/files/v3.12/cmake-3.12.2-Linux-x86_64.tar.gz
tar zxvf cmake-3.12.2-Linux-x86_64.tar.gz

# 移动目录并添加软连接
sudo mv cmake-3.12.2-Linux-x86_64 /opt/cmake-3.12.2
sudo ln -sf /opt/cmake-3.12.2/bin/*  /usr/bin/
  1. 添加Android NDK
代码语言:javascript
复制
wget https://dl.google.com/android/repository/android-ndk-r21b-linux-x86_64.zip
unzip android-ndk-r21b-linux-x86_64.zip
# 添加环境变量,留意你实际下载地址
export ANDROID_NDK=/mnt/d/android-ndk-r21b
  1. 安装编译环境
代码语言:javascript
复制
sudo apt-get install attr
  1. 开始编译
代码语言:javascript
复制
git clone https://github.com/Tencent/TNN.git
cd TNN/scripts

vim build_android.sh
代码语言:javascript
复制
 ABIA32="armeabi-v7a"
 ABIA64="arm64-v8a"
 STL="c++_static"
 SHARED_LIB="ON"                # ON表示编译动态库,OFF表示编译静态库
 ARM="ON"                       # ON表示编译带有Arm CPU版本的库
 OPENMP="ON"                    # ON表示打开OpenMP
 OPENCL="ON"                    # ON表示编译带有Arm GPU版本的库
 SHARING_MEM_WITH_OPENGL=0      # 1表示OpenGL的Texture可以与OpenCL共享

执行编译

代码语言:javascript
复制
./build_android.sh

编译完成后,会在当前目录的release目录下生成对应的armeabi-v7a库,arm64-v8a库和include头文件,这些文件在下一步的Android开发都需要使用到。

模型转换

接下来我们需要把Tensorflow,onnx等其他的模型转换为TNN的模型。目前 TNN 支持业界主流的模型文件格式,包括ONNX、PyTorch、TensorFlow 以及 Caffe 等。TNN 将 ONNX 作为中间层,借助于ONNX 开源社区的力量,来支持多种模型文件格式。如果要将PyTorch、TensorFlow 以及 Caffe 等模型文件格式转换为 TNN,首先需要使用对应的模型转换工具,统一将各种模型格式转换成为 ONNX 模型格式,然后将 ONNX 模型转换成 TNN 模型。

代码语言:javascript
复制
sudo docker pull turandotkay/tnn-convert
sudo docker tag turandotkay/tnn-convert:latest tnn-convert:latest
sudo docker rmi turandotkay/tnn-convert:latest

针对不同的模型转换,有不同的命令,如onnx2tnn,caffe2tnn,tf2tnn。

代码语言:javascript
复制
docker run --volume=$(pwd):/workspace -it tnn-convert:latest  python3 ./converter.py tf2tnn \
    -tp /workspace/mobilenet_v1.pb \
    -in "input[1,224,224,3]" \
    -on MobilenetV1/Predictions/Reshape_1 \
    -v v1.0 \
    -optimize

通过上面的输出,可以发现针对 TF 模型的转换,convert2tnn 工具提供了很多参数,我们一次对下面的参数进行解释:

  • tp 参数(必须) 通过 “-tp” 参数指定需要转换的模型的路径。目前只支持单个 TF模型的转换,不支持多个 TF 模型的一起转换。
  • in 参数(必须) 通过 “-in” 参数指定模型输入的名称,输入的名称需要放到“”中,例如,-in “name”。如果模型有多个输入,请使用 “;”进行分割。有的 TensorFlow 模型没有指定 batch 导致无法成功转换为 ONNX 模型,进而无法成功转换为 TNN 模型。你可以通过在名称后添加输入 shape 进行指定。shape 信息需要放在 [] 中。例如:-in “name1,28,28,3”。
  • on 参数(必须) 通过 “-on” 参数指定模型输入的名称,如果模型有多个输出,请使用 “;”进行分割
  • output_dir 参数: 可以通过 “-o ” 参数指定输出路径,但是在 docker 中我们一般不使用这个参数,默认会将生成的 TNN 模型放在当前和 TF 模型相同的路径下。
  • optimize 参数(可选) 可以通过 “-optimize” 参数来对模型进行优化,我们强烈建议你开启这个选项,只有在开启这个选项模型转换失败时,我们才建议你去掉 “-optimize” 参数进行重新尝试
  • v 参数(可选) 可以通过 -v 来指定模型的版本号,以便于后期对模型进行追踪和区分。
  • half 参数(可选) 可以通过 -half 参数指定,模型数据通过 FP16 进行存储,减少模型的大小,默认是通过 FP32 的方式进行存储模型数据的。
  • align 参数(可选) 可以通过 -align 参数指定,将 转换得到的 TNN 模型和原模型进行对齐,确定 TNN 模型是否转换成功。当前仅支持单输入单输出模型和单输入多输出模型。 align 只支持 FP32 模型的校验,所以使用 align 的时候不能使用 half
  • input_file 参数(可选) 可以通过 -input_file 参数指定模型对齐所需要的输入文件的名称,输入需要遵循如下格式
  • ref_file 参数(可选) 可以通过 -ref_file 参数指定待对齐的输出文件的名称,输出需遵循如下格式。生成输出的代码可以参考

成功转换会输出以下的日志。

代码语言:javascript
复制
----------  convert model, please wait a moment ----------

Converter Tensorflow to TNN model

Convert TensorFlow to ONNX model succeed!

Converter ONNX to TNN Model

Converter ONNX to TNN model succeed!

最终会得到这两个模型文件,mobilenet_v1.opt.tnnmodel mobilenet_v1.opt.tnnproto

开发Android项目

  1. 将转换的模型放在assets目录下。
  2. 把上一步编译得到的include目录复制到Android项目的app目录下。
  3. 把上一步编译得到的armeabi-v7aarm64-v8a目录复制到main/jniLibs下。
  4. app/src/main/cpp/目录下编写JNI的C++代码。

TNN工具

编写一个ImageClassifyUtil.java工具类,关于TNN的操作都在这里完成,如加载模型、预测。

下面三个就是TNN的JNI接口,通过这个接口完成模型加载,预测,当不使用的时候和可以调用deinit()清空对象。

代码语言:javascript
复制
public native int init(String modelPath, String protoPath, int computeUnitType);

public native float[] predict(Bitmap image, int width, int height);

public native int deinit();

通过上面的JNI接口,下面就可以实现图像识别了,WIDTHHEIGHT是模型输入图片的大小。为了兼容图片路径和Bitmap格式的图片预测,这里创建了两个重载方法。

代码语言:javascript
复制
private static final int WIDTH = 224;
private  static final int HEIGHT = 224;

public ImageClassifyUtil() {
    System.loadLibrary("TNN");
    System.loadLibrary("tnn_wrapper");
}

// 重载方法,根据图片路径转Bitmap预测
public float[] predictImage(String image_path) throws Exception {
    if (!new File(image_path).exists()) {
        throw new Exception("image file is not exists!");
    }
    FileInputStream fis = new FileInputStream(image_path);
    Bitmap bitmap = BitmapFactory.decodeStream(fis);
    Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, WIDTH, HEIGHT, false);
    float[] result = predictImage(scaleBitmap);
    if (bitmap.isRecycled()) {
        bitmap.recycle();
    }
    return result;
}

// 重载方法,直接使用Bitmap预测
public float[] predictImage(Bitmap bitmap) {
    Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, WIDTH, HEIGHT, false);
    float[] results = predict(scaleBitmap, WIDTH, HEIGHT);
    int l = getMaxResult(results);
    return new float[]{l, results[l] * 0.01f};
}

这里创建一个获取最大概率值,并把下标返回的方法,其实就是获取概率最大的预测标签。

代码语言:javascript
复制
public static int getMaxResult(float[] result) {
    float probability = 0;
    int r = 0;
    for (int i = 0; i < result.length; i++) {
        if (probability < result[i]) {
            probability = result[i];
            r = i;
        }
    }
    return r;
}

不同的模型,训练的预处理方式可能不一样,TNN 的图像预处理在C++中完成,代码片段

代码语言:javascript
复制
TNN_NS::MatConvertParam input_cvt_param;
input_cvt_param.scale = {1.0 / (255 * 0.229), 1.0 / (255 * 0.224), 1.0 / (255 * 0.225), 0.0};
input_cvt_param.bias  = {-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225, 0.0};
auto status = instance_->SetInputMat(input_mat, input_cvt_param);

选择图片预测

本教程会有两个页面,一个是选择图片进行预测的页面,另一个是使用相机实时预测并显示预测结果。以下为activity_main.xml的代码,通过按钮选择图片,并在该页面显示图片和预测结果。

代码语言:javascript
复制
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".MainActivity">

    <ImageView
        android:id="@+id/image_view"
        android:layout_width="match_parent"
        android:layout_height="400dp" />

    <TextView
        android:id="@+id/result_text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_below="@id/image_view"
        android:text="识别结果"
        android:textSize="16sp" />


    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:orientation="horizontal">

        <Button
            android:id="@+id/select_img_btn"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="选择照片" />


        <Button
            android:id="@+id/open_camera"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="实时预测" />

    </LinearLayout>

</RelativeLayout>

MainActivity.java中,进入到页面我们就要先加载模型,我们是把模型放在Android项目的assets目录的,我们需要把模型复制到一个缓存目录,然后再从缓存目录加载模型,同时还有读取标签名,标签名称按照训练的label顺序存放在assets的label_list.txt,以下为实现代码。

代码语言:javascript
复制
classNames = Utils.ReadListFromFile(getAssets(), "label_list.txt");
String protoContent = getCacheDir().getAbsolutePath() + File.separator + "squeezenet_v1.1.tnnproto";
Utils.copyFileFromAsset(MainActivity.this, "squeezenet_v1.1.tnnproto", protoContent);
String modelContent = getCacheDir().getAbsolutePath() + File.separator + "squeezenet_v1.1.tnnmodel";
Utils.copyFileFromAsset(MainActivity.this, "squeezenet_v1.1.tnnmodel", modelContent);

imageClassifyUtil = new ImageClassifyUtil();
int status = imageClassifyUtil.init(modelContent, protoContent, USE_GPU ? 1 : 0);
if (status == 0){
    Toast.makeText(MainActivity.this, "模型加载成功!", Toast.LENGTH_SHORT).show();
}else {
    Toast.makeText(MainActivity.this, "模型加载失败!", Toast.LENGTH_SHORT).show();
    finish();
}

添加两个按钮点击事件,可以选择打开相册读取图片进行预测,或者打开另一个Activity进行调用摄像头实时识别。

代码语言:javascript
复制
Button selectImgBtn = findViewById(R.id.select_img_btn);
Button openCamera = findViewById(R.id.open_camera);
imageView = findViewById(R.id.image_view);
textView = findViewById(R.id.result_text);
selectImgBtn.setOnClickListener(new View.OnClickListener() {
    @Override
    public void onClick(View v) {
        // 打开相册
        Intent intent = new Intent(Intent.ACTION_PICK);
        intent.setType("image/*");
        startActivityForResult(intent, 1);
    }
});
openCamera.setOnClickListener(new View.OnClickListener() {
    @Override
    public void onClick(View v) {
        // 打开实时拍摄识别页面
        Intent intent = new Intent(MainActivity.this, CameraActivity.class);
        startActivity(intent);
    }
});

当打开相册选择照片之后,回到原来的页面,在下面这个回调方法中获取选择图片的Uri,通过Uri可以获取到图片的绝对路径。如果Android8以上的设备获取不到图片,需要在AndroidManifest.xml配置文件中的application添加android:requestLegacyExternalStorage="true"。拿到图片路径之后,调用TFLiteClassificationUtil类中的predictImage()方法预测并获取预测值,在页面上显示预测的标签、对应标签的名称、概率值和预测时间。

代码语言:javascript
复制
@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
    super.onActivityResult(requestCode, resultCode, data);
    String image_path;
    if (resultCode == Activity.RESULT_OK) {
        if (requestCode == 1) {
            if (data == null) {
                Log.w("onActivityResult", "user photo data is null");
                return;
            }
            Uri image_uri = data.getData();
            image_path = getPathFromURI(MainActivity.this, image_uri);
            try {
                // 预测图像
                FileInputStream fis = new FileInputStream(image_path);
                imageView.setImageBitmap(BitmapFactory.decodeStream(fis));
                long start = System.currentTimeMillis();
                float[] result = imageClassifyUtil.predictImage(image_path);
                long end = System.currentTimeMillis();
                String show_text = "预测结果标签:" + (int) result[0] +
                        "\n名称:" +  classNames[(int) result[0]] +
                        "\n概率:" + result[1] +
                        "\n时间:" + (end - start) + "ms";
                textView.setText(show_text);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

上面获取的Uri可以通过下面这个方法把Url转换成绝对路径。

代码语言:javascript
复制
// get photo from Uri
public static String getPathFromURI(Context context, Uri uri) {
    String result;
    Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
    if (cursor == null) {
        result = uri.getPath();
    } else {
        cursor.moveToFirst();
        int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
        result = cursor.getString(idx);
        cursor.close();
    }
    return result;
}

摄像头实时预测

在调用相机实时预测我就不再介绍了,原理都差不多,具体可以查看https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification中的源代码。核心代码如下,创建一个子线程,子线程中不断从摄像头预览的AutoFitTextureView上获取图像,并执行预测,并在页面上显示预测的标签、对应标签的名称、概率值和预测时间。每一次预测完成之后都立即获取图片继续预测,只要预测速度够快,就可以看成实时预测。

代码语言:javascript
复制
private Runnable periodicClassify =
        new Runnable() {
            @Override
            public void run() {
                synchronized (lock) {
                    if (runClassifier) {
                        // 开始预测前要判断相机是否已经准备好
                        if (getApplicationContext() != null && mCameraDevice != null && mnnClassification != null) {
                            predict();
                        }
                    }
                }
                if (mInferThread != null && mInferHandler != null && mCaptureHandler != null && mCaptureThread != null) {
                    mInferHandler.post(periodicClassify);
                }
            }
        };

// 预测相机捕获的图像
private void predict() {
    // 获取相机捕获的图像
    Bitmap bitmap = mTextureView.getBitmap();
    try {
        // 预测图像
        long start = System.currentTimeMillis();
        float[] result = imageClassifyUtil.predictImage(bitmap);
        long end = System.currentTimeMillis();
        String show_text = "预测结果标签:" + (int) result[0] +
                "\n名称:" +  classNames[(int) result[0]] +
                "\n概率:" + result[1] +
                "\n时间:" + (end - start) + "ms";
        textView.setText(show_text);
    } catch (Exception e) {
        e.printStackTrace();
    }
}

本项目中使用的了读取图片的权限和打开相机的权限,所以不要忘记在AndroidManifest.xml添加以下权限申请。

代码语言:javascript
复制
<uses-permission android:name="android.permission.CAMERA"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>

如果是Android 6 以上的设备还要动态申请权限。

代码语言:javascript
复制
    // check had permission
    private boolean hasPermission() {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
            return checkSelfPermission(Manifest.permission.CAMERA) == PackageManager.PERMISSION_GRANTED &&
                    checkSelfPermission(Manifest.permission.READ_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED &&
                    checkSelfPermission(Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED;
        } else {
            return true;
        }
    }

    // request permission
    private void requestPermission() {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
            requestPermissions(new String[]{Manifest.permission.CAMERA,
                    Manifest.permission.READ_EXTERNAL_STORAGE,
                    Manifest.permission.WRITE_EXTERNAL_STORAGE}, 1);
        }
    }

效果图:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-09-06 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
作者已关闭评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 编译Android库
  • 模型转换
  • 开发Android项目
    • TNN工具
      • 选择图片预测
        • 摄像头实时预测
        相关产品与服务
        容器服务
        腾讯云容器服务(Tencent Kubernetes Engine, TKE)基于原生 kubernetes 提供以容器为核心的、高度可扩展的高性能容器管理服务,覆盖 Serverless、边缘计算、分布式云等多种业务部署场景,业内首创单个集群兼容多种计算节点的容器资源管理模式。同时产品作为云原生 Finops 领先布道者,主导开源项目Crane,全面助力客户实现资源优化、成本控制。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档