前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >卫星图像中的船舶检测

卫星图像中的船舶检测

作者头像
代码医生工作室
发布2019-06-21 22:51:14
1.7K1
发布2019-06-21 22:51:14
举报
文章被收录于专栏:相约机器人相约机器人

作者 | Daniel Moraite

来源 | Towards Data Science

编辑 | 代码医生团队

卫星图像是数据科学家可以使用的最丰富的数据源之一。这是选择首先考虑的部分,因为它减少了收集数据的工作,甚至减少了个人项目的附属研究。它也有一个缺点:个人计算机存储大小和计算能力有限。需要查找AWS Amazon Web Services以弥补它。

与此同时发现了一个非常小的数据集:行星卫星图像,可以在个人计算机上运行它。

关于数据:

  • 包括4000个80x80 RGB图像,标记为“ship”或“no-ship”分类,值为1或0。
  • 图像被正射校正为3米像素尺寸
  • 数据集为.png图像,图像文件名遵循特定格式:{label} __ {scene id} __ {longitude} _ {latitude} .png
  • longitude_latitude:图像中心点的经度和纬度坐标
  • dataset也作为JSON格式的文本文件分发,包含:data,label,scene_ids和location list
  • 单个图像的像素值数据存储为19200个整数的列表:第一个6400包含红色通道,下一个6400包含绿色,最后6400包含蓝色。
  • 标签,scene_ids和位置中的索引i处的列表值每个对应于数据列表中的第i个图像
  • 类标签:“船”类包括1000个图像,靠近单个船体的中心。“无船”类包括3000幅图像,1/3是不同土地覆盖特征的随机抽样。 - 不包括船舶的任何部分。下一个1/3是“部分船只”,而1/3是先前被机器学习模型错误标记的图像(由于强大的线性特征)。

想要实现的目标:检测卫星图像中船舶的位置,可用于解决以下问题:监控港口活动和供应链分析。

第一部分

阅读和准备数据

确保导入需要的所有库和模块,除了常规的Keras:顺序,密集,扁平,激活和丢失也将使用Conv2D和MaxPooling2D(参见完整的笔记本文章末尾)。现在下载并研究数据集:

代码语言:javascript
复制
f = open(r'../ships-in-satellite-imagery/shipsnet.json')

  dataset = json.load(f)

  f.close()

  input_data = np.array(dataset['data']).astype('uint8')

  output_data = np.array(dataset['labels']).astype('uint8')

  input_data.shape

  (4000, 19200)

  # and since I was currios to see how the tupple of arrays of arrays look like:

  input_data

  array([[ 82,  89,  91, ...,  86,  88,  89],

         [ 76,  75,  67, ...,  54,  57,  58],

         [125, 127, 129, ..., 111, 109, 115],

         ...,

         [171, 135, 118, ...,  95,  95,  85],

         [ 85,  90,  94, ...,  96,  95,  89],

         [122, 122, 126, ...,  51,  46,  69]], dtype=uint8)

  # now we realize that this is not a photo format that we can visualize, in order to be able to read an image we need to reshape the array/input_data:

  n_spectrum = 3 # the number of color chanels: RGB 

  weight = 80

  height = 80

  X = input_data.reshape([-1, n_spectrum, weight, height])

  X[0].shape



  # let`s pick one channel

  pic = X[3]

  red_spectrum = pic[0]

  green_spectrum = pic[1]

  blue_spectrum = pic[2]

有趣的部分:在所有3个频道上绘制照片:

代码语言:javascript
复制
plt.figure(2, figsize = (5*3, 5*1))

plt.set_cmap('jet')

#show each channel

plt.subplot(1, 3, 1)

plt.imshow(red_spectrum)

plt.subplot(1, 3, 2)

plt.imshow(green_spectrum)

plt.subplot(1, 3, 3)

plt.imshow(blue_spectrum)

plt.show()

如果X [0]中的某些照片可能具有相同的所有3个波段,只需尝试另一个X [3]。

输出是4000个元素的向量:

代码语言:javascript
复制
output_data

  array([1, 1, 1, ..., 0, 0, 0], dtype=uint8)

  np.bincount(output_data)

  array([3000, 1000])

矢量包含3000个零和1000个单位= 1000个图像标有“ship”和3000个图像标有“not ship”。

为keras准备数据

首先对标签进行分类编码:

代码语言:javascript
复制
# output encoding

  y = np_utils.to_categorical(output_data, 2)

第二次洗牌所有索引:

代码语言:javascript
复制
indexes = np.arange(4000)

  np.random.shuffle(indexes)

选择X_train,y_train:

代码语言:javascript
复制
X_train = X[indexes].transpose([0,2,3,1])

  y_train = y[indexes]

当然还有正常化:

代码语言:javascript
复制
X_train = X_train / 255

  # images are type uint8 with values in the [0, 255] interval and we would like to contain values between 0 and 1

第二部分

训练模型/神经网络

代码语言:javascript
复制
np.random.seed(42)

  # network design

  model = Sequential()

  model.add(Conv2D(32, (3, 3), padding='same', input_shape=(80, 80, 3), activation='relu'))

  model.add(MaxPooling2D(pool_size=(2, 2))) #40x40

  model.add(Dropout(0.25))

  model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))

  model.add(MaxPooling2D(pool_size=(2, 2))) #20x20

  model.add(Dropout(0.25))

  model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))

  model.add(MaxPooling2D(pool_size=(2, 2))) #10x10

  model.add(Dropout(0.25))

  model.add(Conv2D(32, (10, 10), padding='same', activation='relu'))

  model.add(MaxPooling2D(pool_size=(2, 2))) #5x5

  model.add(Dropout(0.25))

  model.add(Flatten())

  model.add(Dense(512, activation='relu'))

  model.add(Dropout(0.5))

  model.add(Dense(2, activation='softmax'))

有关relu,softmax和dropout的详细信息,请参阅Github博客文章

https://danielmoraite.github.io/docs/fifth.html

代码语言:javascript
复制
# optimization setup

  sgd = SGD(lr=0.01, momentum=0.9, nesterov=True)

  model.compile(

      loss='categorical_crossentropy',

      optimizer=sgd,

      metrics=['accuracy'])

  # training

  model.fit(

      X_train,

      y_train,

      batch_size=32, # 32 photos at once

      epochs=18,

      validation_split=0.2,

      shuffle=True,

      verbose=2)

请先喝杯茶,因为这可能需要几分钟:

代码语言:javascript
复制
Train on 3200 samples, validate on 800 samples

  Epoch 1/18

   - 67s - loss: 0.4076 - acc: 0.8219 - val_loss: 0.2387 - val_acc: 0.9025

  Epoch 2/18

   - 89s - loss: 0.2227 - acc: 0.9034 - val_loss: 0.1767 - val_acc: 0.9150

  Epoch 3/18

   - 74s - loss: 0.1809 - acc: 0.9278 - val_loss: 0.1481 - val_acc: 0.9425

  Epoch 4/18

   - 72s - loss: 0.1444 - acc: 0.9428 - val_loss: 0.1201 - val_acc: 0.9600

  Epoch 5/18

   - 48s - loss: 0.1334 - acc: 0.9522 - val_loss: 0.1126 - val_acc: 0.9513

  Epoch 6/18

   - 42s - loss: 0.1221 - acc: 0.9591 - val_loss: 0.0879 - val_acc: 0.9637

  Epoch 7/18

   - 40s - loss: 0.1068 - acc: 0.9625 - val_loss: 0.0846 - val_acc: 0.9738

  Epoch 8/18

   - 45s - loss: 0.0820 - acc: 0.9716 - val_loss: 0.0808 - val_acc: 0.9675

  Epoch 9/18

   - 41s - loss: 0.0851 - acc: 0.9728 - val_loss: 0.0626 - val_acc: 0.9838

  Epoch 10/18

   - 40s - loss: 0.0799 - acc: 0.9709 - val_loss: 0.0662 - val_acc: 0.9762

  Epoch 11/18

   - 42s - loss: 0.0672 - acc: 0.9800 - val_loss: 0.0599 - val_acc: 0.9812

  Epoch 12/18

   - 41s - loss: 0.0500 - acc: 0.9813 - val_loss: 0.0729 - val_acc: 0.9738

  Epoch 13/18

   - 42s - loss: 0.0570 - acc: 0.9784 - val_loss: 0.0625 - val_acc: 0.9788

  Epoch 14/18

   - 43s - loss: 0.0482 - acc: 0.9828 - val_loss: 0.0526 - val_acc: 0.9775

  Epoch 15/18

   - 42s - loss: 0.0510 - acc: 0.9822 - val_loss: 0.0847 - val_acc: 0.9762

  Epoch 16/18

   - 44s - loss: 0.0440 - acc: 0.9841 - val_loss: 0.0615 - val_acc: 0.9800

  Epoch 17/18

   - 41s - loss: 0.0411 - acc: 0.9862 - val_loss: 0.0559 - val_acc: 0.9775

  Epoch 18/18

   - 42s - loss: 0.0515 - acc: 0.9834 - val_loss: 0.0597 - val_acc: 0.9775

  Out[31]:

  <keras.callbacks.History at 0xb47f7bfd0>

有关categorical_crossentropy,'准确性',以及损失函数的详细信息,请参阅Github博客文章。

https://danielmoraite.github.io/docs/fifth.html

第三部分

在图像上应用模型和搜索

代码语言:javascript
复制
# download image

  image = Image.open(r'../ships-in-satellite-imagery/scenes/sfbay_1.png')

  pix = image.load()

如果想快速浏览一下:plt.imshow(image),为了能够正确使用它需要创建一个向量:

代码语言:javascript
复制
n_spectrum = 3

  width = image.size[0]

  height = image.size[1]

  # creat vector

  picture_vector = []

  for chanel in range(n_spectrum):

      for y in range(height):

          for x in range(width):

              picture_vector.append(pix[x, y][chanel])

  picture_vector = np.array(picture_vector).astype('uint8')

  picture_tensor = picture_vector.reshape([n_spectrum, height, width]).transpose(1, 2, 0)

  plt.figure(1, figsize = (15, 30))

  plt.subplot(3, 1, 1)

  plt.imshow(picture_tensor)

  plt.show()

在图像上搜索船只

代码语言:javascript
复制
picture_tensor = picture_tensor.transpose(2,0,1)

  # Search on the image

  def cutting(x, y):

      area_study = np.arange(3*80*80).reshape(3, 80, 80)

      for i in range(80):

          for j in range(80):

              area_study[0][i][j] = picture_tensor[0][y+i][x+j]

              area_study[1][i][j] = picture_tensor[1][y+i][x+j]

              area_study[2][i][j] = picture_tensor[2][y+i][x+j]

      area_study = area_study.reshape([-1, 3, 80, 80])

      area_study = area_study.transpose([0,2,3,1])

      area_study = area_study / 255

      sys.stdout.write('\rX:{0} Y:{1}  '.format(x, y))

      return area_study

  def not_near(x, y, s, coordinates):

      result = True

      for e in coordinates:

          if x+s > e[0][0] and x-s < e[0][0] and y+s > e[0][1] and y-s < e[0][1]:

              result = False

      return result



  def show_ship(x, y, acc, thickness=5):   

      for i in range(80):

          for ch in range(3):

              for th in range(thickness):

                  picture_tensor[ch][y+i][x-th] = -1

      for i in range(80):

          for ch in range(3):

              for th in range(thickness):

                  picture_tensor[ch][y+i][x+th+80] = -1

      for i in range(80):

          for ch in range(3):

              for th in range(thickness):

                  picture_tensor[ch][y-th][x+i] = -1

      for i in range(80):

          for ch in range(3):

              for th in range(thickness):

                  picture_tensor[ch][y+th+80][x+i] = -1

可以选择更多的步骤,而不是10或更少:只要有耐心,因为这可能需要一段时间。

代码语言:javascript
复制
step = 10; coordinates = []

  for y in range(int((height-(80-step))/step)):

      for x in range(int((width-(80-step))/step) ):

          area = cutting(x*step, y*step)

          result = model.predict(area)

          if result[0][1] > 0.90 and not_near(x*step,y*step, 88, coordinates):

              coordinates.append([[x*step, y*step], result])

              print(result)

              plt.imshow(area[0])

              plt.show()

正如所看到的那样:它确实分类为具有直线和明亮像素的船舶图像

  • 想这是找到一种方法来改进模型的下一步 - 尽管这是另一次。

或者给它第二次运行:

现在理解标签并在图像上找到它们:

代码语言:javascript
复制
for e in coordinates:

    show_ship(e[0][0], e[0][1], e[1][0][1])

picture_tensor = picture_tensor.transpose(1,2,0)

picture_tensor.shape

(1777, 2825, 3)

plt.figure(1, figsize = (15, 30))

plt.subplot(3,1,1)

plt.imshow(picture_tensor)

plt.show()

可以重新训练模型并给它另一次运行,或者使用当前模型进行第二次搜索,看看可能会得到什么。

资料来源:

Github博客文章

https://danielmoraite.github.io/docs/fifth.html

Kaggle比赛数据下载

https://www.kaggle.com/rhammell/ships-in-satellite-imagery

在GitHub上完整的Jupyter笔记本

https://github.com/DanielMoraite/DanielMoraite.github.io/blob/master/assets/Keras%20for%20search%20ships%20in%20satellite%20image.ipynb

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-04-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
数据保险箱
数据保险箱(Cloud Data Coffer Service,CDCS)为您提供更高安全系数的企业核心数据存储服务。您可以通过自定义过期天数的方法删除数据,避免误删带来的损害,还可以将数据跨地域存储,防止一些不可抗因素导致的数据丢失。数据保险箱支持通过控制台、API 等多样化方式快速简单接入,实现海量数据的存储管理。您可以使用数据保险箱对文件数据进行上传、下载,最终实现数据的安全存储和提取。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档