# PyTorch实例：用ResNet进行交通标志分类

【导读】本文是机器学习工程师Pavel Surmenok撰写的一篇技术博客，用Pytorch实现ResNet网络，并用德国交通标志识别基准数据集进行实验。文中分别介绍了数据集、实验方法、代码、准备工作，并对图像增强、学习率、模型微调、误差分析等步骤进行详细介绍。文章中给出了GitHub代码，本文是一篇学习PyTorch和ResNet的很好的实例教程。

### ResNet for Traffic Sign Classification With PyTorch

▌数据集

http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset

▌实验方法

▌代码

https://github.com/surmenok/GTSRB/blob/master/german-traffic-signs.ipynb

https://github.com/surmenok/GTSRB

▌准备工作

www.fast.ai/2017/11/13/validation-sets/

▌探索性分析

▌训练

arch = resnet34 learn = ConvLearner.pretrained(arch, data, precompute=False)

▌图像增强

sz = 96 # Look at examples of image augmentation def get_augs(): x,_ = next(iter(data.aug_dl)) return data.trn_ds.denorm(x)[1] aug_tfms = [RandomRotate(20), RandomLighting(0.8, 0.8)] tfms = tfms_from_model(arch, sz, aug_tfms=aug_tfms, max_zoom=1.2) data = ImageClassifierData.from_paths(path, tfms=tfms, test_name='test') ims = np.stack([get_augs() for i in range(6)]) plots(ims, rows=2)

▌学习率

https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0

def plot_loss_change(sched, sma=1, n_skip=20, y_lim=(-0.01, 0.01)): """ Plots rate of change of the loss function. Parameters: sched - learning rate scheduler, an instance of LR_Finder class. sma - number of batches for simple moving average to smooth out the curve. n_skip - number of batches to skip on the left. y_lim - limits for the y axis. """ derivatives = [0] * (sma + 1) for i in range(1 + sma, len(learn.sched.lrs)): derivative = (learn.sched.losses[i] - learn.sched.losses[i - sma]) / sma derivatives.append(derivative) plt.ylabel("d/loss") plt.xlabel("learning rate (log scale)") plt.plot(learn.sched.lrs[n_skip:], derivatives[n_skip:]) plt.xscale('log') plt.ylim(y_lim) learn.lr_find()

▌微调最后一层

wd = 5e-4 learn.fit(0.01, 1, wds=wd)

▌微调整个模型

learn.unfreeze() learn.fit(0.01, 3, wds=wd)

learn.fit(lr, 4, cycle_len=1, cycle_mult=2, wds=wd)

▌误差分析

log_preds,y = learn.predict_with_targs() preds = np.exp(log_preds) pred_labels = np.argmax(preds, axis=1) results = ImageModelResults(data.val_ds, log_preds) results.plot_most_incorrect(1)

▌重新训练整个训练集

▌在测试集上进行测试

▌测试时间增加

log_preds,_ = learn.TTA(n_aug=20, is_test=True) preds = np.mean(np.exp(log_preds),0) accuracy_np(preds, y_true)

▌它有多好？

2011年IJCNN竞赛排行榜排名：

• CNN与ÁlvaroArcos-García等人的3个空间变换器99.71％

• DanCireşan等人的CNN。99.46％

• 基于颜色斑点的COSFIRE过滤器，用于由Baris Gecer进行物体识别，98.97％

▌参考链接：

fast.ai最新版本的“深入学习编码器”课程：

course.fast.ai

GitHub：

https://github.com/surmenok/GTSRB

fastai：

https://github.com/fastai/fastai

CNN with 3 spatial transformers：

Committee of CNNs：

https://www.sciencedirect.com/science/article/pii/S0893608012000524?via%3Dihub

Color-blob-based COSFIRE blters for object recognition：

dx.doi.org/10.1016/j.imavis.2016.10.006

https://towardsdatascience.com/resnet-for-traffic-sign-classification-with-pytorch-5883a97bbaa3

0 条评论

• ### 实战｜如何利用深度学习诊断心脏病？

摘要： 本文探讨的是开发一个能够对心脏磁共振成像（MRI）数据集图像中的右心室自动分割的系统。到目前为止，这主要是通过经典的图像处理方法来处理的。而现代深度学习...

• ### 5秒钟内将手绘网站线框图转换为可用的 HTML网站

你可以在 GitHub 上找到这个项目的代码：https://github.com/ashnkumar/sketch-code

• ### 世界公认最健康的作息时间表，今后就照这个来~

IT派 - {技术青年圈} 持续关注互联网、大数据、人工智能领域 ? ? 7:00 迎着清晨的阳光起床 一杯温水是早起之后的必需品，能让你...

• ### 【干货】PyTorch实例：用ResNet进行交通标志分类

【导读】本文是机器学习工程师Pavel Surmenok撰写的一篇技术博客，用Pytorch实现ResNet网络，并用德国交通标志识别基准数据集进行实验。文中分...

• ### 谷歌人工智能“即时”预测局部降水模式

谷歌希望利用人工智能和机器学习来快速预测当地天气。在一篇论文和附带的博客文章中，这家科技巨头详细介绍了一个人工智能系统，该系统利用卫星图像生成“几乎是瞬间”的高...

• ### 在ubuntu使用apt install的fastqc是有bug的

所以我就去了我的生物信息学常见1000个软件的安装代码：https://www.jianshu.com/p/ae28e8e3e9f5 找到了fastqc软件下载...

• ### 如何到top5%？NLP文本分类和情感分析竞赛总结

笔者主要方向是KBQA，深深体会到竞赛是学习一个新领域最好的方式，这些比赛总的来说都属于文本分类领域，因此最近打算一起总结一下。

• ### Java EE基本框架（Struts2+Spring+MyBatis三层，Struts MVC）之间的关系

一个JavaEE的项目，页面用JSP，后台用了Struts2+Spring+MyBatis，数据库用的是Oracle，这么多技术名词，他们之间的关系如何，整体是...

• ### 如何到top5%？NLP文本分类和情感分析竞赛总结

笔者主要方向是KBQA，深深体会到竞赛是学习一个新领域最好的方式，这些比赛总的来说都属于文本分类领域，因此最近打算一起总结一下。