专栏首页大龄程序员的人工智能之路深度学习中超大规模数据集的处理

深度学习中超大规模数据集的处理

在机器学习项目中,如果使用的是比较小的数据集,数据集的处理上可以非常简单:加载每个单独的图像,对其进行预处理,然后输送给神经网络。但是,对于大规模数据集(例如ImageNet),我们需要创建一次只访问一部分数据集的数据生成器(比如mini batch),然后将小批量数据传递给网络。其实,这种方法在我们之前的示例中也有所涉及,在使用数据增强技术提升模型泛化能力一文中,我就介绍了通过数据增强技术批量扩充数据集,虽然那里并没有使用到超大规模的数据集。Keras提供的方法允许使用磁盘上的原始文件路径作为训练输入,而不必将整个数据集存储在内存中。

然而,这种方法的缺点也是很明显,非常低效。加载磁盘上的每个图像都需要I/O操作,学过计算机的同学都知道,I/O操作最耗时,这无疑会在整个训练管道中引入延迟。本来训练深度学习网络就够慢的,I/O瓶颈应尽可能避免。

HDF5

这个时候,该HDF5文件登场了。HDF是用于存储和分发科学数据的一种自我描述、多对象文件格式。HDF最早由美国国家超级计算应用中心NCSA开发,目前在非盈利组织HDF小组维护下继续发展。当前流行的版本是HDF5。HDF5拥有一系列的优异特性,使其特别适合进行大量科学数据的存储和操作,如它支持非常多的数据类型,灵活、通用、跨平台、可扩展、高效的I/O性能,支持几乎无限量(高达EB)的单文件存储等,详见其官方介绍:https://support.hdfgroup.org/HDF5/ 。

HDF5文件格式为何如此牛X?估计你也和我一样有强烈的好奇心。但是当我看到长达200页的spec,还是决定放弃深究其细节,毕竟我们需要聚焦到深度学习上。再说,python提供了hdf5库,让读写hdf5文件简单得如同读写普通文本文件。借助h5py模块,实现一个HDF5数据集读写类非常容易:

class HDF5DatasetWriter:
 def __init__(self, dims, output_path, data_key="images", buf_size=1000):
   # check to see if the output path exists, and if so, raise an exception
   if os.path.exists(output_path):
     raise ValueError("the supplied `output_path` already exists and cannot be overwritten.", output_path)   # open the HDF5 database for writing and create two datasets: one to store images/features
   # and another to store the class labels
   self.db = h5py.File(output_path, "w")
   self.data = self.db.create_dataset(data_key, dims, dtype="float")
   self.labels = self.db.create_dataset("labels", (dims[0],), dtype="int")   self.buf_size = buf_size
   self.buffer = {"data": [], "labels": []}
   self.idx = 0 def add(self, rows, labels):
   self.buffer["data"].extend(rows)
   self.buffer["labels"].extend(labels)   if len(self.buffer["data"]) >= self.buf_size:
     self.flush() def flush(self):
   i = self.idx + len(self.buffer["data"])
   self.data[self.idx:i] = self.buffer["data"]
   self.labels[self.idx:i] = self.buffer["labels"]
   self.idx = i
   self.buffer = {"data": [], "labels": []} def store_class_labels(self, class_labels):
   dt = h5py.special_dtype(vlen=str)
   labelset = self.db.create_dataset("label_names", (len(class_labels),), dtype=dt)
   labelset[:] = class_labels def close(self):
   if len(self.buffer["data"]) > 0:
     self.flush()   self.db.close()

其中主要用到的方法就是h5py.File和create_dataset,前一个方法生成HDF5文件,后一个方法创建数据集。

猫狗数据集

理论掌握再多,还是不如实例来得直接。对于个人开发者而言,收集超大规模数据集几乎是一个不可能完成的任务,幸运的是,由于互联网的开放性以及机器学习领域的共享精神,很多研究机构提供数据集公开下载。我们这里选用kaggle大赛使用的Kaggle: Dogs vs. Cats dataset。你可以前往 http://pyimg.co/xb5lb 下载,也可以在公众号平台对话框中回复”数据集“关键字,获取百度网盘下载链接。

请下载kaggle - dogs vs cats下的train.zip文件。下载train.zip文件后,解开压缩文件,你可以看到train目录下包含猫狗图片文件,从文件名可以推断出其所属的类别:

kaggle_dogs_vs_cats/train/cat.11866.jpg
...
kaggle_dogs_vs_cats/train/dog.11046.jpg

构建数据集

由于Kaggle: Dogs vs. Cats dataset的类别包含在文件名中间,我们很容易写出如下代码提取类别标签:

train_paths = list(paths.list_images(config.IMAGES_PATH))
train_labels = [p.split(os.path.sep)[-1].split(".")[0] for p in train_paths]

接下来划分数据集,学过吴恩达《机器学习》课程的同学可能知道,通常我们将数据集划分为 训练集、验证集和测试集 ,通常比例为6:2:2,但是对于大规模数据集来说,验证集和测试集分配20%,数量太大,也没有必要,这时通常给一个两千左右的固定值即可。

split = train_test_split(train_paths, train_labels, test_size=config.NUM_TEST_IMAGES, stratify=train_labels,
                        random_state=42)
(train_paths, test_paths, train_labels, test_labels) = splitsplit = train_test_split(train_paths, train_labels, test_size=config.NUM_VAL_IMAGES, stratify=train_labels,
                        random_state=42)
(train_paths, val_paths, train_labels, val_labels) = split

接下来就是遍历图片文件,并分别为训练集、验证集和测试集生成HDF5文件。

datasets = [
 ("train", train_paths, train_labels, config.TRAIN_HDF5),
 ("val", val_paths, val_labels, config.VAL_HDF5),
 ("test", test_paths, test_labels, config.TEST_HDF5)
]aap = AspectAwarePreprocessor(256, 256)
(R, G, B) = ([], [], [])for (dtype, paths, labels, output_path) in datasets:
 writer = HDF5DatasetWriter((len(paths), 256, 256, 3), output_path) # loop over the image paths
 for (i, (path, label)) in enumerate(zip(paths, labels)):
   image = cv2.imread(path)
   image = aap.preprocess(image)   if dtype == "train":
     (b, g, r) = cv2.mean(image)[:3]
     R.append(r)
     G.append(g)
     B.append(b)   writer.add([image], [label]) writer.close()

注意到,代码中累计了RGB均值,可以使用以下代码计算RGB均值:

D = {"R": np.mean(R), "G": np.mean(G), "B": np.mean(B)}
f = open(config.DATASET_MEAN, "w")
f.write(json.dumps(D))
f.close()

为啥需要RGB均值呢?这就涉及到深度学习中的一个正则化技巧,在我们之前的代码中,都是RGB值除以255.0进行正则化,但实践表明,将RGB值减去均值,效果更好,所以在此计算RGB的均值。需要注意的是,正则化只针对训练数据集,目的是让训练出的模型具有更强的泛化能力。

构建数据集用时最长的是训练数据集,用时大约两分半,而验证集和测试集则比较快,大约20秒。这额外的3分钟时间是否值得花,在后面的文章中,我们将继续分析。

让我们看看最后生成的HDF5文件:

-rw-rw-r-- 1 alex alex  3932182048 Feb 18 11:33 test.hdf5
-rw-rw-r-- 1 alex alex 31457442048 Feb 18 11:31 train.hdf5
-rw-rw-r-- 1 alex alex  3932182048 Feb 18 11:32 val.hdf5

是的,你没看错,train.hdf5高达30G,害得我不得不删掉硬盘上许多文件,才腾出这么多空间。

为什么这样,要知道原始的图像包train.zip文件才500多M?这是因为,JPEG和PNG等图像文件格式使用了数据压缩算法,以保持较小的图像文件大小。但是,在我们的处理中,将图像存储为原始NumPy阵列(即位图)。虽然这样大大增加了存储成本,但也有助于加快训练时间,因为不必浪费处理器时间解码图像。

在下一篇文章中,我将演示如何读取HDF5文件,进行猫狗识别模型训练。

以上实例均有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。

另外,我在阅读《Deep Learning for Computer Vision with Python》这本书,在微信公众号后台回复“计算机视觉”关键字,可以免费下载这本书的电子版

本文分享自微信公众号 - 云水木石(ourpoeticlife),作者:陈正勇

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-02-20

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • YOLO v3有哪些新特点?

    You only look once(你只需看一次),或者YOLO,是目前比较快的目标对象检测算法之一。虽然它不再是最精确的目标对象检测算法,但是当您需要实时检...

    云水木石
  • [译]GPU加持,TensorFlow Lite更快了

    由于处理器性能和电池容量有限,在移动设备上使用计算密集的机器学习模型进行推断是非常耗资源的。 虽然可以采用一种加速途径:转换为定点数模型,但用户已经要求作为一种...

    云水木石
  • 我的第一个caffe C++程序

    最近一段时间一直在考虑为浏览器添加AI过滤裸露图片的功能,但目前大多数AI相关的教程都是用python语言。如果是训练模型,使用python语言无疑是最合适的,...

    云水木石
  • 关于首页倒计时处理一些细节

    促销商品展示的 Cell 是重用的,开始的时候其他栏目是没有赋值的。导致是不能收到已经停止的消息的,自然也就没办法从列表里面进行移除

    君赏
  • 模型之母:简单线性回归的代码实现

    关于作者:饼干同学,某人工智能公司交付开发工程师/建模科学家。专注于AI工程化及场景落地,希望和大家分享成长中的专业知识与思考感悟。

    木东居士
  • 在 PyQt4 中的菜单和工具栏¶

    QtGui.QMainWindow 类提供了一个应用的主窗口。这使得我们可以创建典型的应用框架,包括状态栏,工具栏和菜单。

    bear_fish
  • 在 PyQt4 中的菜单和工具栏¶

    http://www.cppblog.com/mirguest/archive/2012/02/05/164982.html

    bear_fish
  • 如何用Python3实现12306火车票自动抢票,小白必学

    最近在学Python,所以用Python写了这个12306抢票脚本,分享出来,与大家共同交流和学习,有不对的地方,请大家多多指正。话不多说,进入正题:在进入正题...

    用户7286429
  • PySpark工作原理

    Spark是一个开源的通用分布式计算框架,支持海量离线数据处理、实时计算、机器学习、图计算,结合大数据场景,在各个领域都有广泛的应用。Spark支持多种开发语言...

    Fayson
  • SQL学习笔记之简易ORM

    1 、我在实例化一个user对象的时候,可以user=User(name='lqz',password='123')

    Jetpropelledsnake21

扫码关注云+社区

领取腾讯云代金券