# 【干货】PyTorch实例：用ResNet进行交通标志分类

【导读】本文是机器学习工程师Pavel Surmenok撰写的一篇技术博客，用Pytorch实现ResNet网络，并用德国交通标志识别基准数据集进行实验。文中分别介绍了数据集、实验方法、代码、准备工作，并对图像增强、学习率、模型微调、误差分析等步骤进行详细介绍。文章中给出了GitHub代码，本文是一篇学习PyTorch和ResNet的很好的实例教程。

### ResNet for Traffic Sign Classification With PyTorch

▌数据集

http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset

▌实验方法

▌代码

https://github.com/surmenok/GTSRB/blob/master/german-traffic-signs.ipynb

https://github.com/surmenok/GTSRB

▌准备工作

www.fast.ai/2017/11/13/validation-sets/

▌探索性分析

▌训练

arch = resnet34
learn = ConvLearner.pretrained(arch, data, precompute=False)

▌图像增强

sz = 96

# Look at examples of image augmentation
def get_augs():
x,_ = next(iter(data.aug_dl))
return data.trn_ds.denorm(x)[1]

aug_tfms = [RandomRotate(20), RandomLighting(0.8, 0.8)]
tfms = tfms_from_model(arch, sz, aug_tfms=aug_tfms, max_zoom=1.2)
data = ImageClassifierData.from_paths(path, tfms=tfms, test_name='test')

ims = np.stack([get_augs() for i in range(6)])
plots(ims, rows=2)

▌学习率

https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0

def plot_loss_change(sched, sma=1, n_skip=20, y_lim=(-0.01, 0.01)):
"""
Plots rate of change of the loss function.
Parameters:
sched - learning rate scheduler, an instance of LR_Finder class.
sma - number of batches for simple moving average to smooth out the curve.
n_skip - number of batches to skip on the left.
y_lim - limits for the y axis.
"""
derivatives = [0] * (sma + 1)
for i in range(1 + sma, len(learn.sched.lrs)):
derivative = (learn.sched.losses[i] - learn.sched.losses[i - sma]) / sma
derivatives.append(derivative)

plt.ylabel("d/loss")
plt.xlabel("learning rate (log scale)")
plt.plot(learn.sched.lrs[n_skip:], derivatives[n_skip:])
plt.xscale('log')
plt.ylim(y_lim)

learn.lr_find()

▌微调最后一层

wd = 5e-4
learn.fit(0.01, 1, wds=wd)

▌微调整个模型

learn.unfreeze()
learn.fit(0.01, 3, wds=wd)

learn.fit(lr, 4, cycle_len=1, cycle_mult=2, wds=wd)

▌误差分析

log_preds,y = learn.predict_with_targs()
preds = np.exp(log_preds)
pred_labels = np.argmax(preds, axis=1)

results = ImageModelResults(data.val_ds, log_preds)

results.plot_most_incorrect(1)

▌重新训练整个训练集

▌在测试集上进行测试

▌测试时间增加

log_preds,_ = learn.TTA(n_aug=20, is_test=True)
preds = np.mean(np.exp(log_preds),0)
accuracy_np(preds, y_true)

▌它有多好？

2011年IJCNN竞赛排行榜排名：

• CNN与ÁlvaroArcos-García等人的3个空间变换器99.71％

• DanCireşan等人的CNN。99.46％

• 基于颜色斑点的COSFIRE过滤器，用于由Baris Gecer进行物体识别，98.97％

▌参考链接：

fast.ai最新版本的“深入学习编码器”课程：

course.fast.ai

GitHub：

https://github.com/surmenok/GTSRB

fastai：

https://github.com/fastai/fastai

CNN with 3 spatial transformers：

Committee of CNNs：

https://www.sciencedirect.com/science/article/pii/S0893608012000524?via%3Dihub

Color-blob-based COSFIRE blters for object recognition：

dx.doi.org/10.1016/j.imavis.2016.10.006

https://towardsdatascience.com/resnet-for-traffic-sign-classification-with-pytorch-5883a97bbaa3

0 条评论

• ### 【干货】TensorFlow实战——图像分类神经网络模型

Learn how to classify images with TensorFlow 使用TensorFlow创建一个简单而强大的图像分类神经网络模型 by...

• ### 什么是MAP？ 理解目标检测模型中的性能评估

【导读】近日，机器学习工程师Tarang Shah发布一篇文章，探讨了机器学习中模型的度量指标的相关问题。本文首先介绍了机器学习中两个比较直观和常用的度量指标：...

• ### PyTorch实例：用ResNet进行交通标志分类

【导读】本文是机器学习工程师Pavel Surmenok撰写的一篇技术博客，用Pytorch实现ResNet网络，并用德国交通标志识别基准数据集进行实验。文中分...

• ### 谷歌人工智能“即时”预测局部降水模式

谷歌希望利用人工智能和机器学习来快速预测当地天气。在一篇论文和附带的博客文章中，这家科技巨头详细介绍了一个人工智能系统，该系统利用卫星图像生成“几乎是瞬间”的高...

• ### 在ubuntu使用apt install的fastqc是有bug的

所以我就去了我的生物信息学常见1000个软件的安装代码：https://www.jianshu.com/p/ae28e8e3e9f5 找到了fastqc软件下载...

• ### 如何到top5%？NLP文本分类和情感分析竞赛总结

笔者主要方向是KBQA，深深体会到竞赛是学习一个新领域最好的方式，这些比赛总的来说都属于文本分类领域，因此最近打算一起总结一下。

• ### Java EE基本框架（Struts2+Spring+MyBatis三层，Struts MVC）之间的关系

一个JavaEE的项目，页面用JSP，后台用了Struts2+Spring+MyBatis，数据库用的是Oracle，这么多技术名词，他们之间的关系如何，整体是...

• ### 如何到top5%？NLP文本分类和情感分析竞赛总结

笔者主要方向是KBQA，深深体会到竞赛是学习一个新领域最好的方式，这些比赛总的来说都属于文本分类领域，因此最近打算一起总结一下。