前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Matlab 使用CNN拟合回归模型预测手写数字的旋转角度

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度

作者头像
用户9925864
发布2022-07-27 08:46:58
1.4K0
发布2022-07-27 08:46:58
举报
文章被收录于专栏:算法工程师的学习日志

一个深度学习文档分享一下,很简单,但思路不错,在个人项目上也可以按照需求变化数据集来实现CNN回归计算。

加载数据

代码语言:javascript
复制
clc
close all
clear
%% 加载数据
%% 数据集包含手写数字的合成图像,以及每幅图像旋转的对应角度(以角度为单位)。
%% 使用digitTrain4DArrayData和digitTest4DArrayData将训练和验证图像加载为4D数组。
%% 输出YTrain和YValidation是以角度为单位的旋转角度。每个训练和验证数据集包含5000张图像。
[XTrain, ~, Ytrain] = digitTrain4DArrayData;
[XValidation, ~, YValidation] = digitTest4DArrayData;
%% 随机显示20张训练图像
numTrainImages = numel(Ytrain);
figure;
idx = randperm(numTrainImages, 20);
for i = 1 : numel(idx)
  subplot(4, 5, i);
  imshow(XTrain(:, :, :, idx(i)))
  drawnow
end

数据归一化

当训练神经网络时,确保你的数据在网络的所有阶段都是标准化的。归一化有助于使用梯度下降来稳定和加速网络训练。如果数据规模太小,那么损失可能会变成NaN,并且在培训期间网络参数可能会出现分歧。

标准化数据的常用方法包括重新标定数据,使其范围变为[0,1]或使其均值为0,标准差为1。

标准化以下数据:

1、输入数据。在将预测器输入到网络之前对数据进行规范化。

2、层输出。使用批处理规范化层对每个卷积和完全连接层的输出进行规范化。

3、响应。如果使用批处理规范化层对网络末端的层输出进行规范化,则在开始训练时对网络的预测进行规范化。

代码语言:javascript
复制
%% 绘制响应分布:在分类问题中,输出是类概率,类概率总是归一化的。
figure;
histogram(Ytrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

创建网络层

代码语言:javascript
复制
%% 创建网络层
%% 第一层定义输入数据的大小和类型。输入的图像大小为28×28×1。创建与训练图像大小相同的图像输入层。
%% 网络的中间层定义了网络的核心架构,大部分计算和学习都在这个架构中进行。
%% 最后一层定义输出数据的大小和类型。对于回归问题,全连接层必须先于网络末端的回归层。
layers = [
  imageInputLayer([28 28 1])
  batchNormalizationLayer
  reluLayer
  
  averagePooling2dLayer(2, 'Stride', 2)
  
  convolution2dLayer(3, 16, 'Padding', 'same')
  batchNormalizationLayer
  reluLayer
  
  averagePooling2dLayer(2, 'Stride', 2)
  
  convolution2dLayer(3, 32, 'Padding', 'same')
  batchNormalizationLayer
  reluLayer
  
  convolution2dLayer(3, 32, 'Padding', 'same')
  batchNormalizationLayer
  reluLayer
  
  dropoutLayer(0.2)
  fullyConnectedLayer(1)
  regressionLayer];

训练网络设置

使用 trainNetwork 创建网络。如果存在兼容的 GPU,此命令会使用 GPU。否则,trainNetwork 将使用 CPU。在 GPU 上进行训练需要具有 3.0 或更高计算能力的支持 CUDA® 的 NVIDIA® GPU。

代码语言:javascript
复制
%% 训练网络——Options
%% Train for 30 epochs 学习率0.001 在20个epoch后降低学习率。
%% 通过指定验证数据和验证频率,监控培训过程中的网络准确性。
%% 根据训练数据对网络进行训练,并在训练过程中定期对验证数据进行精度计算。
%% 验证数据不用于更新网络权重。打开训练进度图,并关闭命令窗口输出。
miniBatchSize = 128;
validationFrequency = floor(numel(Ytrain) / miniBatchSize);
options = trainingOptions('sgdm', ...
  'MiniBatchSize', miniBatchSize, ...
  'MaxEpochs', 30, ...
  'InitialLearnRate', 1e-3, ...
  'LearnRateSchedule', 'piecewise', ...
  'LearnRateDropFactor', 0.1, ...
  'LearnRateDropPeriod', 20, ...
  'Shuffle', 'every-epoch', ...
  'ValidationData', {XValidation, YValidation}, ...
  'ValidationFrequency', validationFrequency, ...
  'Plots', 'training-progress', ...
  'Verbose', false);

训练网络

代码语言:javascript
复制
net = trainNetwork(XTrain, Ytrain, layers, options);

预测结果

基于验证数据评估准确度来测试网络性能。使用 predict 预测验证图像的旋转角度。

代码语言:javascript
复制
YPredicted = predict(net,XValidation);

评估性能

通过计算以下值来评估模型性能:

代码语言:javascript
复制
predictionError = YValidation - YPredicted;

计算在实际角度的可接受误差界限内的预测值的数量。将阈值设置为 10 度。计算此阈值范围内的预测值的百分比。

代码语言:javascript
复制
thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);
accuracy = numCorrect/numValidationImages

使用均方根误差 (RMSE) 来衡量预测旋转角度和实际旋转角度之间的差异。

代码语言:javascript
复制
squares = predictionError.^2;
rmse = sqrt(mean(squares))
代码语言:javascript
复制
accuracy =
    0.9584


rmse =
  single
    4.8987

显示原始数字以及校正旋转后的数字,使用 montage (Image Processing Toolbox) 将数字显示在同一个图像上。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-04-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法工程师的学习日志 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档