# 【Keras】基于SegNet和U-Net的遥感图像语义分割

## 数据集

1. 原图和label图都需要旋转：90度，180度，270度
2. 原图和label图都需要做沿y轴的镜像操作
3. 原图做模糊操作
4. 原图做光照调整操作
5. 原图做增加噪声操作（高斯噪声，椒盐噪声）

```img_w = 256  img_h = 256  image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):
gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]
gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)    return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):
log_gamma_vari = np.log(gamma_vari)
alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)
gamma = np.exp(alpha)    return gamma_transform(img, gamma)    def rotate(xb,yb,angle):
M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)
xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))
yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))    return xb,yb    def blur(img):
img = cv2.blur(img, (3, 3));    return imgdef add_noise(img):
for i in range(200): #添加点噪声
temp_x = np.random.randint(0,img.shape[0])
temp_y = np.random.randint(0,img.shape[1])
img[temp_x][temp_y] = 255
return img
def data_augment(xb,yb):
if np.random.random() < 0.25:
xb,yb = rotate(xb,yb,90)    if np.random.random() < 0.25:
xb,yb = rotate(xb,yb,180)    if np.random.random() < 0.25:
xb,yb = rotate(xb,yb,270)    if np.random.random() < 0.25:
xb = cv2.flip(xb, 1)  # flipcode > 0：沿y轴翻转
yb = cv2.flip(yb, 1)
if np.random.random() < 0.25:
xb = random_gamma_transform(xb,1.0)
if np.random.random() < 0.25:
xb = blur(xb)
if np.random.random() < 0.2:
return xb,ybdef creat_dataset(image_num = 100000, mode = 'original'):
print('creating dataset...')
image_each = image_num / len(image_sets)
g_count = 0
for i in tqdm(range(len(image_sets))):
count = 0
src_img = cv2.imread('./data/src/' + image_sets[i])  # 3 channels
X_height,X_width,_ = src_img.shape        while count < image_each:
random_width = random.randint(0, X_width - img_w - 1)
random_height = random.randint(0, X_height - img_h - 1)
src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]
label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]            if mode == 'augment':
src_roi,label_roi = data_augment(src_roi,label_roi)

visualize = np.zeros((256,256)).astype(np.uint8)
visualize = label_roi *50

cv2.imwrite(('./aug/train/visualize/%d.png' % g_count),visualize)
cv2.imwrite(('./aug/train/src/%d.png' % g_count),src_roi)
cv2.imwrite(('./aug/train/label/%d.png' % g_count),label_roi)
count += 1
g_count += 1```

## 卷积神经网络

### SegNet

SegNet已经出来好几年了，这不是一个最新、效果最好的语义分割网络，但是它胜在网络结构清晰易懂，训练快速坑少，所以我们也采取它来做同样的任务。SegNet网络结构是编码器-解码器的结构，非常优雅，值得注意的是，SegNet做语义分割时通常在末端加入CRF模块做后处理，旨在进一步精修边缘的分割结果。有兴趣深究的可以看看这里

```def SegNet():
model = Sequential()
#encoder
#(128,128)
#(64,64)
#(32,32)
#(16,16)
#(8,8)
#decoder
#(16,16)
#(32,32)
#(64,64)
#(128,128)
#(256,256)
#axis=1和axis=2互换位置，等同于np.swapaxes(layer,1,2)
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
model.summary()
return model  ```

```def get_train_val(val_rate = 0.25):
train_url = []
train_set = []
val_set  = []    for pic in os.listdir(filepath + 'src'):
train_url.append(pic)
random.shuffle(train_url)
total_num = len(train_url)
val_num = int(val_rate * total_num)    for i in range(len(train_url)):        if i < val_num:
val_set.append(train_url[i])
else:
train_set.append(train_url[i])    return train_set,val_set    # data for training  def generateData(batch_size,data=[]):
#print 'generateData...'
while True:
train_data = []
train_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
#print (filepath + 'src/' + url)
#img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))
img = load_img(filepath + 'src/' + url)
img = img_to_array(img)
# print img
# print img.shape
train_data.append(img)
#label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
train_label.append(label)
if batch % batch_size==0:
#print 'get enough bacth!\n'
train_data = np.array(train_data)
train_label = np.array(train_label).flatten()
train_label = labelencoder.transform(train_label)
train_label = to_categorical(train_label, num_classes=n_label)
train_label = train_label.reshape((batch_size,img_w * img_h,n_label))
yield (train_data,train_label)
train_data = []
train_label = []
batch = 0   # data for validation def generateValidData(batch_size,data=[]):
#print 'generateValidData...'
while True:
valid_data = []
valid_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
#img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))
img = load_img(filepath + 'src/' + url)            #print img
#print (filepath + 'src/' + url)
img = img_to_array(img)
# print img.shape
valid_data.append(img)
#label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label).flatten()
valid_label = labelencoder.transform(valid_label)
valid_label = to_categorical(valid_label, num_classes=n_label)
valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0  ```

```def train(args):
EPOCHS = 30
BS = 16
model = SegNet()
modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max')
callable = [modelcheck]
train_set,val_set = get_train_val()
train_numb = len(train_set)
valid_numb = len(val_set)
print ("the number of train data is",train_numb)
print ("the number of val data is",valid_numb)
H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,
validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)

# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on SegNet Satellite Seg")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(args["plot"])```

```def predict(args):
# load the trained convolutional neural network
stride = args['stride']    for n in range(len(TEST_SET)):
path = TEST_SET[n]        #load the image
image = cv2.imread('./test/' + path)        # pre-process the image for classification
#image = image.astype("float") / 255.0
#image = img_to_array(image)
h,w,_ = image.shape
padding_h = (h//stride + 1) * stride
padding_w = (w//stride + 1) * stride
_,ch,cw = crop.shape                if ch != 256 or cw != 256:                    print 'invalid size!'
continue

crop = np.expand_dims(crop, axis=0)                #print 'crop:',crop.shape
pred = model.predict_classes(crop,verbose=2)
pred = labelencoder.inverse_transform(pred[0])
#print (np.unique(pred))
pred = pred.reshape((256,256)).astype(np.uint8)                #print 'pred:',pred.shape

### U-Net

U-Net有很多优点，最大卖点就是它可以在小数据集上也能train出一个好的模型，这个优点对于我们这个任务来说真的非常适合。而且，U-Net在训练速度上也是非常快的，这对于需要短时间就得出结果的期末project来说也是非常合适。U-Net在网络架构上还是非常优雅的，整个呈现U形，故起名U-Net。这里不打算详细介绍U-Net结构，有兴趣的深究的可以看看论文。

```def unet():
inputs = Input((3, img_w, img_h))

conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)

up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)
conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)

up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)
conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)
conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)
conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)

conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

model = Model(inputs=inputs, outputs=conv10)

```# data for training  def generateData(batch_size,data=[]):
#print 'generateData...'
while True:
train_data = []
train_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
img = load_img(filepath + 'src/' + url)
img = img_to_array(img)
train_data.append(img)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label)            #print label.shape
train_label.append(label)
if batch % batch_size==0:
#print 'get enough bacth!\n'
train_data = np.array(train_data)
train_label = np.array(train_label)

yield (train_data,train_label)
train_data = []
train_label = []
batch = 0   # data for validation def generateValidData(batch_size,data=[]):
#print 'generateValidData...'
while True:
valid_data = []
valid_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
img = load_img(filepath + 'src/' + url)            #print img
img = img_to_array(img)
# print img.shape
valid_data.append(img)
label = load_img(filepath + 'label/' + url, grayscale=True)
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label)
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0  ```

`python unet.py --model unet_buildings20.h5 --data ./unet_train/buildings/`

```def combind_all_mask():

height,width = img.shape
label_value = idx+1  #coressponding labels value
for j in range(width):                    if img[i,j] == 255:                        if label_value == 2:

## 模型融合

```import numpy as npimport cv2import argparse

RESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):
result_list = []    for j in range(len(RESULT_PREFIXX)):
result_list.append(im)
# each pixel
height,width = result_list[0].shape
vote_mask = np.zeros((height,width))    for h in range(height):        for w in range(width):
record = np.zeros((1,5))            for n in range(len(result_list)):
record[0,pixel]+=1

label = record.argmax()            #print(label)

vote_per_image(3)```

## 总结

0 条评论

• ### 用深度学习keras的cnn做图像识别分类，准确率达97%

Keras是一个简约，高度模块化的神经网络库。 可以很容易和快速实现原型（通过总模块化，极简主义，和可扩展性） 同时支持卷积网络（vision）和复发性的网络...

• ### 文本挖掘：手把手教你分析携程网评论数据

作者：飘雪 http://www.itongji.cn/cms/article/articledetails?articleid=1114 中文文本挖掘包tm...

• ### 数据挖掘算法（logistic回归，随机森林，GBDT和xgboost）

面网易数据挖掘工程师岗位，第一次面数据挖掘的岗位，只想着能够去多准备一些，体验面这个岗位的感觉，虽然最好心有不甘告终，不过继续加油。 不过总的来看，面试前有准备...

• ### 独家 | 10分钟搭建你的第一个图像识别模型（附步骤、代码）

本文介绍了图像识别的深度学习模型的建立过程，通过陈述实际比赛的问题、介绍模型框架和展示解决方案代码，为初学者提供了解决图像识别问题的基础框架。

• ### 10分钟搭建你的第一个图像识别模型（附步骤、代码）

导读：本文介绍了图像识别的深度学习模型的建立过程，通过陈述实际比赛的问题、介绍模型框架和展示解决方案代码，为初学者提供了解决图像识别问题的基础框架。

• ### 10分钟搭建你的第一个图像识别模型 | 附完整代码

【导读】本文介绍了图像识别的深度学习模型的建立过程，通过陈述实际比赛的问题、介绍模型框架和展示解决方案代码，为初学者提供了解决图像识别问题的基础框架。

• ### java中的运算 ^, << , >>,&

那么这个1是怎么来的,我们要知道^、<<、>>等位运算符主要针对二进制,算异或的时候相同的为0,不同的为1 2转换成二进制是0010 3转换成二进制是001...

• ### 简单易学的机器学习算法——协同过滤推荐算法(2)

一、基于协同过滤的推荐系统     协同过滤(Collaborative Filtering)的推荐系统的原理是通过将用户和其他用户的数据进行比对来实现推荐的。...

• ### php 图片转成base64 原

function base64EncodeImage (\$image_file) {

• ### 简单易学的机器学习算法——协同过滤推荐算法(2)

协同过滤(Collaborative Filtering)的推荐系统的原理是通过将用户和其他用户的数据进行比对来实现推荐的。比对的具体方法就是通过计算两...