前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【YOLOv8】自定义姿态评估模型训练

【YOLOv8】自定义姿态评估模型训练

作者头像
OpenCV学堂
发布2023-12-26 16:31:57
5180
发布2023-12-26 16:31:57
举报

前言

Hello大家好,今天给大家分享一下如何基于YOLOv8姿态评估模型,实现在自定义数据集上,完成自定义姿态评估模型的训练与推理。

01

tiger-pose数据集

YOLOv8官方提供了一个自定义tiger-pose数据集(老虎姿态评估),总计数据有263张图像、其中210张作为训练集、53张作为验证集。

其中YOLOv8-pose的数据格式如下:

解释一下:

代码语言:javascript
复制
Class-index 表示对象类型索引,从0开始
后面的四个分别是对象的中心位置与宽高 xc、yc、width、height
px1,py1表示第一个关键点坐标、p1v表示师傅可见,默认填2即可。
kpt_shape=12x2 表示有12个关键点,每个关键点是x,y

02

模型训练

跟训练YOLOv8对象检测模型类似,直接运行下面的命令行即可:

代码语言:javascript
复制
yolo train model=yolov8n-pose.pt data=tiger_pose_dataset.yaml epochs=100 imgsz=640 batch=1

03

模型导出预测

训练完成以后模型预测推理测试 使用下面的命令行:

代码语言:javascript
复制
yolo predict model=tiger_pose_best.pt source=D:/123.jpg

导出模型为ONNX格式,使用下面命令行即可

代码语言:javascript
复制
yolo export model=tiger_pose_best.pt format=onnx

04

部署推理

基于ONNX格式模型,采用ONNXRUNTIME推理结果如下:

ORT相关的推理演示代码如下:

代码语言:javascript
复制
代码语言:javascript
复制
def ort_pose_demo():

    # initialize the onnxruntime session by loading model in CUDA support
    model_dir = "tiger_pose_best.onnx"
    session = onnxruntime.InferenceSession(model_dir, providers=['CUDAExecutionProvider'])

    # 就改这里, 把RTSP的地址配到这边就好啦,然后直接运行,其它任何地方都不准改!
    # 切记把 yolov8-pose.onnx文件放到跟这个python文件同一个文件夹中!
    frame = cv.imread("D:/123.jpg")
    bgr = format_yolov8(frame)
    fh, fw, fc = frame.shape

    start = time.time()
    image = cv.dnn.blobFromImage(bgr, 1 / 255.0, (640, 640), swapRB=True, crop=False)

    # onnxruntime inference
    ort_inputs = {session.get_inputs()[0].name: image}
    res = session.run(None, ort_inputs)[0]

    # matrix transpose from 1x8x8400 => 8400x8
    out_prob = np.squeeze(res, 0).T

    result_kypts, confidences, boxes = wrap_detection(bgr, out_prob)
    for (kpts, confidence, box) in zip(result_kypts, confidences, boxes):
        cv.rectangle(frame, box, (0, 0, 255), 2)
        cv.rectangle(frame, (box[0], box[1] - 20), (box[0] + box[2], box[1]), (0, 255, 255), -1)
        cv.putText(frame, ("%.2f" % confidence), (box[0], box[1] - 10), cv.FONT_HERSHEY_SIMPLEX, .5, (0, 0, 0))
        cv.circle(frame, (int(kpts[0]), int(kpts[1])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[2]), int(kpts[3])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[4]), int(kpts[5])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[6]), int(kpts[7])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[8]), int(kpts[9])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[10]), int(kpts[11])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[12]), int(kpts[13])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[14]), int(kpts[15])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[16]), int(kpts[17])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[18]), int(kpts[19])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[20]), int(kpts[21])), 3, (255, 0, 255), 4, 8, 0)
        cv.circle(frame, (int(kpts[22]), int(kpts[23])), 3, (255, 0, 255), 4, 8, 0)

    cv.imshow("Tiger Pose Demo - gloomyfish", frame)
    cv.waitKey(0)
    cv.destroyAllWindows()
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-12-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档