使用实验管理器进行音频迁移学习
此示例展示了如何配置一个实验,该实验将多个预训练网络应用于使用迁移学习的语音命令识别任务时,比较其性能。它突出实验管理器(深度学习工具箱)的能力来调优超参数,并使用内置和用户定义的度量轻松比较不同预训练网络之间的结果。
Audio Toolbox™为音频处理提供了各种预训练的网络,每个网络都由不同的体系结构组成,需要不同的数据预处理。这些差异导致了各种网络的准确性、速度和规模之间的权衡。实验管理器组织训练实验的结果,以突出每个网络的长处和弱点,以便您可以选择最适合您的约束条件的网络。
的性能比较YAMNet而且VGGish预先训练的网络,以及从头训练的定制设计的网络。看到深层网络设计师(深度学习工具箱)以探索Audio Toolbox™支持的其他预训练网络选项。
在本例中,您将下载谷歌语音命令数据集[1]和预先训练的网络,并将它们存储在临时目录中(如果它们还不存在的话)。数据集占用1.96 GB的磁盘空间,网络总共占用470 MB。
开放实验管理器
的方法加载示例开放的例子按钮。这将在MATLAB编辑器的实验管理器中打开项目。
内置训练实验由描述、超参数表、设置函数和度量函数集合组成,用于评估实验结果。有关更多信息,请参见配置内置训练实验(深度学习工具箱).
的描述字段包含实验的文本描述。
的Hyperparameters节指定了用于实验的策略(穷举扫描)和超参数值。当运行实验时,实验管理器使用超参数表中指定的超参数值的每个组合来训练网络。这个例子演示了如何测试不同的网络类型。定义一个hyperparameter,网络,以表示存储为字符串的网络名称。
的设置函数字段包含配置训练数据、网络体系结构和实验训练选项的主要函数的名称。setup函数的输入是一个包含超参数表字段的结构。setup函数将训练数据、网络架构和训练参数作为输出返回。这已经为您实现了。
的指标列表使您能够定义自己的自定义指标,以便在训练实验的不同试验之间进行比较。在本示例的后面,将为您定义两个示例自定义度量函数。实验经理在每个试验中训练的网络上运行列出的每个指标。这里列出了本例中为您定义的指标。您打算使用的任何其他自定义指标都必须在此部分中列出。
定义设置函数
在本例中,设置函数下载数据集,选择所需的网络,执行必要的数据预处理,并设置网络训练选项。该函数的输入是一个结构,其中包含在experimental Manager界面中定义的每个超参数的字段。在设置函数在本例中,输入变量被命名参数个数
输出变量被命名trainingData
,层
,选项
分别表示训练数据、网络结构和训练参数。的关键步骤设置函数下面解释这个例子。在MATLAB中打开示例查看完整的定义compareNetSetup
的名称设置函数本例中使用。
下载和提取数据
要加快示例的速度,请打开compareNetSetup
和切换加速
旗帜真正的
.这减少了数据集的大小,以快速测试实验的基本功能。
加速= false;
辅助函数setupDatastores
下载谷歌语音命令数据集[1],选择网络识别的命令,并将数据随机划分为训练和验证数据存储。
[adsTrain, adsValidation] = setupDatastores(加速);
选择所需的网络和预处理数据
最初根据超参数表中定义的网络类型所需的预处理转换数据存储参数个数。网络
.辅助函数extractSpectrogram
将输入数据处理为每个网络类型所期望的格式。辅助函数getLayers
返回一个layerGraph
(深度学习工具箱)对象,该对象表示所需网络的体系结构。
tdsTrain =变换(adsTrain @ (x) extractSpectrogram (x, params.Network));tdsValidation =变换(adsValidation @ (x) extractSpectrogram (x, params.Network));
层= getLayers(类、classWeights numClasses,网络名);
正确设置了数据存储之后,将数据读入trainingData
而且validationData
变量。
trainingData = readall (tdsTrain UseParallel = canUseParallelPool);validationData = readall (tdsValidation UseParallel = canUseParallelPool);
validationData =表(validationData (: 1), adsValidation.Labels);trainingData =表(trainingData (: 1), adsTrain.Labels);
设置培训选项
设置训练参数trainingOptions
(深度学习工具箱)对象进选项
输出变量。使用Adam优化器训练网络的最大时间为30个时间,耐心为8个时间。设置ExecutionEnvironment
字段设置为“auto”以使用可用的GPU。如果不使用GPU,训练可能会非常耗时。
maxEpochs = 30;miniBatchSize = 256;validationFrequency =地板(元素个数(TTrain) / miniBatchSize);选择= trainingOptions (“亚当”,...GradientDecayFactor = 0.7,...InitialLearnRate =参数。LearnRate,...MaxEpochs = MaxEpochs,...MiniBatchSize = MiniBatchSize,...洗牌=“every-epoch”,...情节=“训练进步”,...Verbose = false,...ValidationData = ValidationData,...ValidationFrequency = ValidationFrequency,...ValidationPatience = 10,...LearnRateSchedule =“分段”,...LearnRateDropFactor = 0.2,...LearnRateDropPeriod =圆(maxEpochs / 3),...ExecutionEnvironment =“汽车”);
自定义指标
实验管理器使您能够定义自定义度量函数,以评估在每次试验中训练的网络的性能。默认情况下计算精度和损失等基本指标。在本例中,您将比较每个模型的大小,因为在将深度神经网络部署到现实应用程序时,内存使用情况是一个重要的度量指标。
自定义度量函数必须接受一个输入参数trialInfo
哪个结构包含字段trainedNetwork
,trainingInfo
,参数
.
trainedNetwork
是SeriesNetwork
(深度学习工具箱)对象或DAGNetwork
(深度学习工具箱)对象返回的trainNetwork
(深度学习工具箱)函数。trainingInfo
控件返回的训练信息的结构是否为trainNetwork
(深度学习工具箱)函数。参数
是否具有超参数表中的字段的结构
度量函数必须返回结果表中显示的标量数、逻辑输出或字符串。在这个实验中为您定义的自定义指标如下所示:
sizeMB
计算分配给存储网络的内存(以兆为单位)numLearnableParams
计算每个模型中可学习参数的数量numIters
计算每个网络在击中任何一个之前所训练的小批量的数量MaxEpochs
或违反ValidationPatience
参数trainingOptions
对象。
运行实验
按“运行”在实验管理器应用程序的顶部窗格运行实验。通过切换模式选项,您可以选择按顺序、同时或批量运行每个试验。在这个实验中,试验是按顺序进行的。
评估结果
当实验结束时,每次试验的结果将出现,并且指标将以表格格式显示。进度条显示了在违反耐心参数之前每个网络训练了多少个周期MaxEpochs
.
通过将鼠标悬停在列名单元格的右侧并单击出现的箭头,可以按每列中的条目对表进行排序。单击右上角的表格图标,选择要显示或隐藏的列。要首先根据精确度比较网络,请按降序对Validation accuracy表进行排序。
在精确度方面,Yamnet
网络表现最好,其次是VGGish
最后是自定义网络。但是,Elapsed Time列显示了这一点Yamnet
训练时间最长。要比较这些网络的大小,请按sizeMB列对表进行排序。
自定义网络最小,Yamnet
是大了几个数量级,还是VGGish
是最大的。
这些结果突出了不同网络设计之间的权衡。的Yamnet
网络在分类任务中表现最好,但需要花费更多的训练时间和中等规模的内存消耗。的VGGish
网络在准确性方面表现略差,但需要超过20倍的内存YAMNet
.最后,自定义网络的准确性最差,但使用的内存最少。
请注意,尽管Yamnet
而且VGGish
是经过预先训练的网络,自定义网络收敛速度最快。看看numites列,自定义网络需要最多的批处理迭代才能收敛,因为它是从头开始学习的。但是,由于自定义网络比深度预训练模型要小得多,也浅得多,所以这些批更新的处理速度要快得多,因此总体训练时间减少了。
要从任何试验中保存一个训练过的网络,右键单击结果表中的相应行并选择出口训练网络.
要进一步分析任何单个试验,单击相应的行,并在审查结果在顶部窗格中的选项卡中,您可以选择调出一个训练进度的图,或结果训练模型的混淆矩阵。的混淆矩阵如下所示Yamnet
模型来自实验的第2次试验。
该模型在区分“off”和“up”这对命令以及“no”和“go”这对命令上最困难,尽管在所有类别中,精确度通常是一致的。此外,该模型在预测“是”命令方面非常有信心,因为该类的假阳性率仅为0.4%。
参考文献
[1]监狱长P。《语音指令:单词语音识别的公共数据集》,2017年。可以从https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz.版权2017年谷歌。语音命令数据集采用创作共用属性4.0许可,可在这里获得:https://creativecommons.org/licenses/by/4.0/legalcode.