主要内容

使用自定义训练循环训练网络

这个例子展示了如何训练一个使用自定义学习速率计划对手写数字进行分类的网络。

你可以训练大多数类型的神经网络使用trainNetwork而且trainingOptions功能。如果trainingOptions函数没有提供所需的选项(例如,自定义学习速率计划),那么您可以使用dlarray而且dlnetwork对象自动区分。举个例子,展示如何使用trainNetwork功能,请参阅使用预训练网络的迁移学习

训练深度神经网络是一项优化任务。把神经网络看作一个函数 f X θ ,在那里 X 是网络输入,和 θ 是否设置了可学习的参数,可以进行优化 θ 根据训练数据,最小化了一些损失值。例如,优化可学习参数 θ 这样对于一个给定的输入 X 有相应的目标 T ,它们将预测之间的误差最小化 Y f X θ 而且 T

所使用的损失函数取决于任务的类型。例如:

  • 对于分类任务,可以最小化预测和目标之间的交叉熵误差。

  • 对于回归任务,您可以最小化预测和目标之间的均方误差。

你可以使用梯度下降来优化目标:减少损失 l 通过迭代更新可学习参数 θ 通过使用损失相对于可学习参数的梯度,逐步向最小值迈进。梯度下降算法通常通过使用表单更新步骤的变体来更新可学习参数 θ t + 1 θ t - ρ l ,在那里 t 为迭代次数, ρ 是学习率,和 l 表示梯度(损失对可学习参数的导数)。

方法训练网络对手写数字进行分类基于时间的衰减学习率计划:对于每一次迭代,求解器使用给定的学习率 ρ t ρ 0 1 + k t ,在那里t为迭代次数, ρ 0 是初始学习率,和k是衰减的。

负荷训练数据

方法将数字数据加载为图像数据存储imageDatastore函数并指定包含图像数据的文件夹。

dataFolder = fullfile (toolboxdir (“nnet”),“nndemos”“nndatasets”“DigitDataset”);imd = imageDatastore (dataFolder,...IncludeSubfolders = true,...LabelSource =“foldernames”);

将数据划分为训练集和验证集。方法预留10%的数据用于验证splitEachLabel函数。

[imdsTrain, imdsValidation] = splitEachLabel (imd, 0.9,“随机”);

本例中使用的网络需要输入大小为28 × 28 × 1的图像。若要自动调整训练图像的大小,请使用增强图像数据存储。指定要在训练图像上执行的附加增强操作:在水平轴和垂直轴上随机地将图像平移到5个像素。数据增强有助于防止网络过拟合和记忆训练图像的精确细节。

inputSize = [28 28 1];pixelRange = [-5 5];imageAugmenter = imageDataAugmenter (...RandXTranslation = pixelRange,...RandYTranslation = pixelRange);augimdsTrain = augmentedImageDatastore (inputSize (1:2), imdsTrain, DataAugmentation = imageAugmenter);

若要自动调整验证图像的大小而不执行进一步的数据增强,请使用增强图像数据存储而不指定任何额外的预处理操作。

augimdsValidation = augmentedImageDatastore (inputSize (1:2), imdsValidation);

确定训练数据中的类的数量。

类=类别(imdsTrain.Labels);numClasses =元素个数(类);

定义网络

为图像分类定义网络。

  • 对于图像输入,指定输入大小与训练数据匹配的图像输入层。

  • 不归一化图像输入,设置归一化输入层的选项“没有”

  • 指定三个卷积-batchnorm- relu块。

  • 将输入填充到卷积层,以便输出具有相同的大小填充选项“相同”

  • 对于第一个卷积层,指定20个大小为5的过滤器。对于剩余的卷积层,指定20个大小为3的过滤器。

  • 对于分类,指定一个完全连接层,其大小与类的数量匹配

  • 为了将输出映射到概率,需要包含一个softmax层。

当使用自定义训练循环训练网络时,不要包含输出层。

layers = [imageInputLayer(inputSize, normalized = .“没有”) convolution2dLayer(5、20、填充=“相同”) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding= .“相同”) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding= .“相同”) batchNormalizationLayer reluLayer fulllyconnectedlayer (numClasses) softmaxLayer];

创建一个dlnetwork对象。

净= dlnetwork(层)
层:[12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1用summary查看摘要。

定义模型损失函数

训练深度神经网络是一项优化任务。把神经网络看作一个函数 f X θ ,在那里 X 是网络输入,和 θ 是否设置了可学习的参数,可以进行优化 θ 根据训练数据,最小化了一些损失值。例如,优化可学习参数 θ 这样对于一个给定的输入 X 有相应的目标 T ,它们将预测之间的误差最小化 Y f X θ 而且 T

创建函数modelLoss,列于损失函数模型部分,它将作为输入dlnetwork对象,它是具有相应目标的输入数据的小批处理,并返回损耗、损耗相对于可学习参数的梯度和网络状态。

指定培训选项

训练10个周期,小批次大小为128。

numEpochs = 10;miniBatchSize = 128;

指定SGDM优化的选项。指定初始学习速率为0.01,衰减为0.01,动量为0.9。

initialLearnRate = 0.01;衰变= 0.01;动量= 0.9;

火车模型

创建一个minibatchqueue对象,该对象在训练期间处理和管理小批量图像。为每个mini-batch:

  • 使用自定义的小批量预处理功能preprocessMiniBatch(在本例末尾定义)将标签转换为单热编码变量。

  • 用尺寸标签格式化图像数据“SSCB”(空间,空间,渠道,批处理)。默认情况下,minibatchqueue对象将数据转换为dlarray具有基础类型的对象.不要格式化类标签。

  • 如果有GPU,请使用GPU进行训练。默认情况下,minibatchqueue对象将每个输出转换为gpuArray如果有可用的GPU。使用GPU需要并行计算工具箱™和支持的GPU设备。有关支持的设备的信息,请参见GPU计算的需求(并行计算工具箱)

兆贝可= minibatchqueue (augimdsTrain,...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessMiniBatch,...MiniBatchFormat = [“SSCB”""]);

初始化SGDM求解器的速度参数。

速度= [];

计算训练进度监视器的总迭代次数。

numObservationsTrain =元素个数(imdsTrain.Files);numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);numIterations = numEpochs * numIterationsPerEpoch;

初始化TrainingProgressMonitor对象。因为计时器在创建监视器对象时开始,所以要确保创建的对象接近训练循环。

监控= = trainingProgressMonitor(指标“损失”信息= (“时代”“LearnRate”),包含=“迭代”);

使用自定义训练循环训练网络。对于每个历元,洗牌数据并遍历小批数据。为每个mini-batch:

  • 评估模型损失,梯度和状态使用dlfeval而且modelLoss功能,并更新网络状态。

  • 为基于时间的衰减学习率时间表确定学习率。

  • 更新网络参数sgdmupdate函数。

  • 更新训练进度监视器中的损耗、学习速率和epoch值。

  • 如果Stop属性为真,则停止。对象的Stop属性值TrainingProgressMonitor对象在单击Stop按钮时更改为true。

时代= 0;迭代= 0;循环遍历各个时代。epoch < numEpochs && ~monitor。停止epoch = epoch + 1;%洗牌数据。洗牌(兆贝可);在小批量上循环。hasdata(兆贝可)& & ~班长。停止迭代=迭代+ 1;读取小批数据。[X, T] =下一个(兆贝可);评估模型的梯度,状态和损失使用dlfeval和% modelLoss函数,并更新网络状态。(损失、渐变、状态)= dlfeval (@modelLoss,净,X, T);网=状态;为基于时间的衰减学习率时间表确定学习率。learnRate = initialLearnRate/(1 + decay*iteration);使用SGDM优化器更新网络参数。(净、速度)= sgdmupdate(净、渐变速度,learnRate动量);更新培训进度监视器。recordMetrics(监控、迭代损失=损失);updateInfo(监控、时代=时代LearnRate = LearnRate);班长。进度= 100 *迭代/numIterations;结束结束

测试模型

通过将验证集上的预测结果与真实标签进行比较,检验模型的分类精度。

训练后,对新数据进行预测不需要标签。创建minibatchqueue对象,该对象只包含测试数据的预测符:

  • 要忽略用于测试的标签,请将迷你批处理队列的输出数量设置为1。

  • 指定与培训时相同的小批大小。

  • 方法对预测器进行预处理preprocessMiniBatchPredictors函数,在示例末尾列出。

  • 对于数据存储的单个输出,指定小批处理格式“SSCB”(空间,空间,渠道,批处理)。

numOutputs = 1;mbqTest = minibatchqueue (augimdsValidation numOutputs,...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessMiniBatchPredictors,...MiniBatchFormat =“SSCB”);

循环使用小批次并对图像进行分类modelPredictions函数,在示例末尾列出。

欧美= modelPredictions(净、mbqTest、类);

评估分类的准确性。

tt = imdsValidation.Labels;精度=平均值(TTest == YTest)
精度= 0.9750

在困惑图表中可视化预测。

次图confusionchart (tt)

对角线上的大值表示对对应类的准确预测。非对角线上的大值表示对应类之间存在很强的混淆。

支持功能

损失函数模型

modelLoss函数接受一个dlnetwork对象,一个小批量的输入数据X与相应的目标T并返回损失,损失相对于可学习参数的梯度、网络状态。方法可自动计算梯度dlgradient函数。

函数(损失、渐变、状态)= modelLoss(净,X, T)通过网络转发数据。[Y,状态]=前进(净,X);计算交叉熵损失。损失= crossentropy (Y, T);计算相对于可学习参数的损失梯度。。梯度= dlgradient(损失、net.Learnables);结束

模型的预测函数

modelPredictions函数接受一个dlnetwork对象,一个minibatchqueue的输入数据兆贝可,并通过迭代所有数据来计算模型预测minibatchqueue对象。函数使用onehotdecode函数求出预测分数最高的班级。

函数Y = modelforecasts (net,mbq,classes) Y = [];在小批量上循环。hasdata(mbq) X = next(mbq);%进行预测。成绩=预测(净,X);解码标签并附加到输出。标签= onehotdecode(成绩、类1)';Y = [Y;标签);结束结束

小批量预处理功能

preprocessMiniBatch函数使用以下步骤预处理一小批预测器和标签:

  1. 方法对图像进行预处理preprocessMiniBatchPredictors函数。

  2. 从传入的单元格数组中提取标签数据,并沿着二次元连接到一个分类数组中。

  3. 一热编码类别标签到数字数组。对第一个维度进行编码会生成一个与网络输出形状匹配的编码数组。

函数[X, T] = preprocessMiniBatch (dataX人数()%预处理预测。X = preprocessMiniBatchPredictors (dataX);从单元格中提取标签数据并连接。猫(T = 2,人数({1:结束});单热编码标签。T = onehotencode (T, 1);结束

小批量预测器预处理函数

preprocessMiniBatchPredictors函数通过从输入单元格数组中提取图像数据并连接到数值数组来预处理小批预测器。对于灰度输入,在第四个维度上的连接将为每个图像添加第三个维度,以用作单通道维度。

函数X = preprocessMiniBatchPredictors (dataX)%连接。猫(X = 4, dataX{1:结束});结束

另请参阅

||||||||||

相关的话题

Baidu
map