专栏首页Jack-CuiPytorch深度学习实战教程(四):必知必会的炼丹法宝

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善。

一、前言

训练深度学习模型,就像“炼丹”,模型可能需要训练很多天。

我们不可能像「太上老君」那样,拿着浮尘,24 小时全天守在「八卦炉」前,更何况人家还有炼丹童、天兵天将,轮流值守。

人手不够,“法宝”来凑。

本文就盘点一下,我们可以使用的「炼丹法宝」。

PS:文中出现的所有代码,均可在我的 Github 上下载:点击查看

二、初级“法宝”,sys.stdout

训练模型,最常看的指标就是 Loss。我们可以根据 Loss 的收敛情况,初步判断模型训练的好坏。

如果,Loss 值突然上升了,那说明训练有问题,需要检查数据和代码。

如果,Loss 值趋于稳定,那说明训练完毕了。

观察 Loss 情况,最直观的方法,就是绘制 Loss 曲线图。

通过绘图,我们可以很清晰的看到,左图还有收敛空间,而右图已经完全收敛。

通过 Loss 曲线,我们可以分析模型训练的好坏,模型是否训练完成,起到一个很好的“监控”作用。

绘制 Loss 曲线图,第一步就是需要保存训练过程中的 Loss 值。

一个最简单的方法是使用,sys.stdout 标准输出重定向,简单好用,实乃“炼丹”必备“良宝”。

import os
import sys
class Logger():
    def __init__(self, filename="log.txt"):
        self.terminal = sys.stdout
        self.log = open(filename, "w")
 
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
 
    def flush(self):
        pass
 
sys.stdout = Logger()
 
print("Jack Cui")
print("https://cuijiahua.com")
print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")

代码很简单,创建一个 log.py 文件,自己写一个 Logger 类,并采用 sys.stdout 重定向输出。

在 Terminal 中,不仅可以使用 print 打印结果,同时也会将结果保存到 log.txt 文件中。

运行 log.py,打印 print 内容的同时,也将内容写入了 log.txt 文件中。

使用这个代码,就可以在打印 Loss 的同时,将结果保存到指定的 txt 中,比如保存上篇文章训练 UNet 的 Loss。

三、中级“法宝”,matplotlib

Matplotlib 是一个 Python 的绘图库,简单好用。

简单几行命令,就可以绘制曲线图、散点图、条形图、直方图、饼图等等。

深度学习中,一般就是绘制曲线图,比如 Loss 曲线、Acc 曲线。

举一个,简单的例子。

使用 sys.stdout 保存的 train_loss.txt,绘制 Loss 曲线。

train_loss.txt 下载地址:点击查看

思路非常简单,读取 txt 内容,解析 txt 内容,使用 Matplotlib 绘制曲线。

import matplotlib.pyplot as plt
# Jupyter notebook 中开启
# %matplotlib inline
with open('train_loss.txt', 'r') as f:
    train_loss = f.readlines()
    train_loss = list(map(lambda x:float(x.strip()), train_loss))
x = range(len(train_loss))
y = train_loss
plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

指定 x 和 y 对应的值,就可以绘制。

是不是很简单?

关于 Matplotlib 更多的详细教程,可以查看官方手册:点击查看

四、中级“法宝”,Logging

说到保存日志,那不得不提 Python 的内置标准模块 Logging,它主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等,同时,我们也可以设置日志的输出格式。

import logging

def get_logger(LEVEL, log_file = None):
    head = '[%(asctime)-15s] [%(levelname)s] %(message)s'
    if LEVEL == 'info':
        logging.basicConfig(level=logging.INFO, format=head)
    elif LEVEL == 'debug':
        logging.basicConfig(level=logging.DEBUG, format=head)
    logger = logging.getLogger()
    if log_file != None:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)
    return logger

logger = get_logger('info')

logger.info('Jack Cui')
logger.info('https://cuijiahua.com')
logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')

只需要几行代码,进行一个简单的封装使用。使用函数 get_logger 创建一个级别为 info 的 logger,如果指定 log_file,则会对日志进行保存。

logging 默认支持的日志一共有 5 个等级:

日志级别等级 CRITICAL > ERROR > WARNING > INFO > DEBUG。

默认的日志级别设置为 WARNING,也就是说如果不指定日志级别,只会显示大于等于 WARNING 级别的日志。

例如:

import logging
logging.debug("debug_msg")
logging.info("info_msg")
logging.warning("warning_msg")
logging.error("error_msg")
logging.critical("critical_msg")

运行结果:

WARNING:root:warning_msg
ERROR:root:error_msg
CRITICAL:root:critical_msg

可以看到 info 和 debug 级别的日志不会输出,默认的日志格式也比较简单。

默认的日志格式为日志级别:Logger名称:用户输出消息

当然,我们可以通过,logging.basicConfig 的 format 参数,设置日志格式。

字段有很多,可谓应有尽有,足以满足我们定制化的需求。

五、高级“法宝”,TensorboardX

上文介绍的“法宝”,并非针对深度学习“炼丹”使用的工具。

而 TensorboardX 则不同,它是专门用于深度学习“炼丹”的高级“法宝”。

早些时候,很多人更喜欢用 Tensorflow 的原因之一,就是 Tensorflow 框架有个一个很好的可视化工具 Tensorboard。

Pytorch 要想使用 Tensorboard 配置起来费劲儿不说,还有很多 Bug。

Pytorch 1.1.0 版本发布后,打破了这个局面,TensorBoard 成为了 Pytorch 的正式可用组件。

在 Pytorch 中,这个可视化工具叫做 TensorBoardX,其实就是针对 Tensorboard 的一个封装,使得 PyTorch 用户也能够调用 Tensorboard。

TensorboardX 安装也非常简单,使用 pip 即可安装。

pip install tensorboardX

tensorboardX 使用也很简单,编写如下代码。

from tensorboardX import SummaryWriter

# 创建 writer1 对象
# log 会保存到 runs/exp 文件夹中
writer1 = SummaryWriter('runs/exp')

# 使用默认参数创建 writer2 对象
# log 会保存到 runs/日期_用户名 格式的文件夹中
writer2 = SummaryWriter()

# 使用 commet 参数,创建 writer3 对象
# log 会保存到 runs/日期_用户名_resnet 格式的文件中
writer3 = SummaryWriter(comment='_resnet')

使用的时候,创建一个 SummaryWriter 对象即可,以上展示了三种初始化 SummaryWriter 的方法:

  • 提供一个路径,将使用该路径来保存日志
  • 无参数,默认将使用 runs/日期_用户名 路径来保存日志
  • 提供一个 comment 参数,将使用 runs/日期_用户名+comment 路径来保存日志

运行结果:

有了 writer 我们就可以往日志里写入数字、图片、甚至声音等数据。

数字 (scalar)

这个是最简单的,使用 add_scalar 方法来记录数字常量。

add_scalar(tag, scalar_value, global_step=None, walltime=None)

总共 4 个参数。

  • tag (string): 数据名称,不同名称的数据使用不同曲线展示
  • scalar_value (float): 数字常量值
  • global_step (int, optional): 训练的 step
  • walltime (float, optional): 记录发生的时间,默认为 time.time()

需要注意,这里的 scalar_value 一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用 .item() 方法获取其数值。我们一般会使用 add_scalar 方法来记录训练过程的 loss、accuracy、learning rate 等数值的变化,直观地监控训练过程。

运行如下代码:

from tensorboardX import SummaryWriter    
writer = SummaryWriter('runs/scalar_example')
for i in range(10):
    writer.add_scalar('quadratic', i**2, global_step=i)
    writer.add_scalar('exponential', 2**i, global_step=i)
writer.close()

通过 add_scalar 往日志里写入数字,日志保存到 runs/scalar_example中,writer 用完要记得 close,否则无法保存数据。

在 cmd 中使用如下命令:

tensorboard --logdir=runs/scalar_example --port=8088

指定日志地址,使用端口号,在浏览器中,就可以使用如下地址,打开 Tensorboad。

http://localhost:8088/

省去了我们自己写代码可视化的麻烦。

图片 (image)

使用 add_image 方法来记录单个图像数据。注意,该方法需要 pillow 库的支持

add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数:

  • tag (string):数据名称
  • img_tensor (torch.Tensor / numpy.array):图像数据
  • global_step (int, optional):训练的 step
  • walltime (float, optional):记录发生的时间,默认为 time.time()
  • dataformats (string, optional):图像数据的格式,默认为 'CHW',即 Channel x Height x Width,还可以是 'CHW'、'HWC' 或 'HW' 等

我们一般会使用 add_image 来实时观察生成式模型的生成效果,或者可视化分割、目标检测的结果,帮助调试模型。

from tensorboardX import SummaryWriter
from urllib.request import urlretrieve
import cv2

urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg')

writer = SummaryWriter('runs/image_example')
for i in range(1, 4):
    writer.add_image('UNet_Seg',
                     cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),
                     global_step=i,
                     dataformats='HWC')
writer.close()

代码就是下载上篇文章数据集里的三张图片,然后使用 Tensorboard 可视化处理来,使用 8088 端口开打 Tensorboard:

tensorboard --logdir=runs/image_example --port=8088

运行结果:

试想一下,一边训练,一边输出图片结果,是不是很酸爽呢?

Tensorboard 中常用的 Scalar 和 Image,直方图、运行图、嵌入向量等,可以查看官方手册进行学习,方法都是类似的,简单好用。

官方文档:点击查看

六、总结

工欲善其事,必先利其器。

本文讲解了深度学习中,常用的“炼丹法宝”的使用方法,sys.stdout、matplotlib、logging、tensorboardX 你更喜欢哪一款?

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Python3网络爬虫(四):使用User Agent和代理IP隐藏身份

    运行平台:Windows Python版本:Python3.x IDE:Sublime text3 一、为何要设置User Agent     有一...

    Jack_Cui
  • Jetson TX1开发笔记(四):使用Caffe对摄像机视频流中的目标进行识别

    版权声明:本文为博主原创文章,未经博主允许不得转载。个人网站:http://cuijiahua.com。 ...

    Jack_Cui
  • AI算法让图片动起来,特朗普和蒙娜丽莎深情合唱《Unravel》

    First Order Motion,也就是一阶运动模型,来自 NeurIPS 2019 论文。

    Jack_Cui
  • Nginx 使用 logrotate 进行日志滚动

    USR1 信号量被 Nginx 自定义了,为重新打开日志;当 kill 命令发送 USR1时,nginx 会重新打开日志文件,并重新创建进程。

    叨叨软件测试
  • Android绘制(三):Path结合属性动画, 让图标动起来!

    sean_yang
  • Android 弹出Dialog时隐藏状态栏和底部导航栏的方法

    我们的app是全屏沉浸式的,发现activity在设置了全屏后,弹出dialog底部会跳出导航栏虚拟键。具体原因是因为Dialog的Window抢走了焦点,Wi...

    砸漏
  • 案例分析|简析腾讯体育用户成长体系的搭建

    用户1756920
  • 知乎高赞答案!你看过/写过哪些有意思的代码?

    腾讯大讲堂
  • Tensorflow | 斯坦福cs20si | lecture1

    上图的代码存在问题,目前不知道怎么修改。若是有朋友有修改方案,请在下方留言,谢谢!

    努力在北京混出人样
  • [原创译文] 10种帮你改善UI设计质量的窍门

    静电说:这几天静电在公众号里做了预告,我们将会花几天时间来翻译国外作者Danny Sapio的一系列文章。前天的不可不知的10条UI设计经验法则 不知道大家有没...

    用户5009027

扫码关注云+社区

领取腾讯云代金券