前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Keras系列(五) ConvLSTM 空间特征深度学习

Keras系列(五) ConvLSTM 空间特征深度学习

作者头像
MeteoAI
发布2019-07-24 16:37:09
10.9K1
发布2019-07-24 16:37:09
举报
文章被收录于专栏:MeteoAI

什么是ConvLSTM

LSTM擅长时序数据的处理,但是如果时序数据是图像等三维图形来讲,因为有着丰富的空间信息并且每一个点和周围是有很强的相关性的,普通的LSTM很难刻画这种空间特征,于是在LSTM的基础上加上卷积操作捕捉空间特征,对于图像的特征提取会更加有效。

图一 经典LSTM

图二 ConvLSTM

二者的区别是,经典的LSTM的state-state采用全连接形式,而ConvLSTM采用卷积的形式

图三 普通LSTM与ConvLSTM的区别

ConvLSTM实践

下面使用Keras实践ConvLSTM,用于预测人造的图片的变化情况。 如果同学们没有GPU环境的话,建议在谷歌免费的深度学习平台Colab上运行,当然文末也会有本地版的Jupyter notebook分享。

下面的代码是加载相关库和对ConvLSTM模型的搭建与编译。

产生样本数据

代码语言:javascript
复制
# Artificial data generation: # Generate movies with 3 to 7 moving squares inside.
# The squares are of shape 1x1 or 2x2 pixels, 
# which move linearly over time. # For convenience we first create movies with bigger width and height (80x80) 
#and at the end we select a 40x40 window. 
def generate_movies(n_samples=, n_frames=): 
    row =  
    col =  
    noisy_movies = np.zeros((n_samples, n_frames, row, col, ), dtype=np.float) 
    shifted_movies = np.zeros((n_samples, n_frames, row, col, ), dtype=np.float) 
    for i in range(n_samples): 
        # Add 3 to 7 moving squares 
        n = np.random.randint(, ) 
        for j in range(n): 
            # Initial position 
            xstart = np.random.randint(, ) 
            ystart = np.random.randint(, ) 
            # Direction of motion 
            directionx = np.random.randint(, ) -  
            directiony = np.random.randint(, ) -  
            # Size of the square 
            w = np.random.randint(, ) 
            for t in range(n_frames): 
                x_shift = xstart + directionx * t 
                y_shift = ystart + directiony * t 
                noisy_movies[i, t, x_shift - w: x_shift + w, y_shift - w: y_shift + w, ] +=  
                # Make it more robust by adding noise. 
                # The idea is that if during inference, # the value of the pixel is not exactly one, 
                # we need to train the network to be robust and still 
                # consider it as a pixel belonging to a square. 
                if np.random.randint(, ): 
                    noise_f = (-1)**np.random.randint(, ) 
                    noisy_movies[i, t, x_shift - w - : x_shift + w + , y_shift - w - : y_shift + w + , ] += noise_f * 0.1 
                    # Shift the ground truth by 1 
                    x_shift = xstart + directionx * (t + ) 
                    y_shift = ystart + directiony * (t + ) 
                    shifted_movies[i, t, x_shift - w: x_shift + w, y_shift - w: y_shift + w, ] +=  
    # Cut to a 40x40 window 
    noisy_movies = noisy_movies[::, ::, :, :, ::] 
    shifted_movies = shifted_movies[::, ::, :, :, ::] 
    noisy_movies[noisy_movies >= ] =  
    shifted_movies[shifted_movies >= ] =  
    return noisy_movies, shifted_movies 

开始训练 、预测

代码语言:javascript
复制
# Train the network
noisy_movies, shifted_movies = generate_movies(n_samples=) 
seq.fit(noisy_movies[:], shifted_movies[:], batch_size=, epochs=, validation_split=0.05) 
# Testing the network on one movie 
# feed it with the first 7 positions and then 
# predict the new positions 
which =  
track = noisy_movies[which][:, ::, ::, ::] 
for j in range(): 
    new_pos = seq.predict(track[np.newaxis, ::, ::, ::, ::]) 
    new = new_pos[::, -1, ::, ::, ::] 
    track = np.concatenate((track, new), axis=) 

开始画图,对比

代码语言:javascript
复制
# And then compare the predictions # to the ground truth 
track2 = noisy_movies[which][::, ::, ::, ::] 
for i in range(): 
    fig = plt.figure(figsize=(, )) 
    ax = fig.add_subplot() 
    if i >= : 
        ax.text(, , 'Predictions !', fontsize=, color='w') 
    else: ax.text(, , 'Initial trajectory', fontsize=) 
    toplot = track[i, ::, ::, ] 
    plt.imshow(toplot) 
    ax = fig.add_subplot() 
    plt.text(, , 'Ground truth', fontsize=) 
    toplot = track2[i, ::, ::, ] 
    if i >= : 
        toplot = shifted_movies[which][i - , ::, ::, ] 
    plt.imshow(toplot) 
    plt.savefig('%i_animate.png' % (i + )) 

本地版代码链接:https://pan.baidu.com/s/1V4_eTYV7vi2UNg_XdiP5mg 提取码:8v9c

谷歌Colab云端在线运行代码地址: https://colab.research.google.com/drive/1XmlpMzQK1REHbjVy61WPSh-oli08JZeZ (记得在修改-笔记本设置 选择框内打开GPU加速)

另外再分享下OpenCV+TensorFlow学习视频,可用于图像处理、预测 链接:https://pan.baidu.com/s/1bolCSMyVYanxPMf9TpQWBg 提取码:m8jl

电子书 OpenCV For ML 链接:https://pan.baidu.com/s/1ErAHu1PNibTBrzJTtsjovg 提取码:estj

有错误或指导请直接联系铁柱邮箱deepwind@aliyun.com

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-02-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 MeteoAI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 什么是ConvLSTM
  • ConvLSTM实践
  • 产生样本数据
  • 开始训练 、预测
  • 开始画图,对比
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档