前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Yolov5 Android torchscript方式集成

Yolov5 Android torchscript方式集成

作者头像
obaby
发布2023-02-22 11:04:08
6240
发布2023-02-22 11:04:08
举报
文章被收录于专栏:obaby@mars

搜索了一下,目前要在手机端实现yolov5检测,找到了如下三种方式:

到目前为止,上面提到的三种方式,前两种已经测试完了,第二种没有成功。最终问题处在修改如下代码的地方:

代码语言:javascript
复制
// stride 16
{
    ncnn::Mat out;
    ex.extract("515", out);

    ncnn::Mat anchors(6);
    anchors[0] = 30.f;
    anchors[1] = 61.f;
    anchors[2] = 62.f;
    anchors[3] = 45.f;
    anchors[4] = 59.f;
    anchors[5] = 119.f;

    std::vector<Object> objects16;
    generate_proposals(anchors, 16, in_pad, out, prob_threshold, objects16);

    proposals.insert(proposals.end(), objects16.begin(), objects16.end());
}

// stride 32
{
    ncnn::Mat out;
    ex.extract("628", out);

    ncnn::Mat anchors(6);
    anchors[0] = 116.f;
    anchors[1] = 90.f;
    anchors[2] = 156.f;
    anchors[3] = 198.f;
    anchors[4] = 373.f;
    anchors[5] = 326.f;

    std::vector<Object> objects32;
    generate_proposals(anchors, 32, in_pad, out, prob_threshold, objects32);

    proposals.insert(proposals.end(), objects32.begin(), objects32.end());
}

原始的yolov5s.pt有三个output,但是我训练的模型由于是single class只有一个输出,所以没有搞明白这个地方该怎么修改。于是就卡住了,最终的检测效果目前不清楚好坏。

重新回到torchscript方式,这种方式相对来说比较简单,直接根据github https://github.com/pytorch/android-demo-app/tree/master/ObjectDetection里面的指导一步一步操作即可。需要注意的是最新版的yolov5中已经继承了export.py文件,但是不是在models目录下,而是在项目的根目录下。并且这个文件对于上面链接中提到的修改已经基本都修改过了,可以直接使用。如果导出有问题可以对比上文链接中的需要修改的地方进行修改。

1.导出torchscript模型,可以通过以下命令导出:

代码语言:javascript
复制
python export.py --weights weights/best.pt

或者修改export.py中的下面的代码,然后直接不带参数运行:

代码语言:javascript
复制
# 修改data weights参数
def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default=ROOT / 'data/ads.yaml', help='dataset.yaml path')
    parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='weights path')
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
    parser.add_argument('--batch-size', type=int, default=16, help='batch size')
    parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
    parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
    parser.add_argument('--train', action='store_true', help='model.train() mode')
    parser.add_argument('--optimize',default=True, action='store_true', help='TorchScript: optimize for mobile')
    parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
    parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
    parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
    parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
    parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
    parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
    parser.add_argument('--include', nargs='+',
                        default=['torchscript', 'onnx'],
                        help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
    opt = parser.parse_args()
    print_args(FILE.stem, opt)
    return opt

导出之后的文件:

2.下载https://github.com/pytorch/android-demo-app/tree/master/ObjectDetection对应的android工程项目,将导出之后的best.torchscript.pt复制到安卓工程的assets目录下,名称修改为yolov5s.torchscript.ptl。

如果想保持原有的文件名则需要修改安卓工程中的Mainactivity.java中的如下代码:

代码语言:javascript
复制
try {
            mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "yolov5s.torchscript.ptl"));
            BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open("classes.txt")));
            String line;
            List<String> classes = new ArrayList<>();
            while ((line = br.readLine()) != null) {
                classes.add(line);
            }
            PrePostProcessor.mClasses = new String[classes.size()];
            classes.toArray(PrePostProcessor.mClasses);
        } catch (IOException e) {
            Log.e("Object Detection", "Error reading assets", e);
            finish();
        }

3.修改assets目录下的class.txt文件,将其中的分类名称改成自己的名称:

修改 PrePostProcessor.java中的mOutputColumn:

代码语言:javascript
复制
private static int mOutputColumn = 6; // left, top, right, bottom, score and 80 class probability
// 这里为left, top, right, bottom, score + 分类数量,因为是单分类所以是5+1=6 原来为5+80

如果要修改置信度可以修改下面的代码:

代码语言:javascript
复制
private static float mThreshold = 0.05f; // score above which a detection is generated

修改下面的计算代码:

代码语言:javascript
复制
static ArrayList<Result> outputsToNMSPredictions(float[] outputs, float imgScaleX, float imgScaleY, float ivScaleX, float ivScaleY, float startX, float startY) {
        ArrayList<Result> results = new ArrayList<>();
        for (int i = 0; i< mOutputRow; i++) {
            if (outputs[i* mOutputColumn +4] > mThreshold) {
                float x = outputs[i* mOutputColumn];
                float y = outputs[i* mOutputColumn +1];
                float w = outputs[i* mOutputColumn +2];
                float h = outputs[i* mOutputColumn +3];

                float left = imgScaleX * (x - w/2);
                float top = imgScaleY * (y - h/2);
                float right = imgScaleX * (x + w/2);
                float bottom = imgScaleY * (y + h/2);

                float max = outputs[i* mOutputColumn +5];
                int cls = 0;
                for (int j = 0; j < mOutputColumn -5; j++) {
                    if (outputs[i* mOutputColumn +5+j] > max) {
                        max = outputs[i* mOutputColumn +5+j];
                        cls = j;
                    }
                }

                Rect rect = new Rect((int)(startX+ivScaleX*left), (int)(startY+top*ivScaleY), (int)(startX+ivScaleX*right), (int)(startY+ivScaleY*bottom));
                Result result = new Result(cls, outputs[i*6+4], rect);// 修改i*后面的值,与mOutputColumn是一样的。
                results.add(result);
            }
        }
        return nonMaxSuppression(results, mNmsLimit, mThreshold);
    }

4.编译运行,到这里就可以在设备上运行检测了效果如下:

置信度只有0.05,这个是有问题的,因为直接通过python版本运行会发现置信度0.91:

这个状况是由于什么导致的目前不太清除,哪位大神知道怎么解决还望不吝赐教。

参考链接:

https://blog.csdn.net/djstavaV/article/details/118078013

https://github.com/pytorch/android-demo-app/tree/master/ObjectDetection

https://github.com/ultralytics/yolov5/issues/251

https://cdmana.com/2021/07/20210721095026166r.html

☆文章版权声明☆

* 网站名称:obaby@mars

* 网址:https://h4ck.org.cn/

* 本文标题: 《Yolov5 Android torchscript方式集成》

* 本文链接:https://h4ck.org.cn/2021/09/yolov5-android-torchscript%e6%96%b9%e5%bc%8f%e9%9b%86%e6%88%90/

* 转载文章请标明文章来源,原文标题以及原文链接。请遵从 《署名-非商业性使用-相同方式共享 2.5 中国大陆 (CC BY-NC-SA 2.5 CN) 》许可协议。


分享文章:

相关文章:

  1. Android Skip Ads Yolov5 Project
  2. Windows 10 yolov5 GPU环境
  3. Yolov5 安卓检测效果让人无语
  4. Android Skip Ads Android Project【截屏部分】
  5. yolov5 训练速度对比
  6. Android Studio动态调试smali
  7. Android 检测坐标是否在国内
  8. Android App Ads Remove
  9. 美女图库(Android)
  10. IDA Android Remote Debug
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021年9月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 相关文章:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档