有关导入数据集的问题将使我发疯。
这是我的segnet代码的一部分。
我将重点讨论有关图像和掩码数据导入的问题。
print("CNN Model created.")
###training data
data_gen_args = dict()
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
seed1 = 1
image_datagen.fit(images, augment=True, seed=seed1)
mask_datagen.fit(masks, augment=True, seed=seed1)
train_image_generator = image_datagen.flow_from_directory(TRAIN_im,target_size=(500, 500),batch_size=BATCH_SIZE, class_mode = None)
train_mask_generator = mask_datagen.flow_from_directory(TRAIN_mask,target_size=(500, 500),batch_size=BATCH_SIZE, class_mode = None)
train_generator = zip(train_image_generator,train_mask_generator)
###validation data
valid_gen_args = dict()
val_image_datagen = ImageDataGenerator(**valid_gen_args)
val_mask_datagen = ImageDataGenerator(**valid_gen_args)
seed2 = 5
val_image_datagen.fit(val_images, augment=True, seed=seed2)
val_mask_datagen.fit(val_masks, augment=True, seed=seed2)
val_image_generator = val_image_datagen.flow_from_directory(VAL_im,target_size=(500, 500),batch_size=BATCH_SIZE, class_mode = None)
val_mask_generator = val_mask_datagen.flow_from_directory(VAL_mask,target_size=(500, 500),batch_size=BATCH_SIZE, class_mode = None)
val_generator = zip(val_image_generator,val_mask_generator)
###
model.fit_generator(
train_generator,steps_per_epoch=nb_train_samples//BATCH_SIZE,epochs=EPOCHS,validation_data=val_generator,validation_steps=nb_validation_samples//BATCH_SIZE)
我的问题是:
这是我的目录结构:
Dataset -training----------images----"many images"
| |
| |-----mask-----"ground truth images(mask)"
|
|
validation----------val_images----"many images"
| |
| |------val_mask------"ground truth images(mask)"
|
|
testing---------------test images (no ground truth)
非常感谢!
发布于 2017-12-21 14:04:33
我们开始吧。
images
和masks
是具有(num_imgs, width, height, num_channels)
形状的四维numpy阵列.这些变量从何而来?您必须在前面的步骤中从它们各自的图像文件中读取它们。flow_from_directory
是一个函数,可以与IDG一起使用,以便为您读取图像。非常方便。但是,如果您不需要featurewise_center
、featurewise_std_normalization
和zca_whitening
,那么我们只需要绕过它,因为在本例中,您需要已经可用的numpy数组来执行IDG fit()
函数。顺便说一句,这个fit函数与启动模型培训的fit()
函数无关。它只是使用相同的命名约定。https://stackoverflow.com/questions/47917729
复制相似问题