# 【Keras】基于SegNet和U-Net的遥感图像语义分割

## 数据集

1. 原图和label图都需要旋转：90度，180度，270度
2. 原图和label图都需要做沿y轴的镜像操作
3. 原图做模糊操作
4. 原图做光照调整操作
5. 原图做增加噪声操作（高斯噪声，椒盐噪声）

```img_w = 256  img_h = 256  image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):
gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]
gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)    return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):
log_gamma_vari = np.log(gamma_vari)
alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)
gamma = np.exp(alpha)    return gamma_transform(img, gamma)    def rotate(xb,yb,angle):
M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)
xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))
yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))    return xb,yb    def blur(img):
img = cv2.blur(img, (3, 3));    return imgdef add_noise(img):
for i in range(200): #添加点噪声
temp_x = np.random.randint(0,img.shape[0])
temp_y = np.random.randint(0,img.shape[1])
img[temp_x][temp_y] = 255
return img
def data_augment(xb,yb):
if np.random.random() < 0.25:
xb,yb = rotate(xb,yb,90)    if np.random.random() < 0.25:
xb,yb = rotate(xb,yb,180)    if np.random.random() < 0.25:
xb,yb = rotate(xb,yb,270)    if np.random.random() < 0.25:
xb = cv2.flip(xb, 1)  # flipcode > 0：沿y轴翻转
yb = cv2.flip(yb, 1)
if np.random.random() < 0.25:
xb = random_gamma_transform(xb,1.0)
if np.random.random() < 0.25:
xb = blur(xb)
if np.random.random() < 0.2:
return xb,ybdef creat_dataset(image_num = 100000, mode = 'original'):
print('creating dataset...')
image_each = image_num / len(image_sets)
g_count = 0
for i in tqdm(range(len(image_sets))):
count = 0
src_img = cv2.imread('./data/src/' + image_sets[i])  # 3 channels
X_height,X_width,_ = src_img.shape        while count < image_each:
random_width = random.randint(0, X_width - img_w - 1)
random_height = random.randint(0, X_height - img_h - 1)
src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]
label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]            if mode == 'augment':
src_roi,label_roi = data_augment(src_roi,label_roi)

visualize = np.zeros((256,256)).astype(np.uint8)
visualize = label_roi *50

cv2.imwrite(('./aug/train/visualize/%d.png' % g_count),visualize)
cv2.imwrite(('./aug/train/src/%d.png' % g_count),src_roi)
cv2.imwrite(('./aug/train/label/%d.png' % g_count),label_roi)
count += 1
g_count += 1```

## 卷积神经网络

### SegNet

SegNet已经出来好几年了，这不是一个最新、效果最好的语义分割网络，但是它胜在网络结构清晰易懂，训练快速坑少，所以我们也采取它来做同样的任务。SegNet网络结构是编码器-解码器的结构，非常优雅，值得注意的是，SegNet做语义分割时通常在末端加入CRF模块做后处理，旨在进一步精修边缘的分割结果。有兴趣深究的可以看看这里

```def SegNet():
model = Sequential()
#encoder
#(128,128)
#(64,64)
#(32,32)
#(16,16)
#(8,8)
#decoder
#(16,16)
#(32,32)
#(64,64)
#(128,128)
#(256,256)
#axis=1和axis=2互换位置，等同于np.swapaxes(layer,1,2)
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
model.summary()
return model  ```

```def get_train_val(val_rate = 0.25):
train_url = []
train_set = []
val_set  = []    for pic in os.listdir(filepath + 'src'):
train_url.append(pic)
random.shuffle(train_url)
total_num = len(train_url)
val_num = int(val_rate * total_num)    for i in range(len(train_url)):        if i < val_num:
val_set.append(train_url[i])
else:
train_set.append(train_url[i])    return train_set,val_set    # data for training  def generateData(batch_size,data=[]):
#print 'generateData...'
while True:
train_data = []
train_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
#print (filepath + 'src/' + url)
#img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))
img = load_img(filepath + 'src/' + url)
img = img_to_array(img)
# print img
# print img.shape
train_data.append(img)
#label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
train_label.append(label)
if batch % batch_size==0:
#print 'get enough bacth!\n'
train_data = np.array(train_data)
train_label = np.array(train_label).flatten()
train_label = labelencoder.transform(train_label)
train_label = to_categorical(train_label, num_classes=n_label)
train_label = train_label.reshape((batch_size,img_w * img_h,n_label))
yield (train_data,train_label)
train_data = []
train_label = []
batch = 0   # data for validation def generateValidData(batch_size,data=[]):
#print 'generateValidData...'
while True:
valid_data = []
valid_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
#img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))
img = load_img(filepath + 'src/' + url)            #print img
#print (filepath + 'src/' + url)
img = img_to_array(img)
# print img.shape
valid_data.append(img)
#label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label).flatten()
valid_label = labelencoder.transform(valid_label)
valid_label = to_categorical(valid_label, num_classes=n_label)
valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0  ```

```def train(args):
EPOCHS = 30
BS = 16
model = SegNet()
modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max')
callable = [modelcheck]
train_set,val_set = get_train_val()
train_numb = len(train_set)
valid_numb = len(val_set)
print ("the number of train data is",train_numb)
print ("the number of val data is",valid_numb)
H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,
validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)

# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on SegNet Satellite Seg")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(args["plot"])```

```def predict(args):
# load the trained convolutional neural network
stride = args['stride']    for n in range(len(TEST_SET)):
path = TEST_SET[n]        #load the image
image = cv2.imread('./test/' + path)        # pre-process the image for classification
#image = image.astype("float") / 255.0
#image = img_to_array(image)
h,w,_ = image.shape
padding_h = (h//stride + 1) * stride
padding_w = (w//stride + 1) * stride
_,ch,cw = crop.shape                if ch != 256 or cw != 256:                    print 'invalid size!'
continue

crop = np.expand_dims(crop, axis=0)                #print 'crop:',crop.shape
pred = model.predict_classes(crop,verbose=2)
pred = labelencoder.inverse_transform(pred[0])
#print (np.unique(pred))
pred = pred.reshape((256,256)).astype(np.uint8)                #print 'pred:',pred.shape

### U-Net

U-Net有很多优点，最大卖点就是它可以在小数据集上也能train出一个好的模型，这个优点对于我们这个任务来说真的非常适合。而且，U-Net在训练速度上也是非常快的，这对于需要短时间就得出结果的期末project来说也是非常合适。U-Net在网络架构上还是非常优雅的，整个呈现U形，故起名U-Net。这里不打算详细介绍U-Net结构，有兴趣的深究的可以看看论文。

```def unet():
inputs = Input((3, img_w, img_h))

conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)

up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)
conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)

up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)
conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)
conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)
conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)

conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

model = Model(inputs=inputs, outputs=conv10)

```# data for training  def generateData(batch_size,data=[]):
#print 'generateData...'
while True:
train_data = []
train_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
img = load_img(filepath + 'src/' + url)
img = img_to_array(img)
train_data.append(img)
label = load_img(filepath + 'label/' + url, grayscale=True)
label = img_to_array(label)            #print label.shape
train_label.append(label)
if batch % batch_size==0:
#print 'get enough bacth!\n'
train_data = np.array(train_data)
train_label = np.array(train_label)

yield (train_data,train_label)
train_data = []
train_label = []
batch = 0   # data for validation def generateValidData(batch_size,data=[]):
#print 'generateValidData...'
while True:
valid_data = []
valid_label = []
batch = 0
for i in (range(len(data))):
url = data[i]
batch += 1
img = load_img(filepath + 'src/' + url)            #print img
img = img_to_array(img)
# print img.shape
valid_data.append(img)
label = load_img(filepath + 'label/' + url, grayscale=True)
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label)
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0  ```

`python unet.py --model unet_buildings20.h5 --data ./unet_train/buildings/`

```def combind_all_mask():

height,width = img.shape
label_value = idx+1  #coressponding labels value
for j in range(width):                    if img[i,j] == 255:                        if label_value == 2:

## 模型融合

```import numpy as npimport cv2import argparse

RESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):
result_list = []    for j in range(len(RESULT_PREFIXX)):
result_list.append(im)
# each pixel
height,width = result_list[0].shape
vote_mask = np.zeros((height,width))    for h in range(height):        for w in range(width):
record = np.zeros((1,5))            for n in range(len(result_list)):
record[0,pixel]+=1

label = record.argmax()            #print(label)

vote_per_image(3)```

## 总结

