之前推过一篇深度学习雷达回波短临外推的推文 基于深度学习的多模型雷达回波外推,很多朋友想获取源代码,但因为一些原因这个代码无法开源。
本文给大家开源一个深度学习雷达回波外推的源代码,此外还会与大家简单分享一点经验,以后也会专门写一篇分享数据集构建,模型训练、评估及可解释性等方面的推文。文末获取所有代码。
深度学习在雷达回波外推方面的应用已经比较成熟了。很多时空预测模型(MIM,PredRNN,SA-ConvLSTM等)被提出以不断的去优化雷达回波外推所面临的各种问题。
这些时空预测模型不少已经开源在GitHub,但很少有完整的将这些模型应用到雷达回波外推的开源代码。本文主要以一个基本的U-net变体模型(SmaAt-Unet)为示例,介绍深度学习在雷达回波外推方面的应用,涉及到数据加载,模型训练,评估和推断,以及最终的可视化。
SmaAt-Unet模型是由 Trebing 等提出的降水短临预报模型,主要是在 U-net 模型的基础上加入了注意力模块和深度可分离卷积。
数据集采用的是长三角地区短临预报比赛雷达回波数据。本文仅是提供一个简单示例,因此并没有确保模型完全收敛到最佳解。开源代码包括功能:简单的日志监控、实时训练结果可视化、存储checkpoint和模型、恢复训练、模型推断可视化等。以上功能比较简单,但也都是经常使用的功能。
数据集的加载并没有统一的方式,这个取决于数据集的存储方式。比如长三角地区的短临比赛数据存储方式和华为云雷达回波比赛数据存储方式就有所差异。
长三角大赛:样本和数据分离,即样本信息单独以文件存储,而对应的图片数据统一存储在单独目录下,
样本(examples)和数据(data)目录
样本目录下文件
每个样本目录下文件
每个文件中包含了输入和对应的标签的文件名。
华为云大赛:样本和数据合并,即每个样本的雷达图单独目录存储
样本目录
每个样本目录中存储了每个样本的输入和对应输出的文件。
除了这两种存储方式之外,其他比赛的存储方式也会有所差异。比如SEVIR提供的数据集就是将训练集和测试集的样本分别存储到一个 hdf5 格式文件中。这对于直接加载所有样本到内存中处理而言比较方便。
这里说的数据加载主要侧重于使用别人已经构建好的数据集,如果是自己构建数据集,则要考虑数据集扩充以及加载等问题。
我平时在构建数据集时则是习惯按照样本,每个样本单独存储为 .npy 格式文件,这样对于扩充数据以及加载数据而言也相对比较方便。
一般而言,重要数据存储方式合适,数据集加载并不会很繁琐。
模型训练部分没有太多需要说明的。只需要按照常规训练方式选择好损失函数、优化器以及相应的参数即可。
在模型确定后,训练模型时,损失函数是至关重要的。针对特定的问题一定要选择合适的损失函数。本文开源代码选择了 MAE+MSE 混合损失函数。
模型训练过程可视化结果
这里提及一点,如果刚接触深度学习没多久,对于训练过程的细节不是很清楚的,可以利用 Pytorch Lightning 库来进行模型的训练,这样可以避免由于不明白模型训练过程中的一些细节所导致的问题。
模型训练过程中,通常需要关注随着模型的训练,模型有没有逐渐向着最佳解收敛。这时候我们就需要设置额外的评估指标关注模型的训练过程。对于气象领域而言,尤其是雷达回波外推,通常是通过 CSI(临界成功指数),也就是常说的TS评分来评估模型效果。本文并没有加入CSI等评估指标计算,可根据需要自行添加。
对于模型训练过程的监控,通常需要加入日志模块,可以采用 Logging 或者 Tensorboard 监控模型训练过程。本文开源的代码加入了简单的 Logging 和 Tensorboard 的日志监控。
除了CSI/TS评分指标之外,还有HSS、POD、FAR。这些通常用于评估确定性预报,对于集合概率预报评估,通常使用CRPS、BSS等指标。此外还有很多其他的评估指标,具体的还需要根据对应的问题选择合适的评估指标。关于这些评估指标晚上有很多开源的库,见文末参考链接。
模型训练并调整完成后,即可部署进行模型推断,注意模型推断过程不需要再计算梯度,对于 Pytorch 而言,需要添加相应的命令 (torch.no_grad()),也不需要再更新模型的权重参数 (model.eval())。
代码中加入了模型训练以及推断结果可视化的代码块,可以实时监测模型的训练过程。此部分提供了基于 matplotlib 和 PIL 的可视化功能,两者在效率上有较大差异。训练过程利用 matplotlib 进行可视化监控。
此部分功能相对比较简单,本文的初衷也仅是提供一个demo示例,对于业务而言可以在此基础上进一步拓展。
推断可视化示例
以上仅是对深度学习雷达回波外推模型中涉及的各个部分的简单介绍,并且功能比较简单,但对于入门者学习而言基本足够了,这也是本文的初衷。完整的模型训练和推断代码可以在后台回复「DLNW」获取。
参考资料