主要内容

训练二元分类的广义加性模型

这个例子展示了如何训练一个二元分类的广义加性模型以及如何评估训练后模型的预测性能。该示例首先查找单变量GAM的最佳参数值(线性项的参数),然后查找双变量GAM的值(交互项的参数)。此外,该示例还解释了如何通过检查术语对特定预测的局部影响和计算预测对预测器的部分依赖性来解释训练过的模型。

加载示例数据

加载存储在中的1994年人口普查数据census1994.mat.该数据集由来自美国人口普查局的人口统计数据组成,用来预测一个人的年收入是否超过5万美元。分类任务是根据人们的年龄、劳动阶层、教育程度、婚姻状况、种族等,拟合一个预测工资类别的模型。

负载census1994

census1994包含训练数据集adultdata和测试数据集成人.为减少此示例的运行时间,可以使用datasample函数。

rng (1)%的再现性NumSamples = 5 e2;adultdata = datasample (adultdata NumSamples,“替换”、假);成人= datasample(成人、NumSamples“替换”、假);

用最优超参数训练GAM

用超参数训练GAM,使交叉验证损失最小化OptimizeHyperparameters名称-值参数。

您可以指定OptimizeHyperparameters作为“汽车”“所有”找出单变量和双变量参数的最优超参数值。方法可以为单变量参数找到最优值“auto-univariate”“all-univariate”选项,然后为二元参数找到最优值“auto-bivariate”“all-bivariate”选择。这个示例使用“auto-univariate”而且“auto-bivariate”

训练一个单变量GAM。指定OptimizeHyperparameters作为“auto-univariate”fitcgam的最优值InitialLearnRateForPredictors而且NumTreesPerPredictor名称-值参数。对于可再现性,使用“expected-improvement-plus”采集功能。指定ShowPlots作为而且详细的为0分别禁用图形和消息显示。

Mdl_univariate = fitcgam (adultdata,“工资”“重量”“fnlwgt”...“OptimizeHyperparameters”“auto-univariate”...“HyperparameterOptimizationOptions”结构(“AcquisitionFunctionName”“expected-improvement-plus”...“ShowPlots”假的,“详细”, 0))
Mdl_univariate = ClassificationGAM PredictorNames: {'age' 'workClass' 'education' 'education_num' 'marital_status' 'occupation' 'relationship' 'race' 'sex' 'capital_gain' 'capital_loss' 'hours_per_week' 'native_country'} ResponseName: 'salary' CategoricalPredictors: [2 3 5 6 7 8 9 13] ClassNames: [<=50K >50K] ScoreTransform: 'logit' Intercept: -1.3118 NumObservations: 500 HyperparameterOptimizationResults: [1×1 BayesianOptimization]属性,方法

fitcgam返回一个ClassificationGAM模型对象使用最佳估计可行点。最佳估计可行点是指基于贝叶斯优化过程的底层目标函数模型,使目标函数值的置信上限最小的超参数集。你可以从上面得到最好的点HyperparameterOptimizationResults属性或使用bestPoint函数。

x = Mdl_univariate.HyperparameterOptimizationResults.XAtMinEstimatedObjective
x =1×2表InitialLearnRateForPredictors NumTreesPerPredictor  _____________________________ ____________________ 118 - 0.02257
bestPoint (Mdl_univariate.HyperparameterOptimizationResults)
ans =1×2表InitialLearnRateForPredictors NumTreesPerPredictor  _____________________________ ____________________ 118 - 0.02257

有关优化过程的详细信息,请参见使用OptimizeHyperparameters优化GAM

训练二元GAM。指定OptimizeHyperparameters作为“auto-bivariate”fitcgam的最优值的相互作用InitialLearnRateForInteractions,NumTreesPerInteraction名称-值参数。中的单变量参数值x这样软件就可以根据x值找到交互项的最优参数值。

Mdl = fitcgam (adultdata,“工资”“重量”“fnlwgt”...“InitialLearnRateForPredictors”x。InitialLearnRateForPredictors,...“NumTreesPerPredictor”x。NumTreesPerPredictor,...“OptimizeHyperparameters”“auto-bivariate”...“HyperparameterOptimizationOptions”结构(“AcquisitionFunctionName”“expected-improvement-plus”...“ShowPlots”假的,“详细”, 0))
Mdl = ClassificationGAM PredictorNames: {'age' 'workClass' 'education' 'education_num' 'marital_status' 'occupation' 'relationship' 'race' 'sex' 'capital_gain' 'capital_loss' 'hours_per_week' 'native_country'} ResponseName: 'salary' CategoricalPredictors: [2 3 5 6 7 8 9 13] ClassNames: [<=50K >50K] ScoreTransform: 'logit' Intercept: -1.4587 Interactions: [6×2 double] NumObservations: 500 HyperparameterOptimizationResults: [1×1 BayesianOptimization]属性,方法

显示最优双变量超参数。

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans =1×3表交互InitialLearnRateForInteractions NumTreesPerInteraction  ____________ _______________________________ ______________________ 0.0061954 422

模型显示Mdl显示模型属性的部分列表。要查看模型属性的完整列表,双击变量名Mdl在工作区中。的变量编辑器打开Mdl.或者,您可以使用点表示法在命令窗口中显示属性。例如,显示ReasonForTermination财产。

Mdl。ReasonForTermination
ans =结构体字段:PredictorTrees: '在训练请求数量的树后终止。'InteractionTrees: '在训练请求数量的树后终止。'

您可以使用ReasonForTermination属性,以确定训练的模型是否包含每个线性项和每个交互项的指定数量的树。

中显示交互项Mdl

Mdl。的相互作用
ans =6×25 12 1 6 6 12 1 12 7 9 2 6

每一行的的相互作用表示一个交互项,并包含该交互项的预测变量的列索引。您可以使用的相互作用属性来检查模型中的交互项及其顺序fitcgam将它们添加到模型中。

中显示交互项Mdl使用预测器名称。

Mdl.PredictorNames (Mdl.Interactions)
ans =6×2细胞{' marital_status}{‘hours_per_week}{‘年龄’}{‘占领’}{‘占领’}{‘hours_per_week}{‘年龄’}{‘hours_per_week}{‘关系’}{‘性’}{‘workClass}{“占领”}

评估新观察结果的预测性能

通过使用测试样本来评估训练模型的性能成人目标函数预测损失边缘,保证金.您可以使用带有这些功能的完整或紧凑模型。

  • 预测——分类的观察

  • 损失-计算分类丢失(默认为十进制错误分类率)

  • 保证金-计算分类裕度

  • 边缘-计算分类边缘(分类边缘的平均值)

如果你想评估训练数据集的性能,可以使用替换对象函数:resubPredictresubLossresubMargin,resubEdge.要使用这些函数,必须使用包含训练数据的完整模型。

创建一个紧凑的模型,以减少训练过的模型的大小。

CMdl =紧凑(Mdl);谁(“Mdl”“CMdl”
名称大小字节类属性CMdl 1x1 5126918 classreg.learning.classif.CompactClassificationGAM Mdl 1x1 5272831 ClassificationGAM

预测测试数据集的标签和分数(成人),并使用测试数据集计算模型统计信息(损耗、边际和边缘)。

(标签、分数)=预测(CMdl,成人);L =损失(CMdl,成人,“重量”, adulttest.fnlwgt);M =利润率(CMdl、成人);E =边缘(CMdl成人,“重量”, adulttest.fnlwgt);

预测标签和分数并计算统计数据,而不包括训练的模型中的交互项。

[labels_nointeraction, scores_nointeraction] =预测(CMdl,成人,“IncludeInteractions”、假);L_nointeractions =损失(CMdl,成人,“重量”adulttest.fnlwgt,“IncludeInteractions”、假);M_nointeractions =利润率(CMdl,成人,“IncludeInteractions”、假);E_nointeractions =边缘(CMdl,成人,“重量”adulttest.fnlwgt,“IncludeInteractions”、假);

将包含线性项和相互作用项的结果与只包含线性项的结果进行比较。

从真实的标签创建一个混淆表adulttest.salary以及预测的标签。

tiledlayout(1、2);nexttile confusionchart (adulttest.salary、标签)标题(“线性和交互术语”nexttile (adulttest.salary,labels_nointeraction)“线性条件仅”

显示计算损失和边缘值。

表([L;E]、[L_nointeractions;E_nointeractions),...“VariableNames”, {“线性和交互术语”“只有线性条件”},...“RowNames”, {“损失”“边缘”})
ans =2×2表线性和交互项只有线性项  ____________________________ _________________ 损失0.1748 - 0.17872 0.57902 - 0.54756

当考虑线性项和交互项时,模型的损失值较小,边缘值较高。

使用框图显示边缘的分布。

图箱线图([M M_nointeractions),“标签”, {“线性和交互术语”“线性条件仅”})标题(“测试样本裕度箱线图”

解释预测

解释第一次测试观测结果的预测plotLocalEffects函数。此外,还为模型中的一些重要项创建偏依赖图plotPartialDependence函数。

对第一次观察到的测试数据进行分类,并绘制出其中项的局部效应CMdl在预测。若要在任何预测器名称中显示现有下划线,请更改TickLabelInterpreter坐标轴的值“没有”

(标签,分数)=预测(CMdl成人(1,:))
标签=分类< = 50 k
分数=1×20.9895 - 0.0105
f1 =图;plotLocalEffects (CMdl成人(1:))f1.CurrentAxes。TickLabelInterpreter =“没有”

预测函数对第一个观察结果进行分类:成人(1)作为“< = 50 k”.的plotLocalEffects函数创建一个水平柱状图,显示10个最重要的术语对预测的局部影响。每个局部效应值显示了每个术语对分类分数的贡献“< = 50 k”,即分类的后验概率的logit“< = 50 k”的观察。

为该术语创建一个偏依赖图年龄.指定训练和测试数据集,使用这两个数据集计算部分依赖值。

图plotPartialDependence (CMdl,“年龄”、标签、[adultdata;成人)

标线表示预测因子之间的平均部分关系年龄和班级成绩< = 50 k在训练过的模型中。的x-轴小刻度表示预测器中的唯一值年龄

为这些项创建偏依赖图education_num而且的关系

f2 =图;plotPartialDependence (CMdl [“education_num”“关系”),标签,[adultdata;成人)f2.CurrentAxes。TickLabelInterpreter =“没有”;视图(40 [55])

该图显示了班级分数值的部分依赖性< = 50education_num而且的关系

另请参阅

||||||

相关的话题

Baidu
map