前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【图像分类】YOLOv5-6.2全新版本:支持图像分类

【图像分类】YOLOv5-6.2全新版本:支持图像分类

作者头像
zstar
发布2022-09-23 11:09:30
1.5K0
发布2022-09-23 11:09:30
举报
文章被收录于专栏:往期博文

前言

众所周知,YOLOv5是一款优秀的目标检测模型,但实际上,它也支持图像分类。在6.1版本中,就暗留了classify这个参数,可以在做检测的同时进行分类。

官方仓库地址:https://github.com/ultralytics/yolov5/releases

更新概览

在几天前刚新出的6.2版本中,直接将分类功能单独剥离开来,使其能够直接训练图像分类数据集。

先看看官网公示的更新说明:

  • 分类功能 新增分类功能,并提供各模型在ImageNet上训练过的预训练模型
  • ClearML日志记录 与开源实验跟踪器ClearML集成。使用pip安装clearml将启用集成,并允许用户跟踪clearml中的每个训练运行。
  • Deci.ai优化 在Deci上单击一次即可自动编译和量化YOLOv5,从而获得更好的性能
  • GPU导出基准 可以使用python utils/benchmarks.py --weights yolov5s.pt --device 0来导出Benchmark (mAP and speed)
  • 训练可完全复现 torch>=1.12.0的单GPU YOLOv5训练现在完全可再现,并且可以使用新的–seed参数(默认seed=0)
  • 优化Apple炼丹体验 Apple Metal Performance Shader(MPS:苹果炼丹工具) 支持Apple M1/M2设备

在这些更新中,我最关注的是图像分类功能,那么本篇就来尝试跑通一下。

分类模型效果

下图是官方贴出来的各分类模型对比图,在可以看到在相同的数据集上,YOLOv5x-cls模型取得了最佳的准确率。下列这些模型官方均提供预训练权重。

在这里插入图片描述
在这里插入图片描述

工程结构

首先看新版本的工程结构,和前几个版本差别不大。主要是多了一个classify文件夹,包含图像分类训练,验证,检测三个函数。

训练结果会保存在runs/train-cls文件夹中。

在这里插入图片描述
在这里插入图片描述

数据集下载

train.py中,提供这段数据集下载程序段:

代码语言:javascript
复制
# Download Dataset
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
    data_dir = data if data.is_dir() else (DATASETS_DIR / data)
    if not data_dir.is_dir():
        LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
        t = time.time()
        if str(data) == 'imagenet':
            subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
        else:
            url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip'
            download(url, dir=data_dir.parent)
        s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
        LOGGER.info(s)

如果指定的数据集路径下没有数据集,会从官方仓库v1.0的版本中进行下载。这里我建议是手动去进行下载,受限于网络情况,自动下载很容易失败。

可以看到,在官方仓库v1.0版本中的Assets中包含了很多数据集,我下载了最经典的mnist的数据集来做测试。

在这里插入图片描述
在这里插入图片描述

下载完之后,需要在本地进行解压,解压之后的数据集格式如下图所示:

在这里插入图片描述
在这里插入图片描述

禁用wandb

训练之前,可以选择禁用wandb,wandb是和tensorboard类似的数据记录平台,为了防止报错,可以用下面的方式进行禁用。

在终端环境中输入wandb disabled

同时在utils/loggers/wandb/__init__.py中添加wandb = None

代码语言:javascript
复制
try:
    import wandb

    assert hasattr(wandb, '__version__')  # verify package import not local dir
    if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in {0, -1}:
        try:
            wandb_login_success = wandb.login(timeout=30)
        except wandb.errors.UsageError:  # known non-TTY terminal issue
            wandb_login_success = False
        if not wandb_login_success:
            wandb = None
except (ImportError, AssertionError):
    wandb = None
# 添加以下语句
wandb = None

utils/loggers/wandb/wandb_utils.py中同样添加

代码语言:javascript
复制
try:
    import wandb

    assert hasattr(wandb, '__version__')  # verify package import not local dir
except (ImportError, AssertionError):
    wandb = None
# 添加以下语句
wandb = None

开始训练

train.py中主要修改下面一些超参数,基本和目标检测类似,值得注意的是图像分类训练中,并不需要指定模型结构,模型结构完全包含在了预训练模型中,使用起来更为方便。如果需要深入了解YOLO模型是如何引出分类的,可以导出ONNX模型,再使用netron查看。

在这里插入图片描述
在这里插入图片描述

训练完成之后,会自动调用测试程序,绘制测试结果。

在这里插入图片描述
在这里插入图片描述

可以看到,我只使用YOLOv5-cls模型训练了10个epoch,就在mnist上取得了不错的效果。

模型预测

模型预测更简单,指定训练好的权重weights,输入图像source,图像尺寸imgsz即可。

模型会从高到低输出前5个类别的概率值。

在这里插入图片描述
在这里插入图片描述

代码备份

本次实验代码包含YOLOv5-6.2版本提供的所有预训练权重和mnist数据集。 备份地址:https://pan.baidu.com/s/1msi5qaE82nEbCha641lkPA?pwd=8888

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-08-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 更新概览
  • 分类模型效果
  • 工程结构
  • 数据集下载
  • 禁用wandb
  • 开始训练
  • 模型预测
  • 代码备份
相关产品与服务
日志服务
日志服务(Cloud Log Service,CLS)是腾讯云提供的一站式日志服务平台,提供了从日志采集、日志存储到日志检索,图表分析、监控告警、日志投递等多项服务,协助用户通过日志来解决业务运维、服务监控、日志审计等场景问题。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档