利用长短时记忆网络对心电信号进行分类
这个例子展示了如何使用深度学习和信号处理来分类来自PhysioNet 2017挑战赛的心跳心电图(ECG)数据。特别是,该示例使用了长短期记忆网络和时频分析。
有关使用GPU和并行计算工具箱™复制和加速此工作流的示例,请参见基于GPU加速的长短时记忆网络心电信号分类.
简介
心电图记录一个人在一段时间内的心电活动。医生使用心电图来直观地检测病人的心跳是正常还是不规则。
心房颤动(AFib)是一种不规则心跳,发生在心脏的上腔,即心房,与下腔,即心室的跳动不协调时。
本例使用了来自PhysioNet 2017挑战赛的心电图数据[1]、[2]、[3.]下载,网址为https://physionet.org/challenge/2017/.该数据由一组以300hz采样的心电信号组成,由一组专家分为四个不同的类别:正常(N)、AFib (a)、其他节律(O)和噪声记录(~)。这个例子展示了如何使用深度学习自动化分类过程。该程序探索了一个二元分类器,可以区分正常心电信号和显示AFib迹象的信号。
长短时记忆(LSTM)网络是一种适合于研究序列和时间序列数据的递归神经网络(RNN)。LSTM网络可以学习序列时间步之间的长期依赖关系。LSTM层(lstmLayer
(深度学习工具箱))可以看到正向的时间序列,而双向LSTM层(bilstmLayer
(深度学习工具箱))可以从正反两个方向看时间序列。本例使用了双向LSTM层。
这个例子展示了在解决人工智能(AI)问题时使用以数据为中心的方法的优势。使用原始数据训练LSTM网络的初步尝试给出了不合格的结果。使用提取的特征训练相同的模型体系结构可以显著提高分类性能。
为了加速训练过程,请在具有GPU的机器上运行此示例。如果您的机器有一个GPU和并行计算工具箱™,那么MATLAB®自动使用GPU进行训练;否则,它将使用CPU。
加载和检查数据
运行ReadPhysionetData
脚本,从PhysioNet网站下载数据,并生成一个mat文件(PhysionetData.mat
),以适当的格式载有心电讯号。下载数据可能需要几分钟。使用条件语句,只在以下情况下运行脚本PhysionetData.mat
在当前文件夹中不存在。
如果~ isfile (“PhysionetData.mat”) ReadPhysionetData结束负载PhysionetData
加载操作将两个变量添加到工作区:信号
而且标签
.信号
是一个保存心电信号的单元阵列。标签
是一个类别数组,它保存着信号的相应的地面真值标签。
信号(1:5)
ans =5×1单元格数组{1×9000 double} {1×9000 double} {1×18000 double} {1×9000 double} {1×18000 double}
标签(1:5)
ans =5×1分类N N N a a
使用总结
函数查看数据中包含多少AFib信号和Normal信号。
总结(标签)
A 738 n 5050
生成信号长度的直方图。大多数信号的样本长度为9000个。
L = cellfun(@length,Signals);h =直方图(L);xticks (0:3000:18000);xticklabels (0:3000:18000);标题(“信号长度”)包含(“长度”) ylabel (“数”)
想象每一类信号的一个片段。AFib的心跳间隔不规律,而正常的心跳有规律。AFib心跳信号通常也缺少P波,P波在正常心跳信号的QRS复合体之前跳动。正常信号图显示P波和QRS复波。
正常=信号{1};aFib = Signals{4};次要情节(2,1,1)情节(正常)标题(“正常的节奏”) xlim([4000,5200]) ylabel(“振幅(mV)”)文本(4330、150、“P”,“HorizontalAlignment”,“中心”)文本(4370、850、“QRS”,“HorizontalAlignment”,“中心”副图(2,1,2)副图(aFib)心房纤维性颤动的) xlim([4000,5200])“样本”) ylabel (“振幅(mV)”)
为培训准备数据
在培训期间,trainNetwork
函数将数据分成小批。然后,函数在同一个小批处理中填充或截断信号,使它们都具有相同的长度。过多的填充或截断会对网络的性能产生负面影响,因为网络可能会根据添加或删除的信息错误地解释信号。
若要避免过度填充或截断,请应用segmentSignals
功能,所以它们都是9000个样本长。该函数忽略小于9000个样本的信号。如果一个信号有超过9000个样本,segmentSignals
将其分解为尽可能多的9000个样本段,并忽略其余的样本。例如,一个有18500个样本的信号变成两个9000个样本的信号,剩下的500个样本被忽略。
[Signals,Labels] = segmentSignals(Signals,Labels);
的前五个元素信号
数组,以验证每个条目现在有9000个样本长度。
信号(1:5)
ans =5×1单元格数组{1×9000 double} {1×9000 double} {1×9000 double} {1×9000 double} {1×9000 double}
第一次尝试:用原始信号数据训练分类器
要设计分类器,请使用上一节生成的原始信号。将信号分为训练集和测试集,分别训练分类器和测试集,测试分类器在新数据上的准确性。
使用总结
函数表示AFib信号与正常信号的比值为78:4937,约为1:7。
总结(标签)
A 718 n 4937
因为大约7/8的信号是Normal,分类器会知道它可以通过简单地将所有信号分类为Normal来获得较高的精度。为了避免这种偏差,可以通过复制数据集中的AFib信号来增强AFib数据,以便有相同数量的Normal和AFib信号。这种重复,通常称为过采样,是深度学习中使用的数据增强的一种形式。
根据他们的类别来分解信号。
afibX = Signals(Labels==“一个”);afibY = Labels(Labels==“一个”);normalX =信号(标签==“N”);normalY =标签(标签==“N”);
下一步,使用dividerand
将每一类目标随机分为训练集和测试集。
[trainIndA,~,testIndA] = dividerand(718,0.9,0.0,0.1);[trainIndN,~,testIndN] = dividerand(4937,0.9,0.0,0.1);XTrainA = afibX(trainIndA);YTrainA = afibY(trainIndA);XTrainN = normalX(trainIndN);YTrainN = normalY(trainIndN);XTestA = afibX(testIndA);YTestA = afibY(testIndA);XTestN = normalX(testIndN);YTestN = normalY(testIndN);
目前有646个AFib信号和4443个Normal信号用于训练。为了在每个类中实现相同数量的信号,使用第一个4438正常信号,然后使用repmat
重复第一个634 AFib信号7次。
测试中有72个AFib信号和494个Normal信号。使用前490个正常信号,然后使用repmat
重复前70个AFib信号7次。默认情况下,神经网络在训练前随机打乱数据,确保相邻的信号不都有相同的标签。
XTrain = [repmat(XTrainA(1:634),7,1);XTrainN (1:4438)];YTrain = [repmat(YTrainA(1:634),7,1);YTrainN (1:4438)];XTest = [repmat(XTestA(1:70),7,1);XTestN (1:490)];YTest = [repmat(YTestA(1:70),7,1);YTestN (1:490);];
Normal和AFib信号之间的分布现在在训练集和测试集都是均匀平衡的。
总结(YTrain)
A 4438 n 4438
总结(欧美)
A 490 n 490
定义LSTM网络架构
LSTM网络可以学习序列数据时间步之间的长期依赖关系。本例使用双向LSTM层bilstmLayer
,因为它从正向和反向两个方向查看序列。
因为每个输入信号都有一个维度,所以将输入大小指定为大小为1的序列。指定一个输出大小为100的双向LSTM层,并输出序列的最后一个元素。该命令指示双向LSTM层将输入的时间序列映射为100个特征,然后为全连接层准备输出。最后,通过包含一个大小为2的完全连接层,然后是一个softmax层和一个分类层,指定两个类。
图层= [...sequenceInputLayer (1) bilstmLayer (100“OutputMode”,“最后一次”) fullyConnectedLayer(2)
2”BiLSTM BiLSTM 100个隐藏单元3”全连接2全连接层4”Softmax Softmax 5”分类输出crossentropyex
接下来为分类器指定训练选项。设置“MaxEpochs”
到10,允许网络在训练数据中进行10次传递。一个“MiniBatchSize”
Of 150指示网络一次观察150个训练信号。一个“InitialLearnRate”
0.01有助于加速训练过程。指定一个“SequenceLength”
将信号分解成更小的片段,这样机器就不会因为一次查看太多的数据而耗尽内存。设置”GradientThreshold
’到1,通过防止梯度变得太大来稳定训练过程。指定“阴谋”
作为“训练进步”
生成图形,显示随着迭代次数增加的训练进度。集“详细”
来假
抑制与图中显示的数据相对应的表输出。如果你想看这张桌子,就开始吧“详细”
来真正的
.
本例使用自适应力矩估计(ADAM)求解器。ADAM与LSTMs等rnn相比,默认的随机动量梯度下降(SGDM)求解器的性能更好。
options = trainingOptions(“亚当”,...“MaxEpochs”10...“MiniBatchSize”, 150,...“InitialLearnRate”, 0.01,...“SequenceLength”, 1000,...“GradientThreshold”, 1...“ExecutionEnvironment”,“汽车”,...“阴谋”,“训练进步”,...“详细”、假);
培训LSTM网络
使用指定的训练选项和层架构训练LSTM网络trainNetwork
.由于训练集很大,训练过程可能需要几分钟。
net = trainNetwork(XTrain,YTrain,layers,options);
训练进度图的顶部子图表示训练精度,即每个小批上的分类精度。当训练成功进行时,这个值通常会增加到100%。底部的子图显示了训练损失,这是每个小批上的交叉熵损失。当训练成功进行时,这个值通常趋近于零。
如果训练不收敛,图可能在值之间振荡,而没有某种向上或向下的趋势。这种振荡意味着训练精度没有提高,训练损失没有减少。这种情况可能在训练开始时就发生,或者在训练精度得到一些初步改善后,图可能趋于稳定。在许多情况下,改变培训选项可以帮助网络实现收敛。减少MiniBatchSize
或减少InitialLearnRate
可能会导致更长的训练时间,但它可以帮助网络更好地学习。
分类器的训练准确率在50%到60%之间波动,在10个周期结束时,它已经花费了几分钟的训练时间。
可视化训练和测试准确性
计算训练精度,它表示分类器对受训信号的准确率。首先,对训练数据进行分类。
trainPred = class (net,XTrain,“SequenceLength”, 1000);
在分类问题中,混淆矩阵被用来可视化分类器在一组真实值已知的数据上的性能。目标类是信号的地真标签,输出类是网络分配给信号的标签。坐标轴标签表示类标签,AFib (A)和Normal (N)。
使用confusionchart
命令来计算测试数据预测的总体分类精度。指定“RowSummary”
作为“row-normalized”
在行摘要中显示真阳性率和假阳性率。同时,指定“ColumnSummary”
作为“column-normalized”
在列摘要中显示阳性预测值和错误发现率。
LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 61.7283
图confusionchart (YTrain trainPred,“ColumnSummary”,“column-normalized”,...“RowSummary”,“row-normalized”,“标题”,“LSTM困惑图”);
现在对同一网络的测试数据进行分类。
testPred = class (net,XTest,“SequenceLength”, 1000);
计算测试精度,并将分类性能可视化为混淆矩阵。
LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 66.2245
图confusionchart(欧美、testPred“ColumnSummary”,“column-normalized”,...“RowSummary”,“row-normalized”,“标题”,“LSTM困惑图”);
第二次尝试:使用特征提取提高性能
从数据中提取特征有助于提高分类器的训练和测试精度。为了决定提取哪些特征,本示例采用了一种计算时频图像(如谱图)的方法,并使用它们训练卷积神经网络(CNNs) [4]、[5]。
将每种信号的光谱图可视化。
Fs = 300;图次要情节(2,1,1);pspectrum(正常,fs,的谱图,“TimeResolution”, 0.5)标题(“正常信号”次要情节(2,1,2);pspectrum (aFib fs,的谱图,“TimeResolution”, 0.5)标题(“AFib信号”)
由于本例使用LSTM而不是CNN,因此必须对该方法进行转换,使其适用于一维信号。时频矩从谱图中提取信息。每一个时刻都可以作为一个一维特征输入到LSTM中。
探索时域中的两个TF矩:
瞬时频率(
instfreq
)谱熵(
pentropy
)
的instfreq
函数估计信号的随时间变化的频率,作为功率谱图的一阶矩。该函数使用时间窗上的短时傅里叶变换计算光谱图。在本例中,函数使用255个时间窗口。函数的时间输出对应于时间窗口的中心。
可视化每一种信号的瞬时频率。
[instFreqA,tA] = instfreq(aFib,fs);[instFreqN,tN] = instfreq(normal,fs);图次要情节(2,1,1);情节(tN, instFreqN)标题(“正常信号”)包含(“时间(s)”) ylabel (瞬时频率的次要情节(2,1,2);instFreqA情节(tA)标题(“AFib信号”)包含(“时间(s)”) ylabel (瞬时频率的)
使用cellfun
要应用instfreq
对训练和测试集中的每个单元都有作用。
instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,“UniformOutput”、假);instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,“UniformOutput”、假);
频谱熵测量信号的频谱有多平坦。具有尖峰频谱的信号,如正弦信号的和,具有较低的频谱熵。具有平坦频谱的信号,如白噪声,具有很高的频谱熵。的pentropy
函数根据功率谱图估计谱熵。与瞬时频率估计的情况一样,pentropy
使用255个时间窗口计算谱图。函数的时间输出对应于时间窗口的中心。
可视化每种信号的光谱熵。
[pentropyA,tA2] = pentropy(aFib,fs);[pentropyN,tN2] = pentropy(normal,fs);图副图(2,1,1)图(tN2,pentropyN)标题(“正常信号”) ylabel (“谱熵”(2,1,2) (tA2,pentropyA)“AFib信号”)包含(“时间(s)”) ylabel (“谱熵”)
使用cellfun
要应用pentropy
对训练和测试集中的每个单元都有作用。
pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,“UniformOutput”、假);pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,“UniformOutput”、假);
连接这些特征,使新的训练和测试集中的每个单元都具有两个维度或两个特征。
XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,“UniformOutput”、假);XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,“UniformOutput”、假);
可视化新输入的格式。每个细胞不再包含一个9000个样本长的信号;现在它包含两个255个样本长的特征。
XTrain2 (1:5)
ans =5×1单元格数组{2×255 double} {2×255 double} {2×255 double} {2×255 double} {2×255 double}
标准化数据
瞬时频率和谱熵的含义相差几乎一个数量级。此外,瞬时频率平均值可能过高,使LSTM无法有效学习。当网络拟合的数据具有较大的均值和较大的值范围时,大量的输入可能会减慢网络的学习和收敛速度[6]。
意思是(instFreqN)
Ans = 5.5615
意思是(pentropyN)
Ans = 0.6326
利用训练集均值和标准差对训练集和测试集进行标准化。标准化,或z评分,是在训练中提高网络性能的一种流行方法。
XV = [XTrain2{:}];=均值(XV,2);sg = std(XV,[],2);XTrainSD = XTrain2;XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,“UniformOutput”、假);XTestSD = XTest2;= cellfun(@(x)(x-mu)./sg,XTestSD,“UniformOutput”、假);
给出了标准化瞬时频率和谱熵的方法。
instFreqNSD = XTrainSD{1}(1,:);pentropyNSD = XTrainSD{1}(2,:);意思是(instFreqNSD)
Ans = -0.3211
意思是(pentropyNSD)
Ans = -0.2416
修改LSTM网络架构
既然每个信号都有两个维度,就有必要通过指定输入序列大小为2来修改网络架构。指定一个输出大小为100的双向LSTM层,并输出序列的最后一个元素。通过包含一个大小为2的完全连接层,然后是一个softmax层和一个分类层,指定两个类。
图层= [...sequenceInputLayer (2) bilstmLayer (100“OutputMode”,“最后一次”) fullyConnectedLayer(2)
2”BiLSTM BiLSTM与100个隐藏单元3”全连接2全连接层4”Softmax Softmax 5”分类输出crossentropyex
指定培训选项。将最大时数设置为30,允许网络通过30次训练数据。
options = trainingOptions(“亚当”,...“MaxEpochs”30岁的...“MiniBatchSize”, 150,...“InitialLearnRate”, 0.01,...“GradientThreshold”, 1...“ExecutionEnvironment”,“汽车”,...“阴谋”,“训练进步”,...“详细”、假);
训练具有时频特性的LSTM网络
使用指定的训练选项和层架构训练LSTM网络trainNetwork
.
net2 = trainNetwork(XTrainSD,YTrain,layers,options);
在训练精度上有了很大的提高。交叉熵损失趋向于0。此外,由于TF矩比原始序列短,因此训练所需的时间减少。
可视化训练和测试准确性
使用更新后的LSTM网络对训练数据进行分类。将分类性能可视化为一个混淆矩阵。
trainPred2 = category (net2,XTrainSD);LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 83.5962
图confusionchart (YTrain trainPred2,“ColumnSummary”,“column-normalized”,...“RowSummary”,“row-normalized”,“标题”,“LSTM困惑图”);
用更新后的网络对测试数据进行分类。绘制混淆矩阵来检验测试精度。
testPred2 = category (net2,XTestSD);LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 80.1020
图confusionchart(欧美、testPred2“ColumnSummary”,“column-normalized”,...“RowSummary”,“row-normalized”,“标题”,“LSTM困惑图”);
结论
本例展示了如何利用LSTM网络建立检测心电信号中房颤的分类器。该程序使用过采样,以避免在主要由健康患者组成的人群中检测异常情况时出现的分类偏差。使用原始信号数据训练LSTM网络,分类精度较差。利用每个信号的两个时频矩特征对网络进行训练,显著提高了分类性能,同时减少了训练时间。
参考文献
[1]基于短导联心电图记录的房颤分类:心内科物理网络/计算挑战,2017。https://physionet.org/challenge/2017/
[2] Clifford, Gari,刘成玉,Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Li Qiao, Alistair Johnson, Roger G. Mark。从短单导联心电图记录中进行AF分类:2017年心脏病学挑战中的物理网络计算。心脏病学中的计算(雷恩:IEEE)。2017年第44卷第1-4页。
[3]戈德伯格,a.l., l.a.n.阿马拉尔,L.格拉斯,J. M.豪斯多夫,P. Ch.伊万诺夫,R. G.马克,J. E.米耶图斯,G. B.穆迪,c . k。彭和h·e·斯坦利。“PhysioBank, PhysioToolkit和PhysioNet:复杂生理信号新研究资源的组成部分”。循环.第101卷第23期,2000年6月13日,页e215-e220。http://circ.ahajournals.org/content/101/23/e215.full
Pons, Jordi, Thomas Lidy和Xavier Serra。“音乐驱动卷积神经网络实验”。第十四届基于内容的多媒体索引国际研讨会.2016年6月。
[5]王,D。“深度学习重塑助听器,”IEEE频谱2017年3月,第54卷第3期,第32-37页。doi: 10.1109 / MSPEC.2017.7864754。
布朗利,杰森。如何在Python中为长短期记忆网络扩展数据.2017年7月7日。https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.
另请参阅
功能
instfreq
|pentropy
|trainingOptions
(深度学习工具箱)|trainNetwork
(深度学习工具箱)|bilstmLayer
(深度学习工具箱)|lstmLayer
(深度学习工具箱)
相关的话题
- 长-短时记忆网络(深度学习工具箱)