前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >完美解决keras 读取多个hdf5文件进行训练的问题

完美解决keras 读取多个hdf5文件进行训练的问题

作者头像
砸漏
发布2020-10-21 10:03:24
9640
发布2020-10-21 10:03:24
举报
文章被收录于专栏:恩蓝脚本

用keras进行大数据训练,为了加快训练,需要提前制作训练集。

由于HDF5的特性,所有数据需要一次性读入到内存中,才能保存。

为此,我采用分批次分为2个以上HDF5进行存储。

1、先读取每个标签下的图片,并设置标签

代码语言:javascript
复制
def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = [] 
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)

2、该标签下的数据集分割为训练集(train images),验证集(val images),训练标签(train labels),验证标签

(val labels)

代码语言:javascript
复制
def split_dataset(images, labels): 

 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
  
 #print(train_images.shape[0], 'train samples')
 #print(valid_images.shape[0], 'valid samples') 
 return train_images, valid_images, train_labels ,valid_labels

3、分割后的数据分别添加到总的训练集,验证集,训练标签,验证标签。

其次,清空原有的图片集和标签集,目的是节省内存。假如一次性读入多个标签的数据集与标签集,进行数据分割后,会占用大于单纯进行上述操作两倍以上的内存。

代码语言:javascript
复制
images = np.array(images) 
t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
for i in range(len(t_images)):
 train_images.append(t_images[i])
 train_labels.append(t_labels[i]) 
for j in range(len(v_images)):
 valid_images.append(v_images[j])
 valid_labels.append(v_labels[j])
if counter%50== 49:
 print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 
images = []
labels = [] 
counter = counter + 1 

print("train_images num: ", len(train_images), " ", "valid_images num: ",len(valid_images))

4、进行判断,直到读到自己自己分割的那个标签。

开始进行写入。写入之前,为了更好地训练模型,需要把对应的图片集和标签打乱顺序。

代码语言:javascript
复制
if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, t_images, v_images, t_labels ,v_labels)

对应打乱顺序并写入到HDF5

代码语言:javascript
复制
def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #shuffle
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # HDF5 initial
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

5、判断文件全部读入

代码语言:javascript
复制
def read_dataset(dirs):
 
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file
 dataset = h5py.File(path, "r")
 file = file.split('.')
 set_x_orig = dataset["x_"+file[0]].shape[0]
 set_y_orig = dataset["y_"+file[0]].shape[0]

 print(set_x_orig)
 print(set_y_orig)

6、训练中,采用迭代器读入数据

代码语言:javascript
复制
 def generator(self, datagen, mode):
 
 passes=np.inf
 aug = ImageDataGenerator(
  featurewise_center = False,  
  samplewise_center = False,  
  featurewise_std_normalization = False, 
  samplewise_std_normalization = False, 
  zca_whitening = False,   
  rotation_range = 20,   
  width_shift_range = 0.2,  
  height_shift_range = 0.2,  
  horizontal_flip = True,  
  vertical_flip = False)  
 
 epochs = 0  
 # 默认是无限循环遍历
 
 while epochs < passes:
  # 遍历数据
  file_dir = os.listdir(self.data_path)
  for file in file_dir:
  #print(file)
  file_path = os.path.join(self.data_path,file)
  TRAIN_HDF5 = file_path +"/train.hdf5"
  VAL_HDF5 = file_path +"/val.hdf5"
  #TEST_HDF5 = file_path +"/test.hdf5"
  
  db_t = h5py.File(TRAIN_HDF5)
  numImages_t = db_t['y_train'].shape[0] 
  db_v = h5py.File(VAL_HDF5)
  numImages_v = db_v['y_val'].shape[0] 
  
  if mode == "train":  
   for i in np.arange(0, numImages_t, self.BS):
   
   images = db_t['x_train'][i: i+self.BS]
   labels = db_t['y_train'][i: i+self.BS]
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
      
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   
      
   # one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes)   
   
   yield ({'input_1': images}, {'softmax': labels})
    
  elif mode == "val":
   for i in np.arange(0, numImages_v, self.BS):
   images = db_v['x_val'][i: i+self.BS]
   labels = db_v['y_val'][i: i+self.BS] 
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
   
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   

   #one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes) 
    
   yield ({'input_1': images}, {'softmax': labels})
     
  epochs += 1

7、至此,就大功告成了

完整的代码:

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 12 20:46:12 2018
@author: william_yue
"""
import os
import numpy as np
import cv2
import random
from scipy import misc
import h5py
from sklearn.model_selection import train_test_split
from keras import backend as K
K.clear_session()
from keras.utils import np_utils
IMAGE_SIZE = 128
# 加载数据集并按照交叉验证的原则划分数据集并进行相关预处理工作
def split_dataset(images, labels): 
# 导入了sklearn库的交叉验证模块,利用函数train_test_split()来划分训练集和验证集
# 划分出了20%的数据用于验证,80%用于训练模型
train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
labels, test_size = 0.2, random_state = random.randint(0, 100)) 
return train_images, valid_images, train_labels ,valid_labels
def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
#def data2h5(dirs_path, train_images, valid_images, test_images, train_labels ,valid_labels, test_labels):
TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
VAL_HDF5 = dirs_path + '/' + "val.hdf5"
#采用标签与图片相同的顺序分别打乱训练集与验证集
state1 = np.random.get_state()
np.random.shuffle(train_images)
np.random.set_state(state1)
np.random.shuffle(train_labels)
state2 = np.random.get_state()
np.random.shuffle(valid_images)
np.random.set_state(state2)
np.random.shuffle(valid_labels)
datasets = [
("train",train_images,train_labels,TRAIN_HDF5),
("val",valid_images,valid_labels,VAL_HDF5)]
for (dType,images,labels,outputPath) in datasets:
# 初始化HDF5写入
f = h5py.File(outputPath, "w")
f.create_dataset("x_"+dType, data=images)
f.create_dataset("y_"+dType, data=labels)
#f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
#f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
f.close()
def read_dataset(dirs):
files = os.listdir(dirs)
print(files)
for file in files:
path = dirs+'/' + file 
file_read = os.listdir(path)
for i in file_read:
path_read = os.path.join(path, i)
dataset = h5py.File(path_read, "r")
i = i.split('.')
set_x_orig = dataset["x_"+i[0]].shape[0]
set_y_orig = dataset["y_"+i[0]].shape[0]
print(set_x_orig)
print(set_y_orig)
#循环读取每个标签集下的所有图片
def load_dataset(path_name,data_path):
images = []
labels = []
train_images = []
valid_images = []
train_labels = []
valid_labels = []
counter = 0
allpath = os.listdir(path_name)
nb_classes = len(allpath)
print("label_num: ",nb_classes)
for child_dir in allpath:
child_path = os.path.join(path_name, child_dir)
for dir_image in os.listdir(child_path):
if dir_image.endswith('.jpg'):
img = cv2.imread(os.path.join(child_path, dir_image))  
image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
#resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
images.append(image)
labels.append(counter)
images = np.array(images) 
t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
for i in range(len(t_images)):
train_images.append(t_images[i])
train_labels.append(t_labels[i]) 
for j in range(len(v_images)):
valid_images.append(v_images[j])
valid_labels.append(v_labels[j])
if counter%50== 49:
print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
images = []
labels = [] 
if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
print("train_images num: ", len(train_images), "  ", "valid_images num: ",len(valid_images)) 
print("start write images and labels data...................................................................")  
num = counter // 5000
dirs = data_path + "/" + "h5_" + str(num - 1)
if not os.path.exists(dirs):
os.makedirs(dirs)
data2h5(dirs, train_images, valid_images, train_labels ,valid_labels)
#read_dataset(dirs)
print("File HDF5_%d "%num, " id done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
train_images = []
valid_images = []
train_labels = []
valid_labels = [] 
counter = counter + 1 
print("All File HDF5 done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
read_dataset(data_path) 
#读取训练数据集的文件夹,把他们的名字返回给一个list
def read_name_list(path_name):
name_list = []
for child_dir in os.listdir(path_name):
name_list.append(child_dir)
return name_list
if __name__ == '__main__':
path = "data"
data_path = "data_hdf5_half"
if not os.path.exists(data_path):
os.makedirs(data_path)
load_dataset(path,data_path)

以上这篇完美解决keras 读取多个hdf5文件进行训练的问题就是小编分享给大家的全部内容了,希望能给大家一个参考。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-09-11 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档