主要内容

训练卷积神经网络用于回归

这个例子展示了如何使用卷积神经网络拟合回归模型来预测手写体数字的旋转角度。

卷积神经网络(CNNs,或ConvNets)是深度学习的基本工具,特别适合分析图像数据。例如,您可以使用cnn对图像进行分类。为了预测连续的数据,比如角度和距离,可以在网络的末端包含一个回归层。

该示例构建了一个卷积神经网络体系结构,训练了一个网络,并使用训练过的网络预测旋转手写数字的角度。这些预测对于光学字符识别是很有用的。

可选地,您可以使用imrotate(图像处理工具箱™)旋转图像,和箱线图(统计和机器学习工具箱™)创建残差箱线图。

加载数据

该数据集包含了手写数字合成图像以及相应的角度(以度为单位),每个图像都被旋转。

将训练和验证图像加载为4-D数组digitTrain4DArrayData而且digitTest4DArrayData.输出YTrain而且YValidation旋转角度的单位是度。训练和验证数据集各包含5000张图像。

[XTrain ~, YTrain] = digitTrain4DArrayData;[XValidation ~, YValidation] = digitTest4DArrayData;

显示20个随机训练图像使用imshow

numTrainImages =元素个数(YTrain);图idx = randperm(numTrainImages,20);i = 1:元素个数(idx)次要情节(4、5、i) imshow (XTrain (:,:,:, idx(我)))结束

图中包含20个轴对象。坐标轴对象1包含一个image类型的对象。Axes对象2包含一个image类型的对象。Axes对象3包含一个image类型的对象。Axes对象4包含一个image类型的对象。Axes对象5包含一个image类型的对象。Axes对象6包含一个image类型的对象。Axes对象7包含一个image类型的对象。Axes对象8包含一个image类型的对象。Axes对象9包含一个image类型的对象。 Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

检查数据归一化

在训练神经网络时,它通常有助于确保数据在网络的所有阶段都是规范化的。规范化有助于使用梯度下降稳定和加速网络训练。如果您的数据没有很好地伸缩,那么损失可能会变成在训练过程中,网络参数会发生发散。规范化数据的常用方法包括缩放数据,使其范围变为[0,1],或者使其均值为0,标准差为1。可以对以下数据进行规范化:

  • 输入数据。在将预测器输入到网络之前对其进行规范化处理。在本例中,输入图像已经规范化到[0,1]的范围。

  • 层输出。通过使用批处理归一化层,可以归一化每个卷积层和完全连接层的输出。

  • 响应。如果使用批处理归一化层对网络末端的层输出进行归一化,则在训练开始时对网络的预测进行归一化。如果反应的规模与这些预测有很大的不同,那么网络训练可能无法收敛。如果你的反应尺度很差,那么试着将其规范化,看看网络训练是否有所改善。如果在训练前对响应进行归一化,那么必须对训练过的网络的预测进行转换,以获得原始响应的预测。

绘制响应的分布图。响应(旋转角度以度为单位)在-45到45之间近似均匀分布,不需要归一化就能很好地工作。在分类问题中,输出是类概率,它总是归一化的。

图直方图(YTrain)轴ylabel (“计数”)包含(旋转角度的

图中包含一个axes对象。axis对象包含一个直方图类型的对象。

通常,数据不需要完全规范化。但是,如果你在这个例子中训练网络来预测100 * YTrainYTrain + 500而不是YTrain,那么损失就变成了训练开始时网络参数发散。这些结果即使发生在网络预测之间唯一的区别aY + b网络预测Y是最终全连接层的权重和偏差的简单缩放。

如果输入或响应的分布非常不均匀或倾斜,您还可以在训练网络之前对数据执行非线性转换(例如,取对数)。

创建网络层

为了解决回归问题,创建网络的层,并在网络的末端包含一个回归层。

第一层定义输入数据的大小和类型。输入图像是28乘28乘1。创建一个与训练图像相同大小的图像输入层。

网络的中间层定义了网络的核心架构,大部分的计算和学习都发生在这里。

最后一层定义输出数据的大小和类型。对于回归问题,一个完全连接的层必须先于网络末端的回归层。创建一个大小为1的完全连接输出层和一个回归层。

将所有图层组合在一起数组中。

layers = [imageInputLayer([28 28 1]) convolution2dLayer(3,8,“填充”“相同”(2)“步”2) convolution2dLayer(16日“填充”“相同”(2)“步”32岁的,2)convolution2dLayer (3“填充”“相同”) reluLayer卷积2dlayer (3,32,“填充”“相同”) batchNormalizationLayer reluLayer dropoutLayer(0.2) fullconnectedlayer (1) regressionLayer];

列车网络的

创建网络培训选项。训练30次。设置初始学习率为0.001,并在20个周期后降低学习率。通过指定验证数据和验证频率来监控训练过程中的网络准确性。该软件在训练数据上训练网络,并在训练期间定期计算验证数据的准确性。验证数据不用于更新网络权重。打开训练进度图,并关闭命令窗口输出。

miniBatchSize = 128;validationFrequency =地板(元素个数(YTrain) / miniBatchSize);选择= trainingOptions (“个”...“MiniBatchSize”miniBatchSize,...“MaxEpochs”30岁的...“InitialLearnRate”1 e - 3,...“LearnRateSchedule”“分段”...“LearnRateDropFactor”, 0.1,...“LearnRateDropPeriod”, 20岁,...“洗牌”“every-epoch”...“ValidationData”{XValidation, YValidation},...“ValidationFrequency”validationFrequency,...“阴谋”“训练进步”...“详细”、假);

使用以下命令创建网络trainNetwork.该命令使用兼容的图形处理器。使用GPU需要并行计算工具箱™和支持的GPU设备。有关支持的设备的信息,请参见GPU计算的需求(并行计算工具箱).否则,trainNetwork使用CPU。

网= trainNetwork (XTrain、YTrain层,选择);

{

中包含的网络体系结构的详细信息的属性

网层
带有图层的图层数组:1' imageinput'图像输入28x28x1图像与'zerocenter'归一化2' conv_1'二维卷积8 3x3x1卷积与stride[1 1]和填充'same' 3 'batchnorm_1' Batch归一化批处理归一化8通道4 'relu_1' ReLU ReLU 5 'avgpool2d_1'二维平均归一化2x2平均归一化与stride[2 2]和填充[0 0 0 0]6 'conv_2'二维卷积16 3x3x8卷积与stride[1 1]和填充'same' 7 'batchnorm_2' Batch归一化批处理归一化与16channel8 'relu_2' ReLU ReLU 9 'avgpool2d_2' 2d Average Pooling 2x2 Average Pooling with stride[2 2]和padding [0 0 0 0] 10 'conv_3' 2d Convolution 32 3x3x16 convolutions with stride[1 1]和padding 'same' 11 'batchnorm_3' Batch Normalization Batch Normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'conv_4' 2d Convolution 32 3x3x32 convolutions with stride[1 1]和padding 'same' 14 'batchnorm_4' Batch Normalization Batch Normalization with 32 channels 15 'relu_4' ReLU ReLU 16'dropout' dropout' 20% dropout' 17 'fc'全连接1全连接层18 'regressionoutput'回归输出均方误差响应' response '

测试网络

通过评估验证数据的准确性来测试网络的性能。

使用预测对验证图像的旋转角度进行预测。

YPredicted =预测(净,XValidation);

评估性能

通过计算来评估模型的性能:

  1. 在可接受的误差范围内的预测的百分比

  2. 预测和实际旋转角度的均方根误差(RMSE)

计算预测的旋转角度与实际旋转角度之间的预测误差。

predictionError = YValidation - yexpected;

从真实角度计算可接受误差范围内的预测数量。将阈值设置为10度。计算在此阈值内的预测百分比。

用力推= 10;numright = sum(abs(predictionError) < thr);numValidationImages =元素个数(YValidation);= numCorrect / numValidationImages准确性
精度= 0.9672

使用均方根误差(RMSE)来测量预测和实际旋转角度之间的差异。

广场= predictionError。^ 2;rmse =√意味着(广场))
rmse =4.6507

可视化预测

在散点图中可视化预测。把预测值与真实值画出来。

图散射(YPredicted YValidation,“+”)包含(“预测价值”) ylabel (“真正的价值”)举行Plot ([-60 60], [-60 60],“r——”

图中包含一个axes对象。axis对象包含两个类型为scatter、line的对象。

正确的数字旋转

您可以使用图像处理工具箱中的函数来拉直数字并一起显示它们。旋转49个样本数字根据他们的预测旋转角度使用imrotate(图像处理工具箱)。

idx = randperm (numValidationImages, 49);i = 1:元素个数(idx)图像= XValidation (:,:,:, idx(我));predictedAngle = YPredicted (idx (i));imagesRotated (::,:, i) = imrotate(形象,predictedAngle,“双三次的”“作物”);结束

显示原始数字及其修正旋转。您可以使用蒙太奇(图像处理工具箱)将数字一起显示在单个图像中。

figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title(“原始”副情节(1,2,2)蒙太奇(imagesrotate)标题(“纠正”

图中包含2个轴对象。标题为Original的axis对象1包含一个类型为image的对象。标题为Corrected的Axes对象2包含一个类型为image的对象。

另请参阅

|

相关的话题

Baidu
map