前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【猫狗数据集】pytorch训练猫狗数据集之创建数据集

【猫狗数据集】pytorch训练猫狗数据集之创建数据集

作者头像
西西嘛呦
发布2020-08-26 10:26:13
8700
发布2020-08-26 10:26:13
举报

数据集下载地址:

链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4

猫狗数据集的分为训练集25000张,在训练集中猫和狗的图像是混在一起的,pytorch读取数据集有两种方式,第一种方式是将不同类别的图片放于其对应的类文件夹中,另一种是实现读取数据集类,该类继承torch.utils.Dataset,并重写__getitem__和__len__。

先将猫和狗从训练集中区分开来,分别放到dog和cat文件夹下:

代码语言:javascript
复制
import glob
import shutil
import os

#数据集目录
path = "./ml/dogs-vs-cats/train"
#训练集目录
train_path = path+'/train'
#测试集目录
test_path = path+'/test'

#将某类图片移动到该类的文件夹下
def img_to_file(path):
    print("=========开始移动图片============")
    #如果没有dog类和cat类文件夹,则新建
    if not os.path.exists(path+"/dog"):
            os.makedirs(path+"/dog")
    if not os.path.exists(path+"/cat"):
            os.makedirs(path+"/cat")
    print("共:{}张图片".format(len(glob.glob(path+"/*.jpg"))))
    #通过glob遍历到所有的.jpg文件
    for imgPath in glob.glob(path+"/*.jpg"):
        #print(imgPath)
        #使用/划分
        img=imgPath.strip("\n").replace("\\","/").split("/")
        #print(img)
        #将图片移动到指定的文件夹中
        if img[-1].split(".")[0] == "cat":
            shutil.move(imgPath,path+"/cat")
        if img[-1].split(".")[0] == "dog":
            shutil.move(imgPath,path+"/dog")
    print("=========移动图片完成============")    
img_to_file(train_path)
print("训练集猫共:{}张图片".format(len(glob.glob(train_path+"/cat/*.jpg"))))
print("训练集狗共:{}张图片".format(len(glob.glob(train_path+"/dog/*.jpg"))))

然后从dog中和cat中分别抽取1250张,共2500张图片作为测试集。

代码语言:javascript
复制
import random

def split_train_test(fileDir,tarDir):

        if not os.path.exists(tarDir):
            os.makedirs(tarDir)
        pathDir = os.listdir(fileDir)    #取图片的原始路径
        filenumber=len(pathDir)
        rate=0.1    #自定义抽取图片的比例,比方说100张抽10张,那就是0.1
        picknumber=int(filenumber*rate) #按照rate比例从文件夹中取一定数量图片
        sample = random.sample(pathDir, picknumber)  #随机选取picknumber数量的样本图片
        print("=========开始移动图片============")
        for name in sample:
                shutil.move(fileDir+name, tarDir+name)
        print("=========移动图片完成============")
split_train_test(train_path+'/dog/',test_path+'/dog/')  
split_train_test(train_path+'/cat/',test_path+'/cat/')  

最终,我们就有以下结构了:

其中train包含22500张图片,其中dog类和cat类各11250张。test包含2500张图片,其中dog类和cat类各1250张。

发现测试集还是有点少,那就再来一遍了。

最后,train包含20250张图片,其中dog类和cat类各10125张。test包含4750张图片,其中dog类和cat类各2375张。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-03-04 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档