主要内容

使用贝叶斯优化优化分类器拟合

方法优化支持向量机分类fitcsvm功能和OptimizeHyperparameters名称-值参数。

生成数据

分类工作的点的位置从高斯混合模型。在统计学习的要素, Hastie, Tibshirani, and Friedman(2009),第17页描述了该模型。该模型首先为“绿色”类生成10个基点,分布为均值(1,0)和单位方差的二维独立法线。它还为“红色”类生成10个基点,分布为均值(0,1)和单位方差的二维独立法线。对于每个类别(绿色和红色),生成100个随机点,如下所示:

  1. 选择一个基准点随机均匀地涂上适当的颜色。

  2. 生成具有均值的二维正态分布的独立随机点和方差I/5,其中I是2 × 2单位矩阵。在本例中,使用方差I/50来更清楚地显示优化的优势。

为每个类生成10个基点。

rng (“默认”%用于再现性Grnpop = mvnrnd([1,0],eye(2),10);Redpop = mvnrnd([0,1],eye(2),10);

查看基准点。

情节(grnpop (: 1) grnpop (:, 2),“去”)举行情节(redpop (: 1) redpop (:, 2),“罗”)举行

图中包含一个轴对象。axis对象包含2个line类型的对象。

由于一些红色基点接近绿色基点,因此仅根据位置很难对数据点进行分类。

生成每个类的100个数据点。

Redpts = 0 (100,2);GRNPTS = redpts;i = 1:10 0 grnpts(我:)= mvnrnd (grnpop(兰迪(10):)、眼睛(2)* 0.02);redpts(我)= mvnrnd (redpop(兰迪(10):)、眼睛(2)* 0.02);结束

查看数据点。

图绘制(grnpts (: 1), grnpts (:, 2),“去”)举行情节(redpts (: 1) redpts (:, 2),“罗”)举行

图中包含一个轴对象。axis对象包含2个line类型的对象。

为分类准备数据

把数据放到一个矩阵中,然后做一个向量grp标记每个点的类。1表示绿色类,-1表示红色类。

Cdata = [grnpts;redpts];GRP = ones(200,1);Grp (101:200) = -1;

准备交叉验证

为交叉验证设置一个分区。

C = cvpartition(200,“KFold”10);

该步骤是可选的。如果您为优化指定了一个分区,那么您就可以为返回的模型计算实际的交叉验证损失。

优化匹配

为了找到一个良好的拟合,即具有最优超参数的拟合,可以使用贝叶斯优化。属性指定要优化的超参数列表OptimizeHyperparameters名称-值参数,并使用HyperparameterOptimizationOptions名称-值参数。

指定“OptimizeHyperparameters”作为“汽车”.的“汽车”选项包含一组要优化的典型超参数。fitcsvm找到的最优值BoxConstraint而且KernelScale.设置超参数优化选项以使用交叉验证分区c然后选择“expected-improvement-plus”再现性的获取函数。默认的采集函数取决于运行时,因此可以给出不同的结果。

Opts = struct(“CVPartition”c“AcquisitionFunctionName”“expected-improvement-plus”);Mdl = fitcsvm(cdata,grp,“KernelFunction”“rbf”...“OptimizeHyperparameters”“汽车”“HyperparameterOptimizationOptions”选择)
|=====================================================================================================| | Iter | Eval客观客观| | | BestSoFar | BestSoFar | BoxConstraint | KernelScale | | |结果| |运行时| | (estim(观察) .) | | | |=====================================================================================================| | 最好1 | | 0.345 | 0.26612 | 0.345 | 0.345 | 0.00474 | 306.44 | | 2 |最好| 0.115 | 0.16757 | 0.115 | 0.12678 | 430.31 | 1.4864 | | 3 |接受| 0.52 | 0.21336 | 0.115 | 0.1152 | 0.028415 | 0.014369 | | 4 |接受| 0.61 | 0.41833 | 0.115 | 0.11504 | 133.94 | 0.0031427 | | 5 |接受| 0.34 | 0.46056 | 0.115 | 0.11504 | 0.010993 | 5.7742 | | 6 |的| 0.085 | 0.25465 | 0.085 | 0.085039 | 885.63 | 0.68403 | | | 7日接受| 0.105 | 0.25751 | 0.085 | 0.085428 | 0.3057 | 0.58118 | | |接受8 | 0.21 | 0.28915 | 0.085 | 0.09566 | 0.16044 | 0.91824 | | | 9日接受| 0.085 | 0.30816 | 0.085 | 0.08725 | 972.19 | 0.46259 | | 10 |接受| 0.1 |0.34457 | 0.085 | 0.090952 | 990.29 | 0.491 | | 11 | Best | 0.08 | 0.21805 | 0.08 | 0.079362 | 2.5195 | 0.291 | | 12 | Accept | 0.09 | 0.24212 | 0.08 | 0.08402 | 14.338 | 0.44386 | | 13 | Accept | 0.1 | 0.23766 | 0.08 | 0.08508 | 0.0022577 | 0.23803 | | 14 | Accept | 0.11 | 0.24347 | 0.08 | 0.087378 | 0.2115 | 0.32109 | | 15 | Best | 0.07 | 0.30411 | 0.07 | 0.081507 | 910.2 | 0.25218 | | 16 | Best | 0.065 | 0.24431 | 0.065 | 0.072457 | 953.22 | 0.26253 | | 17 | Accept | 0.075 | 0.33287 | 0.065 | 0.072554 | 998.74 | 0.23087 | | 18 | Accept | 0.295 | 0.21231 | 0.065 | 0.072647 | 996.18 | 44.626 | | 19 | Accept | 0.07 | 0.26876 | 0.065 | 0.06946 | 985.37 | 0.27389 | | 20 | Accept | 0.165 | 0.24669 | 0.065 | 0.071622 | 0.065103 | 0.13679 | |=====================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | | | result | | runtime | (observed) | (estim.) | | | |=====================================================================================================| | 21 | Accept | 0.345 | 0.20097 | 0.065 | 0.071764 | 971.7 | 999.01 | | 22 | Accept | 0.61 | 0.2416 | 0.065 | 0.071967 | 0.0010168 | 0.0010005 | | 23 | Accept | 0.345 | 0.26803 | 0.065 | 0.071959 | 0.0011459 | 995.89 | | 24 | Accept | 0.35 | 0.23608 | 0.065 | 0.071863 | 0.0010003 | 40.628 | | 25 | Accept | 0.24 | 0.39188 | 0.065 | 0.072124 | 996.55 | 10.423 | | 26 | Accept | 0.61 | 0.46697 | 0.065 | 0.072067 | 994.71 | 0.0010063 | | 27 | Accept | 0.47 | 0.28997 | 0.065 | 0.07218 | 993.69 | 0.029723 | | 28 | Accept | 0.3 | 0.24924 | 0.065 | 0.072291 | 993.15 | 170.01 | | 29 | Accept | 0.16 | 0.37085 | 0.065 | 0.072103 | 992.81 | 3.8594 | | 30 | Accept | 0.365 | 0.19017 | 0.065 | 0.072112 | 0.0010017 | 0.044287 |

图中包含一个轴对象。标题为Min objective vs. Number of function的axis对象包含2个类型为line的对象。这些对象代表最小观测目标,估计最小目标。

图中包含一个轴对象。标题为目标函数模型的轴对象包含线、面、等高线类型的5个对象。这些对象代表观测点,模型均值,下一个点,模型最小可行。

__________________________________________________________ 优化完成。最大目标达到30个。总函数评估:30总运行时间:42.2011秒总目标函数评估时间:8.4361最佳观测可行点:BoxConstraint KernelScale _____________ ___________ 953.22 0.26253观测目标函数值= 0.065估计目标函数值= 0.073726函数评估时间= 0.24431最佳估计可行点(根据模型):BoxConstraint KernelScale _____________ ___________ 985.37 0.27389估计目标函数值= 0.072112估计函数评估时间= 0.28413
Mdl = ClassificationSVM ResponseName: 'Y' CategoricalPredictors: [] ClassNames: [-1 1] ScoreTransform: 'none' NumObservations: 200 HyperparameterOptimizationResults: [1x1 bayesioptimization] Alpha: [77x1 double] Bias: -0.2352 KernelParameters: [1x1 struct] BoxConstraints: [200x1 double] ConvergenceInfo: [1x1 struct] IsSupportVector: [200x1 logical] Solver: 'SMO' Properties, Methods

fitcsvm返回一个ClassificationSVM使用最佳估计可行点的模型对象。基于贝叶斯优化过程的底层高斯过程模型,最佳估计可行点是使交叉验证损失的置信上限最小化的超参数集。

贝叶斯优化过程内部维持了目标函数的高斯过程模型。目标函数为交叉验证的分类误分类率。对于每次迭代,优化过程都会更新高斯过程模型,并使用该模型找到一组新的超参数。迭代显示的每一行都显示了新的超参数集和这些列值:

  • 客观的-在新的超参数集上计算目标函数值。

  • 目标运行时-目标函数评价时间。

  • Eval结果—结果报告,指定为接受最好的,或错误接受表示目标函数返回一个有限值,和错误指示目标函数返回一个不是有限实标量的值。最好的指示目标函数返回一个有限值,该值低于先前计算的目标函数值。

  • BestSoFar(观察)-目前计算的最小目标函数值。此值是当前迭代的目标函数值(如果Eval结果当前迭代的值为最好的)或前一个的值最好的迭代。

  • BestSoFar (estim)。-在每次迭代中,软件在迄今为止尝试的所有超参数集上,使用更新的高斯过程模型估计目标函数值的置信上限。然后软件选择上置信界最小的点。的BestSoFar (estim)。方法返回的目标函数值predictObjective函数在最小值处。

迭代显示下面的图显示BestSoFar(观察)而且BestSoFar (estim)。分别用蓝色和绿色表示。

返回的对象Mdl使用最佳估计可行点,即产生的超参数集BestSoFar (estim)。在最终的高斯过程模型的基础上进行最后的迭代。

你可以得到最好的点从HyperparameterOptimizationResults属性或通过使用bestPoint函数。

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans =1×2表BoxConstraint KernelScale  _____________ ___________ 985.37 - 0.27389
[x,CriterionValue,iteration] = bestPoint(mml . hyperparameteroptimizationresults)
x =1×2表BoxConstraint KernelScale  _____________ ___________ 985.37 - 0.27389
CriterionValue = 0.0888
迭代= 19

默认情况下,bestPoint函数使用“min-visited-upper-confidence-interval”标准。该准则选取第19次迭代得到的超参数作为最佳点。CriterionValue为最终高斯过程模型计算的交叉验证损失的上界。通过使用分区计算实际的交叉验证损失c

L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,“CVPartition”c“KernelFunction”“rbf”...“BoxConstraint”x。BoxConstraint,“KernelScale”, x.KernelScale))
L_MinEstimated = 0.0700

实际交叉验证的损失接近估计值。的目标函数估计值优化结果图下方显示。

您还可以提取最佳观测可行点(即最后最好的点在迭代显示)从HyperparameterOptimizationResults属性或通过指定标准作为“min-observed”

Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans =1×2表BoxConstraint KernelScale  _____________ ___________ 953.22 - 0.26253
[x_observed,CriterionValue_observed,iteration_observed] = bestPoint(Mdl。HyperparameterOptimizationResults,“标准”“min-observed”
x_observed =1×2表BoxConstraint KernelScale  _____________ ___________ 953.22 - 0.26253
CriterionValue_observed = 0.0650
Iteration_observed = 16

“min-observed”Criterion选择第16次迭代得到的超参数作为最佳点。CriterionValue_observed使用所选超参数计算的实际交叉验证损失。有关更多信息,请参见标准的名称-值参数bestPoint

可视化优化后的分类器。

D = 0.02;[x1Grid, x2Grid] = meshgrid (min (cdata (: 1)): d:马克斯(cdata (: 1)),...分钟(cdata (:, 2)): d:马克斯(cdata (:, 2)));xGrid = [x1Grid(:),x2Grid(:)];[~,scores] = predict(Mdl,xGrid);图h (1:2) = gscatter (cdata (: 1), cdata (:, 2), grp,“rg”' + *’);持有h(3) = plot(cdata(mld . issupportvector,1),...cdata (Mdl.IsSupportVector, 2),“柯”);轮廓(x1Grid x2Grid,重塑(分数(:,2),大小(x1Grid)), [0 0),“k”);传奇(h, {' 1 '“+ 1”“支持向量”},“位置”“东南”);

图中包含一个轴对象。坐标轴对象包含线、轮廓等4个对象。这些对象表示-1,+1,支持向量。

评估新数据的准确性

生成并分类新的测试数据点。

Grnobj = gmdistribution(grnpop,.2*eye(2));Redobj = gmdistribution(redpop,.2*eye(2));newData = random(grnobj,10);newData = [newData;random(redobj,10)];grpData = ones(20,1);%绿色= 1grpData(11:20) = -1;%红色= -1v = predict(Mdl,newData);

计算测试数据集上的误分类率。

L_Test = loss(Mdl,newData,grpData)
L_Test = 0.3500

确定哪些新数据点被正确分类。将正确分类的点格式化为红色方块,将错误分类的点格式化为黑色方块。

h(4:5) = gscatter(newData(:,1),newData(:,2),v,“mc”“* *”);mydiff = (v == grpData);%正确分类Ii = mydiff在正确的点周围画红方块h(6) = plot(newData(ii,1),newData(ii,2),“rs”“MarkerSize”12);结束Ii = not(mydiff)在不正确的点周围画黑色方块h(7) = plot(newData(ii,1),newData(ii,2),“ks”“MarkerSize”12);结束传奇(h, {“1”(培训)“+ 1(培训)”“支持向量”...“1”(分类)“+ 1(分类)”...正确分类的“是不是”},...“位置”“东南”);持有

图中包含一个轴对象。坐标轴对象包含8个对象类型的线,轮廓。这些对象表示-1(训练),+1(训练),支持向量,-1(分类),+1(分类),正确分类,错误分类。

另请参阅

|

相关的话题

Baidu
map