用卷积神经网络对文本数据进行分类
这个例子展示了如何使用卷积神经网络对文本数据进行分类。
要使用卷积对文本数据进行分类,请使用在输入的时间维度上进行卷积的1-D卷积层。
这个例子用不同宽度的一维卷积滤波器训练一个网络。每个过滤器的宽度对应于过滤器可以看到的单词数(n-gram长度)。该网络具有卷积层的多个分支,因此可以使用不同的n克长度。
加载数据
中的数据创建表格文本数据存储factoryReports.csv
并查看前几个报告。
数据= readtable (“factoryReports.csv”);头(数据)
ans =8×5表类别描述紧急解决成本 _______________________________________________________________________ ______________________ __________ ______________________ _____ {' 项目是偶尔陷于扫描器线轴。'}{'机械故障'}{'中等'}{'重新调整机器'}45{'组装活塞发出响亮的咔啦咔啦声和砰砰声。'}{'机械故障'}{'中等'}{'重新调整机器'}35{'开机时电源被切断。'}{'电子故障'}{'高'}{'完全更换'}16200{'汇编器中的电容器烧坏。'}{'电子故障'}{'高'}{'更换元件'}352{'混合器触发熔断器。'}{'电子故障'}{'低'}{'加入监视列表'}55{'施工剂中的管道破裂正在喷冷却剂。'}{'泄漏'}{'高'}{'更换元件'}371{'混合器保险丝熔断。'}{'电子故障'}{'低'}{'更换元件'}441{'东西继续从传送带上掉下来。'}{'机械故障'}{'低'}{'重新调整机器'}38
将数据划分为训练和验证分区。使用80%的数据进行训练,其余数据进行验证。
本量利= cvpartition (data.Category,坚持= 0.2);dataTrain =数据(训练(cvp):);dataValidation =数据(测试(cvp):);
预处理文本数据
方法提取文本数据“描述”
列,并对其进行预处理preprocessText
函数,在本节中列出文本预处理功能的例子。
documentsTrain = preprocessText (dataTrain.Description);
方法提取标签“类别”
列,并将它们转换为类别。
TTrain =分类(dataTrain.Category);
查看类名和观察数。
一会=独特(TTrain)
一会=4×1分类电子故障泄漏机械故障软件故障
numObservations =元素个数(TTrain)
numObservations = 384
使用相同的步骤提取和预处理验证数据。
documentsValidation = preprocessText (dataValidation.Description);TValidation =分类(dataValidation.Category);
将文档转换为序列
要将文档输入到神经网络中,请使用字编码将文档转换为数字索引序列。
根据文档创建一个单词编码。
内附= wordEncoding (documentsTrain);
查看单词编码的词汇量大小。词汇量大小是单词编码中唯一单词的数量。
numWords = enc.NumWords
numWords = 436
方法将文档转换为整数序列doc2sequence
函数。
documentsTrain XTrain = doc2sequence (enc);
使用从训练数据创建的字编码将验证文档转换为序列。
documentsValidation XValidation = doc2sequence (enc);
定义网络体系结构
为分类任务定义网络体系结构。
下面的步骤描述了网络架构。
指定输入大小为1,它对应于整数序列输入的通道尺寸。
使用维数为100的单词嵌入输入。
对于长度为2、3、4和5的n克,创建包含卷积层、批处理归一化层、ReLU层、dropout层和最大池化层的层块。
对于每个块,指定200个大小为1 × -的卷积过滤器N和全局最大池化层。
将输入层连接到每个块,并使用连接层连接块的输出。
要对输出进行分类,需要包含一个具有输出大小的完全连接层K,一个softmax层和一个分类层,其中K是类的数量。
指定网络超参数。
embeddingDimension = 100;ngramlength = [2 3 4 5];numFilters = 200;
首先,创建一个包含输入层和维度为100的单词嵌入层的层图。要帮助将单词嵌入层连接到卷积层,请将单词嵌入层名称设置为“循证”
.若要检查卷积层在训练期间不将序列卷积为长度为零,请设置最小长度
选项设置为训练数据中最短序列的长度。
最小长度= min (doclength (documentsTrain));layers = [sequenceInputLayer(1,MinLength= MinLength) wordEmbeddingLayer(embeddingDimension,numWords,Name= .“循证”));lgraph = layerGraph(层);
对于每个n克长度,创建一个一维卷积、批处理归一化、ReLU、dropout和一维全局最大池化层的块。将每个块连接到单词嵌入层。
numBlocks =元素个数(ngramLengths);为j = 1:numBlocks N = ngramlength (j);block = [convolution1dLayer(N,numFilters,Name= .“conv”+ N,填充=“相同”) batchNormalizationLayer (Name =“bn”+ N) reluLayer (Name =“relu”+ N) dropoutLayer (0.2, Name =“下降”+ N) globalMaxPooling1dLayer (Name =“马克斯”+ N));lgraph = addLayers (lgraph块);lgraph = connectLayers (lgraph,“循证”,“conv”+ N);结束
添加连接层、全连接层、softmax层和分类层。
numClasses =元素个数(类名);layers = [concatenationLayer(1,numBlocks,Name= .“猫”) fullyConnectedLayer (numClasses Name =“俱乐部”) softmaxLayer (Name =“软”) classificationLayer (Name =“分类”));lgraph = addLayers (lgraph层);
将全局最大池化层连接到连接层,在图中查看网络架构。
为j = 1:numBlocks N = ngramlength (j);lgraph = connectLayers (lgraph,“马克斯”+ N,“猫/”+ j);结束图绘制(lgraph)标题(“网络架构”)
列车网络的
指定培训选项:
用128个小批次进行训练。
使用验证数据验证网络。
返回验证损耗最小的网络。
显示训练进度图,抑制冗长输出。
选择= trainingOptions (“亚当”,...MiniBatchSize = 128,...ValidationData = {XValidation, TValidation},...OutputNetwork =“best-validation-loss”,...情节=“训练进步”,...Verbose = false);
训练网络使用trainNetwork
函数。
网= trainNetwork (XTrain TTrain、lgraph选项);
测试网络
利用训练过的网络对验证数据进行分类。
YValidation =分类(净,XValidation);
在困惑图表中可视化预测。
图confusionchart (TValidation YValidation)
计算分类精度。准确率是预测正确的标签的比例。
精度=平均值(TValidation == YValidation)
精度= 0.9375
使用新数据进行预测
对三个新报告的事件类型进行分类。创建一个包含新报告的字符串数组。
reportsNew = [“冷却剂在分拣机下面聚集。”“分拣机在启动时熔断保险丝。”“组装机发出一些非常响亮的咔啦咔啦的声音。”];
使用预处理步骤作为训练和验证文档对文本数据进行预处理。
documentsNew = preprocessText (reportsNew);documentsNew XNew = doc2sequence (enc);
利用训练过的网络对新序列进行分类。
XNew YNew =分类(净)
YNew =3×1分类泄漏电子故障机械故障
文本预处理功能
的preprocessTextData
函数接受文本数据作为输入,并执行以下步骤:
在标记文本。
将文本转换为小写。
函数documents = preprocessText(textData) documents = tokenizedDocument(textData);文件=低(文件);结束
另请参阅
fastTextWordEmbedding
|wordcloud
|wordEmbedding
|layerGraph
(深度学习工具箱)|convolution2dLayer
(深度学习工具箱)|batchNormalizationLayer
(深度学习工具箱)|trainingOptions
(深度学习工具箱)|trainNetwork
(深度学习工具箱)|doc2sequence
|tokenizedDocument