Abstract
机器学习调参的思路都异曲同工,首先确定一个参数池,也就是模型参数值的可选范围。从这个池子中挑选出不同的参数组合,对于每个组合都计算其预测精度,最后选取预测精度最高的参数组合。
机器学习调参的思路都异曲同工,首先确定一个参数池,也就是模型参数值的可选范围。从这个池子中挑选出不同的参数组合,对于每个组合都计算其预测精度,最后选取预测精度最高的参数组合。
调参的过程就像是找人生伴侣的过程,首先我们有一个标准,比如身高、体重等,符合这个标准的异性将进入到参数池中。然后我们跟参数池中的每个异性谈恋爱,找到最适合我们的那个作为终极选择。接下来,介绍两种常见的调参方法:网格搜索与随机搜索。
knitr::include_graphics(here::here("Machine_Learning_and_Causal_Inference/fig/modify parameters.png"))
网格搜索首先会有一个标准,将符合标准的参数放入参数池中,形成不同的参数组合。而随机搜索则不同,随机搜索没有标准,随机地组合参数。依然以找男友为例,假设参数有3个:身高、体重、年龄。
网格搜索会对这3个参数设定一个范围,比如身高>180厘米,体重小于<140斤,年龄在20~40岁之间。但是随机搜索则不同,有些女性觉得如果设定了择偶条件,反而容易错过自己喜欢的,也许适合自己的恰好身高只有179厘米。
这两种不同的搜索方式出来的参数组合是不同的。两者各有优缺点,随机搜索与网格搜索相比,其优点在于能随机地遍历所有参数空间,但是缺点也很明显:不知道随机出来的是什么类型的人。下面分别看看两种搜索方式的实现。
首先来看在caret包
中如何轻轻松松实现网格搜索。
第一步:设置随机种子,保证实验的可重复性;
第二步:利用traincontrol()函数
设置模型训练时用到的参数。其中method
表示重抽样方法。此处,cv
表示交叉验证,number
表示几折交叉验证,本例中是10折交叉验证。10折交叉验证表示,首先将样本分为10个组,每次训练的时候抽取其中9组作为训练集,剩下的1组作为测试集。classProbs参数
表示是否计算类别概率,如果评价指标为AUC,那么这里一定要设置为TRUE。由于因变量为两水平变量,所以summaryFunction
这里设置为twoClassSummary
。
第三步:设置网格搜索的参数池,也就是设定参数的选择范围。这里以机器学习中的gbm(Gradient boosting machine)方法
为例,所以有4个超参数需要设定,分别为迭代次数(n.trees)
,树的复杂度(interaction.depth)
,学习率(shrinkage)
,训练样本的最小数目(n.minobsinnode)
。这里设定了60组参数组合。
set.seed(1234)
fit_control_gbm <- trainControl(method = "cv",
number = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary)
grid_gbm <- expand.grid(interaction.depth = c(1,5,9),
n.trees = (1:20) * 50,
shrinkage = 0.1,
n.minobsinnode = 20)
grid_gbm
train()函数
来进行模型训练及得到最优参数组合。该函数会遍历第三步得到的所有参数组合,并得到使评价指标最大的参数组合作为输出。method
表示使用的模型,本例使用机器学习中的gbm(Gradient boosting machine)
模型,使用的评价指标为ROC曲线面积
(即AUC值)。data <- data.table::fread(here::here("Machine_Learning_and_Causal_Inference/data/相亲数据重新编码.csv"))
data %<>% as.data.frame()
data
## 'data.frame': 8378 obs. of 29 variables:
## $ 决定 : int 1 1 1 1 1 0 1 0 1 1 ...
## $ 性别 : int 0 0 0 0 0 0 0 0 0 0 ...
## $ 吸引力 : num 6 7 5 7 5 4 7 4 7 5 ...
## $ 共同爱好 : num 5 6 7 8 6 4 7 6 8 8 ...
## $ 幽默 : num 7 8 8 7 7 4 4 6 9 8 ...
## $ 真诚 : num 9 8 8 6 6 9 6 9 6 6 ...
## $ 雄心 : num 6 5 5 6 6 6 6 5 8 10 ...
## $ 智力 : num 7 7 9 8 7 7 7 7 8 6 ...
## $ 好感 : num 7 7 7 7 6 6 6 6 7 6 ...
## $ 成功率自估 : num 6 5 NA 6 6 5 5 7 7 6 ...
## $ 日常出门频率 : int 1 1 1 1 1 1 1 1 1 1 ...
## $ 对宗教的看重程度: int 4 4 4 4 4 4 4 4 4 4 ...
## $ 对种族的看重程度: int 2 2 2 2 2 2 2 2 2 2 ...
## $ 年龄 : int 21 21 21 21 21 21 21 21 21 21 ...
## $ 种族 : int 4 4 4 4 4 4 4 4 4 4 ...
## $ 从事领域 : int 1 1 1 1 1 1 1 1 1 1 ...
## $ 对方决定 : int 0 0 1 1 1 1 0 0 1 0 ...
## $ 好感得分 : num 7 8 10 7 8 7 2 7 6.5 6 ...
## $ 对方评估成功率 : num 4 4 10 7 6 6 1 5 8 6 ...
## $ 吸引力得分 : num 6 7 10 7 8 7 3 6 7 6 ...
## $ 共同爱好得分 : num 6 5 10 8 7 7 7 6 9 6 ...
## $ 幽默得分 : num 8 7 10 8 6 8 5 6 8 6 ...
## $ 真诚得分 : num 8 8 10 8 7 7 6 7 7 6 ...
## $ 雄心得分 : num 8 7 10 9 9 7 8 8 8 6 ...
## $ 智力得分 : num 8 10 10 9 9 8 7 5 8 6 ...
## $ 对方年龄 : int 27 22 22 23 24 25 30 27 28 24 ...
## $ 对方种族 : int 2 2 4 2 3 2 2 2 2 2 ...
## $ 是否同一种族 : int 0 0 1 0 0 0 0 0 0 0 ...
## $ 日常约会频率 : int 7 7 7 7 7 7 7 7 7 7 ...
Name | Piped data |
Number of rows | 8378 |
Number of columns | 29 |
_______________________ | |
Column type frequency: | |
numeric | 29 |
________________________ | |
Group variables | None |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
决定 | 0 | 1.00 | 0.42 | 0.49 | 0 | 0 | 0 | 1 | 1.0 | ▇▁▁▁▆ |
性别 | 0 | 1.00 | 0.50 | 0.50 | 0 | 0 | 1 | 1 | 1.0 | ▇▁▁▁▇ |
吸引力 | 202 | 0.98 | 6.19 | 1.95 | 0 | 5 | 6 | 8 | 10.0 | ▁▃▇▇▂ |
共同爱好 | 1067 | 0.87 | 5.47 | 2.16 | 0 | 4 | 6 | 7 | 10.0 | ▂▅▇▆▂ |
幽默 | 350 | 0.96 | 6.40 | 1.95 | 0 | 5 | 7 | 8 | 10.0 | ▁▂▇▇▂ |
真诚 | 277 | 0.97 | 7.18 | 1.74 | 0 | 6 | 7 | 8 | 10.0 | ▁▁▃▇▃ |
雄心 | 712 | 0.92 | 6.78 | 1.79 | 0 | 6 | 7 | 8 | 10.0 | ▁▂▆▇▃ |
智力 | 296 | 0.96 | 7.37 | 1.55 | 0 | 6 | 7 | 8 | 10.0 | ▁▁▃▇▃ |
好感 | 240 | 0.97 | 6.13 | 1.84 | 0 | 5 | 6 | 7 | 10.0 | ▁▃▇▇▂ |
成功率自估 | 309 | 0.96 | 5.21 | 2.13 | 0 | 4 | 5 | 7 | 10.0 | ▂▅▇▅▁ |
日常出门频率 | 79 | 0.99 | 2.16 | 1.11 | 1 | 1 | 2 | 3 | 7.0 | ▇▃▁▁▁ |
对宗教的看重程度 | 79 | 0.99 | 3.65 | 2.81 | 1 | 1 | 3 | 6 | 10.0 | ▇▃▃▂▁ |
对种族的看重程度 | 79 | 0.99 | 3.78 | 2.85 | 0 | 1 | 3 | 6 | 10.0 | ▇▃▂▂▂ |
年龄 | 95 | 0.99 | 26.36 | 3.57 | 18 | 24 | 26 | 28 | 55.0 | ▇▇▁▁▁ |
种族 | 63 | 0.99 | 2.76 | 1.23 | 1 | 2 | 2 | 4 | 6.0 | ▇▁▃▁▁ |
从事领域 | 82 | 0.99 | 7.66 | 3.76 | 1 | 5 | 8 | 10 | 18.0 | ▃▃▇▂▁ |
对方决定 | 0 | 1.00 | 0.42 | 0.49 | 0 | 0 | 0 | 1 | 1.0 | ▇▁▁▁▆ |
好感得分 | 250 | 0.97 | 6.13 | 1.84 | 0 | 5 | 6 | 7 | 10.0 | ▁▃▇▇▂ |
对方评估成功率 | 318 | 0.96 | 5.21 | 2.13 | 0 | 4 | 5 | 7 | 10.0 | ▂▅▇▅▁ |
吸引力得分 | 212 | 0.97 | 6.19 | 1.95 | 0 | 5 | 6 | 8 | 10.5 | ▁▃▇▇▂ |
共同爱好得分 | 1076 | 0.87 | 5.47 | 2.16 | 0 | 4 | 6 | 7 | 10.0 | ▂▅▇▆▂ |
幽默得分 | 360 | 0.96 | 6.40 | 1.95 | 0 | 5 | 7 | 8 | 11.0 | ▁▂▇▇▂ |
真诚得分 | 287 | 0.97 | 7.18 | 1.74 | 0 | 6 | 7 | 8 | 10.0 | ▁▁▃▇▃ |
雄心得分 | 722 | 0.91 | 6.78 | 1.79 | 0 | 6 | 7 | 8 | 10.0 | ▁▂▆▇▃ |
智力得分 | 306 | 0.96 | 7.37 | 1.55 | 0 | 6 | 7 | 8 | 10.0 | ▁▁▃▇▃ |
对方年龄 | 104 | 0.99 | 26.36 | 3.56 | 18 | 24 | 26 | 28 | 55.0 | ▇▇▁▁▁ |
对方种族 | 73 | 0.99 | 2.76 | 1.23 | 1 | 2 | 2 | 4 | 6.0 | ▇▁▃▁▁ |
是否同一种族 | 0 | 1.00 | 0.40 | 0.49 | 0 | 0 | 0 | 1 | 1.0 | ▇▁▁▁▅ |
日常约会频率 | 97 | 0.99 | 5.01 | 1.44 | 1 | 4 | 5 | 6 | 7.0 | ▁▂▅▃▇ |
现实生活中,在数据分析时,经常会碰到缺失值,比如相亲数据中,有些女性不愿意暴露自己的年龄,年龄就会有缺失值。那么对于缺失值,怎么处理呢?处理方式很多,甚至有时候数据缺失本身也暗含一些信息(比如年龄缺失的女性可能是因为年龄比较大),由此引申了许多插补方法。不过这里缺失值处理并不是重点,因此对于缺失值直接删除即可。
零方差或者近零方差的变量传递不了什么信息,因为几乎所有人的取值都一样。可以利用caret包
中的nearZeroVar()函数
,一行代码就能找出近零方差的变量,操作过程非常简单。
## integer(0)
对于完整的观测,首先需要定义变量的类型:属于定性变量还是连续变量。对于定性变量而言,需要给定性变量的各个水平取名,比如性别有两个水平1和0,分别命名为男、女。
data %>%
mutate(决定 = factor(决定,
levels = c(0, 1),
labels = c("拒绝", "接受"))) %>%
mutate(性别 = factor(性别,
levels = c(0, 1),
labels = c("女", "男"))) %>%
mutate(种族 = factor(
种族,
levels = c(1, 2, 3, 4, 5, 6),
labels = c("非洲裔", "欧洲裔", "拉丁裔", "亚裔", "印第安土著", "其他")
)) %>%
mutate(从事领域 = factor(
从事领域,
levels = 1:18,
labels = c(
"法律",
"数学",
"社会科学或心理学",
"医学或药物学或生物技术",
"工程学",
"写作或新闻",
"历史或宗教或哲学",
"商业或经济或金融",
"教育或学术",
"生物科学或化学或物理",
"社会工作",
"大学在读或未择方向",
"政治学或国际事务",
"电影",
"艺术管理",
"语言",
"建筑学",
"其他"
)
)) %>%
mutate(对方决定 = factor(对方决定,
levels = 0:1,
labels = c("拒绝", "接收"))) %>%
mutate(对方种族 = factor(
对方种族,
levels = c(1, 2, 3, 4, 5, 6),
labels = c("非洲裔", "欧洲裔", "拉丁裔", "亚裔", "印第安土著", "其他")
)) %>%
mutate(是否同一种族 = factor(
是否同一种族,
levels = c(0, 1),
labels = c("非同一种族", "同一种族")
)) -> data
data %>% map(unique)
## $决定
## [1] 接受 拒绝
## Levels: 拒绝 接受
##
## $性别
## [1] 女 男
## Levels: 女 男
##
## $吸引力
## [1] 6.0 7.0 5.0 4.0 8.0 9.0 3.0 10.0 2.0 1.0 0.0 6.5 7.5 9.5 8.5
## [16] 9.9 3.5
##
## $共同爱好
## [1] 5.0 6.0 8.0 4.0 7.0 3.0 2.0 9.0 10.0 1.0 0.0 7.5 6.5 8.5 5.5
##
## $幽默
## [1] 7.0 8.0 4.0 6.0 9.0 3.0 5.0 10.0 2.0 1.0 0.0 5.5 6.5 9.5 7.5
## [16] 8.5
##
## $真诚
## [1] 9.0 8.0 6.0 7.0 5.0 10.0 4.0 3.0 2.0 1.0 0.0 8.5 7.5
##
## $雄心
## [1] 6.0 5.0 8.0 10.0 9.0 3.0 7.0 4.0 2.0 1.0 0.0 9.5 7.5 8.5
##
## $智力
## [1] 7.0 8.0 6.0 9.0 10.0 5.0 4.0 3.0 2.0 1.0 0.0 6.5 8.5 7.5 5.5
##
## $好感
## [1] 7.0 6.0 8.0 5.0 9.0 4.0 10.0 2.0 3.0 6.5 1.0 8.5 9.5 0.0 7.5
## [16] 5.5 4.5 9.7
##
## $成功率自估
## [1] 6.0 5.0 7.0 4.0 3.0 8.0 1.0 10.0 2.0 9.0 0.0 6.5 7.5 8.5 9.5
## [16] 5.5 3.5 4.5
##
## $日常出门频率
## [1] 1 4 2 3 5 7 6
##
## $对宗教的看重程度
## [1] 4 5 1 3 2 8 10 6 7 9
##
## $对种族的看重程度
## [1] 2 8 1 4 7 3 9 10 5 6 0
##
## $年龄
## [1] 21 24 25 23 22 26 27 30 28 29 34 35 32 20 19 18 33 36 31 42 38 55
##
## $种族
## [1] 亚裔 欧洲裔 其他 拉丁裔 非洲裔
## Levels: 非洲裔 欧洲裔 拉丁裔 亚裔 印第安土著 其他
##
## $从事领域
## [1] 法律 数学 政治学或国际事务
## [4] 商业或经济或金融 工程学 教育或学术
## [7] 社会科学或心理学 社会工作 大学在读或未择方向
## [10] 医学或药物学或生物技术 历史或宗教或哲学 写作或新闻
## [13] 生物科学或化学或物理 电影 语言
## [16] 艺术管理 建筑学 其他
## 18 Levels: 法律 数学 社会科学或心理学 医学或药物学或生物技术 ... 其他
##
## $对方决定
## [1] 拒绝 接收
## Levels: 拒绝 接收
##
## $好感得分
## [1] 7.0 8.0 2.0 6.5 6.0 10.0 9.0 4.0 5.0 3.0 1.0 9.5 7.5 4.5 8.5
## [16] 5.5 0.0 9.7
##
## $对方评估成功率
## [1] 4.0 7.0 6.0 1.0 5.0 8.0 2.0 10.0 3.0 9.0 4.5 6.5 5.5 0.0 7.5
## [16] 8.5 9.5 3.5
##
## $吸引力得分
## [1] 6.0 7.0 8.0 3.0 10.0 9.0 5.0 4.0 2.0 1.0 0.0 6.5 7.5 8.5 9.5
## [16] 9.9 3.5
##
## $共同爱好得分
## [1] 6.0 5.0 8.0 7.0 9.0 4.0 10.0 3.0 2.0 1.0 7.5 6.5 8.5 0.0 5.5
##
## $幽默得分
## [1] 8.0 7.0 6.0 5.0 9.0 10.0 3.0 4.0 2.0 1.0 5.5 6.5 9.5 0.0 7.5
## [16] 8.5 11.0
##
## $真诚得分
## [1] 8.0 7.0 6.0 10.0 9.0 3.0 5.0 4.0 2.0 1.0 4.5 8.5 0.0 7.5
##
## $雄心得分
## [1] 8.0 7.0 9.0 6.0 10.0 5.0 4.0 2.0 3.0 1.0 5.5 9.5 0.0 7.5 8.5
##
## $智力得分
## [1] 8.0 10.0 9.0 7.0 5.0 6.0 4.0 3.0 1.0 2.0 6.5 8.5 0.0 7.5 5.5
##
## $对方年龄
## [1] 27 22 23 24 25 30 28 21 26 29 32 35 34 18 20 19 33 31 36 42 38 55
##
## $对方种族
## [1] 欧洲裔 拉丁裔 亚裔 其他 非洲裔
## Levels: 非洲裔 欧洲裔 拉丁裔 亚裔 印第安土著 其他
##
## $是否同一种族
## [1] 非同一种族 同一种族
## Levels: 非同一种族 同一种族
##
## $日常约会频率
## [1] 7 5 3 4 6 1 2
set.seed(1234)
data_id <- createDataPartition(y = data$决定,p = 0.7,list = FALSE,times = 1)
data_training <- data[data_id,]
data_testing <- data[-data_id,]
data_training %$% table(决定) %>% prop.table()
data_testing %$% table(决定) %>% prop.table()
data %$% table(决定) %>% prop.table()
## 决定
## 拒绝 接受
## 0.56002 0.43998
## 决定
## 拒绝 接受
## 0.5600233 0.4399767
## 决定
## 拒绝 接受
## 0.560021 0.439979
# model_gbm <- train(决定 ~ .,
# data = data_training,
# method = "gbm", # 方法
# trControl = fit_control_gbm,
# verbose = FALSE,
# tuneGrid = grid_gbm, # 网格搜索结果
# metric = "ROC") # 指标
data_training %>% dim()
## [1] 4007 29
# save(model_gbm,file = "Machine_Learning_and_Causal_Inference/result/model_gbm.RData")
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm.RData"))
model_gbm
## Stochastic Gradient Boosting
##
## 4007 samples
## 28 predictor
## 2 classes: '拒绝', '接受'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3607, 3607, 3606, 3606, 3606, 3607, ...
## Resampling results across tuning parameters:
##
## interaction.depth n.trees ROC Sens Spec
## 1 50 0.8458812 0.7803234 0.7271540
## 1 100 0.8547180 0.7852143 0.7430566
## 1 150 0.8585245 0.7914623 0.7481638
## 1 200 0.8601837 0.7923552 0.7521347
## 1 250 0.8608241 0.7896865 0.7481638
## 1 300 0.8607350 0.7883452 0.7493002
## 1 350 0.8603262 0.7914702 0.7470275
## 1 400 0.8607127 0.7874544 0.7470339
## 1 450 0.8607112 0.7910119 0.7464625
## 1 500 0.8603422 0.7914643 0.7487256
## 1 550 0.8599789 0.7928056 0.7510048
## 1 600 0.8601791 0.7932500 0.7447772
## 1 650 0.8606016 0.7932460 0.7504430
## 1 700 0.8601213 0.7945833 0.7527125
## 1 750 0.8597893 0.7950258 0.7481542
## 1 800 0.8591014 0.7972560 0.7475892
## 1 850 0.8583529 0.7941409 0.7481478
## 1 900 0.8583318 0.7928056 0.7470178
## 1 950 0.8578740 0.7914702 0.7453197
## 1 1000 0.8581766 0.7896865 0.7458911
## 5 50 0.8637739 0.8083948 0.7436120
## 5 100 0.8679483 0.8043948 0.7493002
## 5 150 0.8708205 0.8079444 0.7561344
## 5 200 0.8720001 0.8101746 0.7538617
## 5 250 0.8730226 0.8141944 0.7555598
## 5 300 0.8738410 0.8150952 0.7532839
## 5 350 0.8741765 0.8097480 0.7572515
## 5 400 0.8738580 0.8084008 0.7634823
## 5 450 0.8751164 0.8119722 0.7617777
## 5 500 0.8747352 0.8097440 0.7589529
## 5 550 0.8743121 0.8026091 0.7527253
## 5 600 0.8756294 0.8070655 0.7606703
## 5 650 0.8749445 0.8057262 0.7612160
## 5 700 0.8753659 0.8070655 0.7612128
## 5 750 0.8759504 0.8079524 0.7617938
## 5 800 0.8763178 0.8088512 0.7663328
## 5 850 0.8756155 0.8084067 0.7617970
## 5 900 0.8762538 0.8075139 0.7652061
## 5 950 0.8765378 0.8066230 0.7612160
## 5 1000 0.8761628 0.8061726 0.7617842
## 9 50 0.8665025 0.8092976 0.7436441
## 9 100 0.8716481 0.8159901 0.7544331
## 9 150 0.8735060 0.8146369 0.7544331
## 9 200 0.8747870 0.8222143 0.7515858
## 9 250 0.8768251 0.8240040 0.7538392
## 9 300 0.8769973 0.8213294 0.7532775
## 9 350 0.8772060 0.8182103 0.7549692
## 9 400 0.8767967 0.8155357 0.7509951
## 9 450 0.8778824 0.8199921 0.7527029
## 9 500 0.8781857 0.8159861 0.7532743
## 9 550 0.8785753 0.8231171 0.7515858
## 9 600 0.8792980 0.8204444 0.7476181
## 9 650 0.8795248 0.8226627 0.7521475
## 9 700 0.8793808 0.8213254 0.7538521
## 9 750 0.8801409 0.8248889 0.7527221
## 9 800 0.8797453 0.8186508 0.7544171
## 9 850 0.8799191 0.8217639 0.7544203
## 9 900 0.8794603 0.8217659 0.7487705
## 9 950 0.8796739 0.8231032 0.7510112
## 9 1000 0.8797459 0.8199901 0.7487481
##
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
##
## Tuning parameter 'n.minobsinnode' was held constant at a value of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 750, interaction.depth =
## 9, shrinkage = 0.1 and n.minobsinnode = 20.
第五步:模型会自动确定ROC曲线面积最大(即AUC值最高)的参数组合,也就是图3中最高的点对应的参数组合,对应的AUC值为90.14%。
## Loaded gbm 2.1.8
随机搜索与网格搜索相比,参数的选择没有固定的范围,最终的结果可能好也可能坏。它的实现步骤如下:
trainControl()函数
设定模型训练的参数,但是多了一项:search=”random”
。fit_control_gbm_random <- trainControl(method = "cv",
number = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary,
search = "random") # 随机搜索
tuneGrid参数
,只需要设置参数tuneLength
(随机搜索多少组)。# model_gbm_random <- train(决定~.,data =data_training,
# method = "gbm", # 方法
# trControl = fit_control_gbm_random,
# verbose = FALSE,
# metric = "ROC", # 指标
# tuneLength = 30)
# save(model_gbm_random,file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random.RData"))
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random.RData"))
model_gbm_random
## Stochastic Gradient Boosting
##
## 4007 samples
## 28 predictor
## 2 classes: '拒绝', '接受'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3606, 3607, 3606, 3606, 3605, 3607, ...
## Resampling results across tuning parameters:
##
## shrinkage interaction.depth n.minobsinnode n.trees ROC Sens
## 0.004834566 10 13 4378 0.8830347 0.8204087
## 0.017773097 6 18 2940 0.8801644 0.8088194
## 0.030012312 10 7 939 0.8828093 0.8199623
## 0.065353448 2 12 3466 0.8705157 0.8038948
## 0.076775003 7 24 2421 0.8743527 0.8168313
## 0.098133025 9 13 4482 0.8774781 0.8204087
## 0.120162734 7 6 3408 0.8794093 0.8163889
## 0.170841822 4 18 4080 0.8647263 0.8043631
## 0.211689456 2 5 966 0.8655478 0.7949960
## 0.224400530 5 21 2407 0.8651855 0.8074742
## 0.281793601 5 8 3231 0.8693740 0.8123810
## 0.301029938 4 11 32 0.8619964 0.8088294
## 0.311465749 4 20 1499 0.8621497 0.8101647
## 0.340869527 2 20 3095 0.8544660 0.7950317
## 0.341748821 9 20 1300 0.8642645 0.8101647
## 0.349149748 4 17 4122 0.8613487 0.8030417
## 0.367269076 3 14 1158 0.8583283 0.8034603
## 0.380703725 1 6 1745 0.8504483 0.7838492
## 0.384781755 5 7 143 0.8551641 0.8003790
## 0.402804276 6 6 4050 0.8751762 0.8164107
## 0.433879408 8 7 220 0.8560501 0.7972421
## 0.446740546 5 21 3244 0.8630270 0.8115119
## 0.463837192 2 20 207 0.8489684 0.7909742
## 0.466876172 3 25 3485 0.8535598 0.7981329
## 0.474556008 10 6 1790 0.8714595 0.8190556
## 0.483829035 7 22 1094 0.8482503 0.7874286
## 0.516980422 6 12 4603 0.8675019 0.8088175
## 0.562797687 3 18 2735 0.8544221 0.8047996
## 0.586171088 10 18 4890 0.8635503 0.7901091
## 0.591066128 8 16 547 0.8535618 0.8119365
## Spec
## 0.7617264
## 0.7634438
## 0.7594536
## 0.7571905
## 0.7639927
## 0.7611486
## 0.7640055
## 0.7532101
## 0.7537879
## 0.7452844
## 0.7509470
## 0.7430117
## 0.7481061
## 0.7526451
## 0.7509534
## 0.7430342
## 0.7412975
## 0.7373941
## 0.7356317
## 0.7571841
## 0.7390280
## 0.7458430
## 0.7316866
## 0.7418657
## 0.7464079
## 0.7407550
## 0.7452908
## 0.7362513
## 0.7600122
## 0.7288649
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 4378, interaction.depth
## = 10, shrinkage = 0.004834566 and n.minobsinnode = 13.
确定最优参数之后,模型如何进行预测呢?使用predict()函数
,只要输入模型及测试集,就可以预测了。然后利用confusionMatrix()函数
输入真实的Y与预测的Y就可以得到混淆矩阵(Confusion Matrix)
。
网格搜索的参数与随机搜索的参数的预测结果有什么区别呢?下面的操作结果可以明显看出两者的区别。
model_gbm_pre <- predict(model_gbm,newdata = data_testing)
confusionMatrix(model_gbm_pre,data_testing$决定)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 拒绝 接受
## 拒绝 807 172
## 接受 154 583
##
## Accuracy : 0.81
## 95% CI : (0.7906, 0.8283)
## No Information Rate : 0.56
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6135
##
## Mcnemar's Test P-Value : 0.3464
##
## Sensitivity : 0.8398
## Specificity : 0.7722
## Pos Pred Value : 0.8243
## Neg Pred Value : 0.7910
## Prevalence : 0.5600
## Detection Rate : 0.4703
## Detection Prevalence : 0.5705
## Balanced Accuracy : 0.8060
##
## 'Positive' Class : 拒绝
##
model_gbm_random_pre <- predict(model_gbm_random,newdata = data_testing)
confusionMatrix(model_gbm_random_pre,data_testing$决定)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 拒绝 接受
## 拒绝 814 171
## 接受 147 584
##
## Accuracy : 0.8147
## 95% CI : (0.7955, 0.8328)
## No Information Rate : 0.56
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6227
##
## Mcnemar's Test P-Value : 0.1971
##
## Sensitivity : 0.8470
## Specificity : 0.7735
## Pos Pred Value : 0.8264
## Neg Pred Value : 0.7989
## Prevalence : 0.5600
## Detection Rate : 0.4744
## Detection Prevalence : 0.5740
## Balanced Accuracy : 0.8103
##
## 'Positive' Class : 拒绝
##
标准化处理是指将数据处理为均值为0、标准差为1的数据。那么为什么要进行标准化处理呢?因为在进行实证分析时,有些变量取值很大,有些变量取值很小,这里需要营造一个公平公正的环境,权重的大小不能被自身变量取值的大小所束缚。比如在判断一个女生是否是美女时,会考虑腿长、脸长、脸宽、腰围等因素,这些因素的学名为特征。显然腿长的取值比脸长的取值大得多,这时为了防止腿长的权重过高,就需要将这些特征进行标准化才能学习各个变量真实的权重。
标准化处理时,只能利用训练集的均值与标准差对训练集和测试集进行标准化。
# model_gbm_std <- train(决定 ~ .,
# data = data_training_std, # 数据已经标准化
# method = "gbm", # 方法
# trControl = fit_control_gbm,
# verbose = FALSE,
# tuneGrid = grid_gbm, # 网格搜索
# metric = "ROC") # 指标
data_training_std %>% dim()
## [1] 4007 29
# save(model_gbm_std,file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_std.RData"))
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_std.RData"))
model_gbm_std
## Stochastic Gradient Boosting
##
## 4007 samples
## 28 predictor
## 2 classes: '拒绝', '接受'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3607, 3607, 3606, 3606, 3606, 3607, ...
## Resampling results across tuning parameters:
##
## interaction.depth n.trees ROC Sens Spec
## 1 50 0.8458685 0.7803234 0.7271540
## 1 100 0.8547154 0.7852143 0.7430566
## 1 150 0.8584942 0.7910159 0.7481638
## 1 200 0.8600550 0.7923552 0.7527029
## 1 250 0.8609050 0.7901310 0.7464625
## 1 300 0.8606871 0.7896825 0.7498684
## 1 350 0.8602378 0.7914702 0.7470243
## 1 400 0.8605659 0.7892361 0.7475989
## 1 450 0.8604984 0.7927976 0.7464625
## 1 500 0.8603470 0.7923571 0.7498523
## 1 550 0.8598781 0.7941409 0.7510015
## 1 600 0.8600794 0.7941349 0.7493066
## 1 650 0.8607604 0.7950258 0.7527061
## 1 700 0.8596862 0.7927996 0.7515761
## 1 750 0.8595700 0.7950238 0.7470307
## 1 800 0.8590555 0.7941329 0.7453294
## 1 850 0.8583105 0.7927956 0.7504269
## 1 900 0.8584066 0.7932440 0.7487256
## 1 950 0.8577257 0.7914683 0.7453197
## 1 1000 0.8580634 0.7914742 0.7470339
## 5 50 0.8638778 0.8088413 0.7419074
## 5 100 0.8683209 0.8035099 0.7487352
## 5 150 0.8709084 0.8119563 0.7601053
## 5 200 0.8720089 0.8150794 0.7549884
## 5 250 0.8732528 0.8137440 0.7566801
## 5 300 0.8737551 0.8177659 0.7600764
## 5 350 0.8743140 0.8173254 0.7583686
## 5 400 0.8746825 0.8124206 0.7572355
## 5 450 0.8758394 0.8173234 0.7532743
## 5 500 0.8761987 0.8164246 0.7561216
## 5 550 0.8762954 0.8177659 0.7533032
## 5 600 0.8771053 0.8168651 0.7601149
## 5 650 0.8766423 0.8173115 0.7652189
## 5 700 0.8768278 0.8186488 0.7623844
## 5 750 0.8771631 0.8195437 0.7606574
## 5 800 0.8774133 0.8155417 0.7618034
## 5 850 0.8772425 0.8137440 0.7657646
## 5 900 0.8779780 0.8195337 0.7663232
## 5 950 0.8774379 0.8150754 0.7685927
## 5 1000 0.8781535 0.8195437 0.7680277
## 9 50 0.8667744 0.8106349 0.7425045
## 9 100 0.8714460 0.8159782 0.7510304
## 9 150 0.8730945 0.8141984 0.7533064
## 9 200 0.8743664 0.8231052 0.7510208
## 9 250 0.8756663 0.8239921 0.7566834
## 9 300 0.8756694 0.8195317 0.7561120
## 9 350 0.8749122 0.8173135 0.7538424
## 9 400 0.8741921 0.8164206 0.7583686
## 9 450 0.8743759 0.8155357 0.7623299
## 9 500 0.8747896 0.8150774 0.7606382
## 9 550 0.8756587 0.8186548 0.7629109
## 9 600 0.8761933 0.8190992 0.7578165
## 9 650 0.8760428 0.8204286 0.7566705
## 9 700 0.8762266 0.8204286 0.7515761
## 9 750 0.8766359 0.8213294 0.7549820
## 9 800 0.8765712 0.8266627 0.7527125
## 9 850 0.8764037 0.8235476 0.7566801
## 9 900 0.8766565 0.8253313 0.7538424
## 9 950 0.8765870 0.8217639 0.7510080
## 9 1000 0.8765769 0.8204325 0.7510144
##
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
##
## Tuning parameter 'n.minobsinnode' was held constant at a value of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 1000, interaction.depth =
## 5, shrinkage = 0.1 and n.minobsinnode = 20.
fit_control_gbm_random <- trainControl(method = "cv",
number = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary,
search = "random") # 随机搜索
# model_gbm_random_std <- train(
# 决定 ~ .,
# data = data_training_std, # 数据已经标准化
# method = "gbm",
# trControl = fit_control_gbm_random,
# verbose = FALSE,
# metric = "ROC",
# tuneLength = 30 # 随机搜索
# )
# save(model_gbm_random_std,file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random_std.RData"))
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random_std.RData"))
model_gbm_random_std
## Stochastic Gradient Boosting
##
## 4007 samples
## 28 predictor
## 2 classes: '拒绝', '接受'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3607, 3606, 3607, 3607, 3606, 3606, ...
## Resampling results across tuning parameters:
##
## shrinkage interaction.depth n.minobsinnode n.trees ROC Sens
## 0.05639147 6 10 213 0.8730341 0.8177183
## 0.06563599 9 17 4530 0.8779457 0.8128056
## 0.08328927 5 5 3468 0.8797575 0.8230575
## 0.09511411 3 8 4758 0.8762998 0.8190456
## 0.12019083 2 11 725 0.8703719 0.8105893
## 0.13293341 6 23 2669 0.8731750 0.8203889
## 0.14990818 3 15 3684 0.8682060 0.8088016
## 0.15548045 10 22 2601 0.8710143 0.8154881
## 0.16083543 5 17 3884 0.8707270 0.8101389
## 0.19761656 3 10 1590 0.8653865 0.8074762
## 0.19882499 7 24 1869 0.8691928 0.8029881
## 0.23788788 5 24 1733 0.8639957 0.8101548
## 0.25513100 3 18 251 0.8588094 0.8070159
## 0.32942565 4 9 3004 0.8660796 0.8016825
## 0.33407670 10 16 3788 0.8720841 0.8074524
## 0.35872628 9 24 2169 0.8684708 0.8114603
## 0.38207893 10 23 4830 0.8708166 0.8141488
## 0.39522148 10 21 4059 0.8665400 0.8016488
## 0.40394650 4 5 1826 0.8627525 0.8047718
## 0.43589702 5 25 3366 0.8602892 0.7972024
## 0.45618750 3 11 2956 0.8581971 0.7981091
## 0.47175568 10 19 658 0.8573075 0.8061230
## 0.48685329 4 19 1259 0.8468197 0.7931905
## 0.50218047 1 19 1828 0.8463778 0.7771647
## 0.53845428 4 17 3248 0.8562245 0.7976726
## 0.54697262 2 14 1908 0.8480733 0.7994286
## 0.54935125 2 7 113 0.8526083 0.7891845
## 0.56671028 9 24 4737 0.8639321 0.8070099
## 0.58348927 5 21 1411 0.8518223 0.7959067
## 0.58480647 3 5 2246 0.8510294 0.7940992
## Spec
## 0.7464529
## 0.7600732
## 0.7521090
## 0.7578197
## 0.7549628
## 0.7492809
## 0.7532646
## 0.7436312
## 0.7521058
## 0.7521475
## 0.7549275
## 0.7379622
## 0.7384919
## 0.7555277
## 0.7538457
## 0.7509791
## 0.7504301
## 0.7419395
## 0.7487288
## 0.7408064
## 0.7419138
## 0.7373652
## 0.7254655
## 0.7294171
## 0.7356895
## 0.7317315
## 0.7345243
## 0.7396540
## 0.7453101
## 0.7379526
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 3468, interaction.depth =
## 5, shrinkage = 0.08328927 and n.minobsinnode = 5.
model_gbm_std_pre <- predict(model_gbm_std,newdata = data_testing_std)
confusionMatrix(model_gbm_std_pre,data_testing_std$决定)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 拒绝 接受
## 拒绝 799 168
## 接受 162 587
##
## Accuracy : 0.8077
## 95% CI : (0.7882, 0.8261)
## No Information Rate : 0.56
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6094
##
## Mcnemar's Test P-Value : 0.7831
##
## Sensitivity : 0.8314
## Specificity : 0.7775
## Pos Pred Value : 0.8263
## Neg Pred Value : 0.7837
## Prevalence : 0.5600
## Detection Rate : 0.4656
## Detection Prevalence : 0.5635
## Balanced Accuracy : 0.8045
##
## 'Positive' Class : 拒绝
##
model_gbm_random_std_pre <- predict(model_gbm_random_std,
newdata = data_testing_std)
confusionMatrix(model_gbm_random_std_pre,data_testing_std$决定)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 拒绝 接受
## 拒绝 819 168
## 接受 142 587
##
## Accuracy : 0.8193
## 95% CI : (0.8003, 0.8373)
## No Information Rate : 0.56
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6321
##
## Mcnemar's Test P-Value : 0.1556
##
## Sensitivity : 0.8522
## Specificity : 0.7775
## Pos Pred Value : 0.8298
## Neg Pred Value : 0.8052
## Prevalence : 0.5600
## Detection Rate : 0.4773
## Detection Prevalence : 0.5752
## Balanced Accuracy : 0.8149
##
## 'Positive' Class : 拒绝
##
result_df_compare <- tibble(
gbm = confusionMatrix(model_gbm_pre, data_testing$决定)[[3]][[1]],
gbm_std = confusionMatrix(model_gbm_std_pre, data_testing$决定)[[3]][[1]], # 利用混淆矩阵评估模型
gbm_random_std = confusionMatrix(model_gbm_random_std_pre, data_testing$决定)[[3]][[1]], # 利用混淆矩阵评估模型
gbm_random = confusionMatrix(model_gbm_random_pre, data_testing$决定)[[3]][[1]],
)
result_df_compare %>% t() %>% as.data.frame() %>% rownames_to_column(var = "model") %>%
rename(value = V1) %>%
arrange(value)