前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【实操干货】创建一个用在图像内部进行对象检测的Android应用程序

【实操干货】创建一个用在图像内部进行对象检测的Android应用程序

作者头像
用户6543014
发布2019-11-20 15:26:18
1.1K0
发布2019-11-20 15:26:18
举报
文章被收录于专栏:CU技术社区CU技术社区

在移动设备上运行机器学习代码是下一件大事。 PyTorch在最新版本的PyTorch 1.3中添加了PyTorch Mobile,用于在Android和iOS设备上部署机器学习模型。

在这里,我们将研究创建一个用于在图像内部进行对象检测的Android应用程序;如下图所示。

应用程序的演示运行

步骤1:准备模型

在本教程中,我们将使用经过预训练好的ResNet18模型。ResNet18是具有1000个分类类别的最先进的计算机视觉模型。

1.安装Torchvision库

代码语言:javascript
复制
pip install torchvision

2.下载并跟踪ResNet18模型

我们追踪这个模型是因为我们需要一个可执行的ScriptModule来进行即时编译。

代码语言:javascript
复制
import torch
import torchvision
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval()
example_inputs = torch.rand(1, 3, 224, 224)
resnet18_traced = torch.jit.trace(resnet18, example_inputs = example_inputs)
resnet18_traced.save("resnet18_traced.pt")

注意:

  1. 将resnet18_traced.pt存储在一个已知的位置,在本教程的后续步骤中我们将需要此位置。
  2. 在torch.rand中,我们采用了224 * 224的尺寸,因为ResNet18接受224 * 224的尺寸。

步骤2:制作Android应用程序

1.如果尚未安装,请下载并安装Android Studio,如果是,请单击“是”以下载和安装SDK。链接:https://developer.android.com/studio

2.打开Android Studio,然后单击:启动一个新的Android Studio项目

3.选择清空活动

4.输入应用程序名称:ObjectDetectorDemo,然后按Finish

5.安装NDK运行Android内部运行原生代码:

  • 转到Tools> SDK Manager
  • 单击SDK工具
  • 选中NDK(并排)旁边的框

6.添加依赖项

Insidebuild.gradle(Module:app)。

在依赖项中添加以下内容

代码语言:javascript
复制
dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'androidx.appcompat:appcompat:1.0.2'
    implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
implementation 'org.pytorch:pytorch_android:1.3.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}

7.添加基本布局以加载图像并显示结果

转到app> res> layout> activity_main.xml,然后添加以下代码

代码语言:javascript
复制
<ImageView
    android:id="@+id/image"
    app:layout_constraintTop_toTopOf="parent"
    android:layout_width="match_parent"
    android:layout_height="400dp"
    android:layout_marginBottom="20dp"
    android:scaleType="fitCenter" />

<TextView
    android:id="@+id/result_text"
    android:layout_width="match_parent"
    android:layout_height="wrap_content"
    android:layout_gravity="top"
    android:text=""
    android:textSize="20dp"
    android:textStyle="bold"
    android:textAllCaps="true"
    android:textAlignment="center"
    app:layout_constraintTop_toTopOf="@id/button"
    app:layout_constraintBottom_toBottomOf="@+id/image" />

<Button
    android:id="@+id/button"
    android:layout_width="match_parent"
    android:layout_height="wrap_content"
    android:text="Load Image"
    app:layout_constraintBottom_toBottomOf="@+id/result_text"
    app:layout_constraintTop_toTopOf="@+id/detect" />

<Button
    android:id="@+id/detect"
    android:layout_width="match_parent"
    android:layout_height="wrap_content"
    android:text="Detect"
    android:layout_marginBottom="50dp"
    app:layout_constraintBottom_toBottomOf="parent" />

您的布局应如下图所示

8.我们需要设置权限以读取设备上的图像存储

转到app> manifests> AndroidManifest.xml,然后在manifest标签内添加以下代码

代码语言:javascript
复制
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>

获取应用程序加载权限(仅在您授予权限之前询问)

—转到Main Activity java。在onCreate()方法中添加以下代码。

代码语言:javascript
复制
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
    requestPermissions(new String[]  {android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
}

9.复制模型

现在是时候复制使用python脚本创建的模型了。

从文件资源管理器/查找器中打开您的应用程序。

转到app > src > main

创建一个名为assets的文件夹将模型复制到此文件夹中。打开后,您将在Android Studio中看到如下图所示。(如果没有,请右键单击应用程序文件夹,然后单击“同步应用程序”)

10.我们需要列出模型的输出类

转到app > java

在第一个文件夹中,将新的Java类名称命名为ModelClasses。

将类的列表定义为(整个列表为1000个类,因此可以在此处复制所有内容(检查Json或Git)以获取完整列表,然后在下面的列表内复制):

代码语言:javascript
复制
public static String[] MODEL_CLASSES = new String[]{
        "tench, Tinca tinca",
        "goldfish, Carassius auratus"
        .
        .
        .
}

11.Main Activity Java,这里将定义按钮动作,读取图像并调用PyTorch模型。请参阅代码内的注释以获取解释。

代码语言:javascript
复制
package com.tckmpsi.objectdetectordemo;

import androidx.appcompat.app.AppCompatActivity;

import android.content.Context;
import android.content.Intent;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable;
import android.net.Uri;
import android.os.Build;
import android.os.Bundle;
import android.provider.MediaStore;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

public class MainActivity extends AppCompatActivity {
    private static int RESULT_LOAD_IMAGE = 1;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        Button buttonLoadImage = (Button) findViewById(R.id.button);
        Button detectButton = (Button) findViewById(R.id.detect);


        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
            requestPermissions(new String[]{android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
        }
        buttonLoadImage.setOnClickListener(new View.OnClickListener() {

            @Override
            public void onClick(View arg0) {
                TextView textView = findViewById(R.id.result_text);
                textView.setText("");
                Intent i = new Intent(
                        Intent.ACTION_PICK,
                        MediaStore.Images.Media.EXTERNAL_CONTENT_URI);

                startActivityForResult(i, RESULT_LOAD_IMAGE);


            }
        });

        detectButton.setOnClickListener(new View.OnClickListener() {

            @Override
            public void onClick(View arg0) {

                Bitmap bitmap = null;
                Module module = null;

                //Getting the image from the image view
                ImageView imageView = (ImageView) findViewById(R.id.image);

                try {
                    //Read the image as Bitmap
                    bitmap = ((BitmapDrawable)imageView.getDrawable()).getBitmap();

                    //Here we reshape the image into 400*400
                    bitmap = Bitmap.createScaledBitmap(bitmap, 400, 400, true);

                    //Loading the model file.
                    module = Module.load(fetchModelFile(MainActivity.this, "resnet18_traced.pt"));
                } catch (IOException e) {
                    finish();
                }

                //Input Tensor
                final Tensor input = TensorImageUtils.bitmapToFloat32Tensor(
                        bitmap,
                        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
                        TensorImageUtils.TORCHVISION_NORM_STD_RGB
                );

                //Calling the forward of the model to run our input
                final Tensor output = module.forward(IValue.from(input)).toTensor();


                final float[] score_arr = output.getDataAsFloatArray();

                // Fetch the index of the value with maximum score
                float max_score = -Float.MAX_VALUE;
                int ms_ix = -1;
                for (int i = 0; i < score_arr.length; i++) {
                    if (score_arr[i] > max_score) {
                        max_score = score_arr[i];
                        ms_ix = i;
                    }
                }

                //Fetching the name from the list based on the index
                String detected_class = ModelClasses.MODEL_CLASSES[ms_ix];

                //Writing the detected class in to the text view of the layout
                TextView textView = findViewById(R.id.result_text);
                textView.setText(detected_class);


            }
        });

    }
    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        //This functions return the selected image from gallery
        super.onActivityResult(requestCode, resultCode, data);

        if (requestCode == RESULT_LOAD_IMAGE && resultCode == RESULT_OK && null != data) {
            Uri selectedImage = data.getData();
            String[] filePathColumn = { MediaStore.Images.Media.DATA };

            Cursor cursor = getContentResolver().query(selectedImage,
                    filePathColumn, null, null, null);
            cursor.moveToFirst();

            int columnIndex = cursor.getColumnIndex(filePathColumn[0]);
            String picturePath = cursor.getString(columnIndex);
            cursor.close();

            ImageView imageView = (ImageView) findViewById(R.id.image);
            imageView.setImageBitmap(BitmapFactory.decodeFile(picturePath));

            //Setting the URI so we can read the Bitmap from the image
            imageView.setImageURI(null);
            imageView.setImageURI(selectedImage);


        }


    }

    public static String fetchModelFile(Context context, String modelName) throws IOException {
        File file = new File(context.getFilesDir(), modelName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(modelName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }

}

12.现在是时候测试应用程序了。两种方法有两种:

  • 在模拟器上运行(https://developer.android.com/studio/run/emulator)。
  • 使用Android设备。(为此,您需要启用USB调试(http://developer.android.com/studio/run/emulator))。
  • 运行应用程序后,它的外观应类似于页面顶部的GIF。

链接到Git存储库:https://github.com/tusharck/Object-Detector-Android-App-Using-PyTorch-Mobile-Neural-Network

好看的人才能点

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-11-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 SACC开源架构 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档