1 加载经常用的R包

library(pacman)
# 读数据
p_load(readxl,writexl,data.table,openxlsx,haven,rvest)
# 数据探索
p_load(tidyverse,DT,skimr,DataExplorer,explore,vtable,stringr,kableExtra,lubridate)
# 模型
p_load(grf,glmnet,caret,tidytext,fpp2,forecast,car,tseries,hdm,tidymodels,broom)
# 可视化
p_load(patchwork,ggrepel,ggcorrplot,gghighlight,ggthemes,shiny)
# 其它常用包
p_load(magrittr,listviewer,devtools,here,janitor,reticulate,jsonlite)

2 模型调参

机器学习调参的思路都异曲同工,首先确定一个参数池,也就是模型参数值的可选范围。从这个池子中挑选出不同的参数组合,对于每个组合都计算其预测精度,最后选取预测精度最高的参数组合。

2.1 什么是调参?

调参的过程就像是找人生伴侣的过程,首先我们有一个标准,比如身高体重等,符合这个标准的异性将进入到参数池中。然后我们跟参数池中的每个异性谈恋爱,找到最适合我们的那个作为终极选择。接下来,介绍两种常见的调参方法:网格搜索随机搜索

knitr::include_graphics(here::here("Machine_Learning_and_Causal_Inference/fig/modify parameters.png"))
调参流程

Figure 2.1: 调参流程

网格搜索首先会有一个标准,将符合标准的参数放入参数池中,形成不同的参数组合。而随机搜索则不同,随机搜索没有标准,随机地组合参数。依然以找男友为例,假设参数有3个:身高体重年龄

网格搜索会对这3个参数设定一个范围,比如身高>180厘米体重小于<140斤年龄在20~40岁之间。但是随机搜索则不同,有些女性觉得如果设定了择偶条件,反而容易错过自己喜欢的,也许适合自己的恰好身高只有179厘米

这两种不同的搜索方式出来的参数组合是不同的。两者各有优缺点,随机搜索与网格搜索相比,其优点在于能随机地遍历所有参数空间,但是缺点也很明显:不知道随机出来的是什么类型的人。下面分别看看两种搜索方式的实现。

首先来看在caret包中如何轻轻松松实现网格搜索

  • 第一步:设置随机种子,保证实验的可重复性;

  • 第二步:利用traincontrol()函数设置模型训练时用到的参数。其中method表示重抽样方法。此处,cv表示交叉验证,number表示几折交叉验证,本例中是10折交叉验证10折交叉验证表示,首先将样本分为10个组,每次训练的时候抽取其中9组作为训练集,剩下的1组作为测试集。classProbs参数表示是否计算类别概率,如果评价指标为AUC,那么这里一定要设置为TRUE。由于因变量为两水平变量,所以summaryFunction这里设置为twoClassSummary

  • 第三步:设置网格搜索的参数池,也就是设定参数的选择范围。这里以机器学习中的gbm(Gradient boosting machine)方法为例,所以有4个超参数需要设定,分别为迭代次数(n.trees)树的复杂度(interaction.depth)学习率(shrinkage)训练样本的最小数目(n.minobsinnode)。这里设定了60组参数组合。

set.seed(1234)
fit_control_gbm <- trainControl(method = "cv",
                                number = 10,
                                classProbs = TRUE,
                                summaryFunction = twoClassSummary)

grid_gbm <- expand.grid(interaction.depth = c(1,5,9),
                        n.trees = (1:20) * 50,
                        shrinkage = 0.1,
                        n.minobsinnode = 20)
grid_gbm
  • 第四步:利用train()函数来进行模型训练及得到最优参数组合。该函数会遍历第三步得到的所有参数组合,并得到使评价指标最大的参数组合作为输出。method表示使用的模型,本例使用机器学习中的gbm(Gradient boosting machine)模型,使用的评价指标为ROC曲线面积即AUC值)。

3 导入数据

data <- data.table::fread(here::here("Machine_Learning_and_Causal_Inference/data/相亲数据重新编码.csv"))
data %<>% as.data.frame()
data

4 数据预处理

4.1 数据查看

data %>% str()  # 全是数值型
## 'data.frame':    8378 obs. of  29 variables:
##  $ 决定            : int  1 1 1 1 1 0 1 0 1 1 ...
##  $ 性别            : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ 吸引力          : num  6 7 5 7 5 4 7 4 7 5 ...
##  $ 共同爱好        : num  5 6 7 8 6 4 7 6 8 8 ...
##  $ 幽默            : num  7 8 8 7 7 4 4 6 9 8 ...
##  $ 真诚            : num  9 8 8 6 6 9 6 9 6 6 ...
##  $ 雄心            : num  6 5 5 6 6 6 6 5 8 10 ...
##  $ 智力            : num  7 7 9 8 7 7 7 7 8 6 ...
##  $ 好感            : num  7 7 7 7 6 6 6 6 7 6 ...
##  $ 成功率自估      : num  6 5 NA 6 6 5 5 7 7 6 ...
##  $ 日常出门频率    : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ 对宗教的看重程度: int  4 4 4 4 4 4 4 4 4 4 ...
##  $ 对种族的看重程度: int  2 2 2 2 2 2 2 2 2 2 ...
##  $ 年龄            : int  21 21 21 21 21 21 21 21 21 21 ...
##  $ 种族            : int  4 4 4 4 4 4 4 4 4 4 ...
##  $ 从事领域        : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ 对方决定        : int  0 0 1 1 1 1 0 0 1 0 ...
##  $ 好感得分        : num  7 8 10 7 8 7 2 7 6.5 6 ...
##  $ 对方评估成功率  : num  4 4 10 7 6 6 1 5 8 6 ...
##  $ 吸引力得分      : num  6 7 10 7 8 7 3 6 7 6 ...
##  $ 共同爱好得分    : num  6 5 10 8 7 7 7 6 9 6 ...
##  $ 幽默得分        : num  8 7 10 8 6 8 5 6 8 6 ...
##  $ 真诚得分        : num  8 8 10 8 7 7 6 7 7 6 ...
##  $ 雄心得分        : num  8 7 10 9 9 7 8 8 8 6 ...
##  $ 智力得分        : num  8 10 10 9 9 8 7 5 8 6 ...
##  $ 对方年龄        : int  27 22 22 23 24 25 30 27 28 24 ...
##  $ 对方种族        : int  2 2 4 2 3 2 2 2 2 2 ...
##  $ 是否同一种族    : int  0 0 1 0 0 0 0 0 0 0 ...
##  $ 日常约会频率    : int  7 7 7 7 7 7 7 7 7 7 ...
data %>% skimr::skim()
Table 4.1: Data summary
Name Piped data
Number of rows 8378
Number of columns 29
_______________________
Column type frequency:
numeric 29
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
决定 0 1.00 0.42 0.49 0 0 0 1 1.0 ▇▁▁▁▆
性别 0 1.00 0.50 0.50 0 0 1 1 1.0 ▇▁▁▁▇
吸引力 202 0.98 6.19 1.95 0 5 6 8 10.0 ▁▃▇▇▂
共同爱好 1067 0.87 5.47 2.16 0 4 6 7 10.0 ▂▅▇▆▂
幽默 350 0.96 6.40 1.95 0 5 7 8 10.0 ▁▂▇▇▂
真诚 277 0.97 7.18 1.74 0 6 7 8 10.0 ▁▁▃▇▃
雄心 712 0.92 6.78 1.79 0 6 7 8 10.0 ▁▂▆▇▃
智力 296 0.96 7.37 1.55 0 6 7 8 10.0 ▁▁▃▇▃
好感 240 0.97 6.13 1.84 0 5 6 7 10.0 ▁▃▇▇▂
成功率自估 309 0.96 5.21 2.13 0 4 5 7 10.0 ▂▅▇▅▁
日常出门频率 79 0.99 2.16 1.11 1 1 2 3 7.0 ▇▃▁▁▁
对宗教的看重程度 79 0.99 3.65 2.81 1 1 3 6 10.0 ▇▃▃▂▁
对种族的看重程度 79 0.99 3.78 2.85 0 1 3 6 10.0 ▇▃▂▂▂
年龄 95 0.99 26.36 3.57 18 24 26 28 55.0 ▇▇▁▁▁
种族 63 0.99 2.76 1.23 1 2 2 4 6.0 ▇▁▃▁▁
从事领域 82 0.99 7.66 3.76 1 5 8 10 18.0 ▃▃▇▂▁
对方决定 0 1.00 0.42 0.49 0 0 0 1 1.0 ▇▁▁▁▆
好感得分 250 0.97 6.13 1.84 0 5 6 7 10.0 ▁▃▇▇▂
对方评估成功率 318 0.96 5.21 2.13 0 4 5 7 10.0 ▂▅▇▅▁
吸引力得分 212 0.97 6.19 1.95 0 5 6 8 10.5 ▁▃▇▇▂
共同爱好得分 1076 0.87 5.47 2.16 0 4 6 7 10.0 ▂▅▇▆▂
幽默得分 360 0.96 6.40 1.95 0 5 7 8 11.0 ▁▂▇▇▂
真诚得分 287 0.97 7.18 1.74 0 6 7 8 10.0 ▁▁▃▇▃
雄心得分 722 0.91 6.78 1.79 0 6 7 8 10.0 ▁▂▆▇▃
智力得分 306 0.96 7.37 1.55 0 6 7 8 10.0 ▁▁▃▇▃
对方年龄 104 0.99 26.36 3.56 18 24 26 28 55.0 ▇▇▁▁▁
对方种族 73 0.99 2.76 1.23 1 2 2 4 6.0 ▇▁▃▁▁
是否同一种族 0 1.00 0.40 0.49 0 0 0 1 1.0 ▇▁▁▁▅
日常约会频率 97 0.99 5.01 1.44 1 4 5 6 7.0 ▁▂▅▃▇

4.2 删除缺失值

现实生活中,在数据分析时,经常会碰到缺失值,比如相亲数据中,有些女性不愿意暴露自己的年龄,年龄就会有缺失值。那么对于缺失值,怎么处理呢?处理方式很多,甚至有时候数据缺失本身也暗含一些信息(比如年龄缺失的女性可能是因为年龄比较大),由此引申了许多插补方法。不过这里缺失值处理并不是重点,因此对于缺失值直接删除即可

data <- data %>% drop_na()
data

4.3 删除近零方差

零方差或者近零方差的变量传递不了什么信息,因为几乎所有人的取值都一样。可以利用caret包中的nearZeroVar()函数,一行代码就能找出近零方差的变量,操作过程非常简单。

data %>% nearZeroVar() # 没有
## integer(0)

4.4 转换数据类型

对于完整的观测,首先需要定义变量的类型:属于定性变量还是连续变量。对于定性变量而言,需要给定性变量的各个水平取名,比如性别有两个水平1和0,分别命名为男、女

data %>%
  mutate(决定   = factor(决定,
                         levels = c(0, 1),
                         labels = c("拒绝", "接受"))) %>%
  mutate(性别   = factor(性别,
                         levels = c(0, 1),
                         labels = c("女", "男"))) %>%
  mutate(种族   = factor(
    种族,
    levels = c(1, 2, 3, 4, 5, 6),
    labels = c("非洲裔", "欧洲裔", "拉丁裔", "亚裔", "印第安土著", "其他")
  )) %>%
  mutate(从事领域 = factor(
    从事领域,
    levels = 1:18,
    labels = c(
      "法律",
      "数学",
      "社会科学或心理学",
      "医学或药物学或生物技术",
      "工程学",
      "写作或新闻",
      "历史或宗教或哲学",
      "商业或经济或金融",
      "教育或学术",
      "生物科学或化学或物理",
      "社会工作",
      "大学在读或未择方向",
      "政治学或国际事务",
      "电影",
      "艺术管理",
      "语言",
      "建筑学",
      "其他"
    )
  )) %>%
  mutate(对方决定  = factor(对方决定,
                            levels = 0:1,
                            labels = c("拒绝", "接收"))) %>%
  mutate(对方种族  = factor(
    对方种族,
    levels = c(1, 2, 3, 4, 5, 6),
    labels = c("非洲裔", "欧洲裔", "拉丁裔", "亚裔", "印第安土著", "其他")
  )) %>%
  mutate(是否同一种族  = factor(
    是否同一种族,
    levels = c(0, 1),
    labels = c("非同一种族", "同一种族") 
  )) -> data

data %>% map(unique)
## $决定
## [1] 接受 拒绝
## Levels: 拒绝 接受
## 
## $性别
## [1] 女 男
## Levels: 女 男
## 
## $吸引力
##  [1]  6.0  7.0  5.0  4.0  8.0  9.0  3.0 10.0  2.0  1.0  0.0  6.5  7.5  9.5  8.5
## [16]  9.9  3.5
## 
## $共同爱好
##  [1]  5.0  6.0  8.0  4.0  7.0  3.0  2.0  9.0 10.0  1.0  0.0  7.5  6.5  8.5  5.5
## 
## $幽默
##  [1]  7.0  8.0  4.0  6.0  9.0  3.0  5.0 10.0  2.0  1.0  0.0  5.5  6.5  9.5  7.5
## [16]  8.5
## 
## $真诚
##  [1]  9.0  8.0  6.0  7.0  5.0 10.0  4.0  3.0  2.0  1.0  0.0  8.5  7.5
## 
## $雄心
##  [1]  6.0  5.0  8.0 10.0  9.0  3.0  7.0  4.0  2.0  1.0  0.0  9.5  7.5  8.5
## 
## $智力
##  [1]  7.0  8.0  6.0  9.0 10.0  5.0  4.0  3.0  2.0  1.0  0.0  6.5  8.5  7.5  5.5
## 
## $好感
##  [1]  7.0  6.0  8.0  5.0  9.0  4.0 10.0  2.0  3.0  6.5  1.0  8.5  9.5  0.0  7.5
## [16]  5.5  4.5  9.7
## 
## $成功率自估
##  [1]  6.0  5.0  7.0  4.0  3.0  8.0  1.0 10.0  2.0  9.0  0.0  6.5  7.5  8.5  9.5
## [16]  5.5  3.5  4.5
## 
## $日常出门频率
## [1] 1 4 2 3 5 7 6
## 
## $对宗教的看重程度
##  [1]  4  5  1  3  2  8 10  6  7  9
## 
## $对种族的看重程度
##  [1]  2  8  1  4  7  3  9 10  5  6  0
## 
## $年龄
##  [1] 21 24 25 23 22 26 27 30 28 29 34 35 32 20 19 18 33 36 31 42 38 55
## 
## $种族
## [1] 亚裔   欧洲裔 其他   拉丁裔 非洲裔
## Levels: 非洲裔 欧洲裔 拉丁裔 亚裔 印第安土著 其他
## 
## $从事领域
##  [1] 法律                   数学                   政治学或国际事务      
##  [4] 商业或经济或金融       工程学                 教育或学术            
##  [7] 社会科学或心理学       社会工作               大学在读或未择方向    
## [10] 医学或药物学或生物技术 历史或宗教或哲学       写作或新闻            
## [13] 生物科学或化学或物理   电影                   语言                  
## [16] 艺术管理               建筑学                 其他                  
## 18 Levels: 法律 数学 社会科学或心理学 医学或药物学或生物技术 ... 其他
## 
## $对方决定
## [1] 拒绝 接收
## Levels: 拒绝 接收
## 
## $好感得分
##  [1]  7.0  8.0  2.0  6.5  6.0 10.0  9.0  4.0  5.0  3.0  1.0  9.5  7.5  4.5  8.5
## [16]  5.5  0.0  9.7
## 
## $对方评估成功率
##  [1]  4.0  7.0  6.0  1.0  5.0  8.0  2.0 10.0  3.0  9.0  4.5  6.5  5.5  0.0  7.5
## [16]  8.5  9.5  3.5
## 
## $吸引力得分
##  [1]  6.0  7.0  8.0  3.0 10.0  9.0  5.0  4.0  2.0  1.0  0.0  6.5  7.5  8.5  9.5
## [16]  9.9  3.5
## 
## $共同爱好得分
##  [1]  6.0  5.0  8.0  7.0  9.0  4.0 10.0  3.0  2.0  1.0  7.5  6.5  8.5  0.0  5.5
## 
## $幽默得分
##  [1]  8.0  7.0  6.0  5.0  9.0 10.0  3.0  4.0  2.0  1.0  5.5  6.5  9.5  0.0  7.5
## [16]  8.5 11.0
## 
## $真诚得分
##  [1]  8.0  7.0  6.0 10.0  9.0  3.0  5.0  4.0  2.0  1.0  4.5  8.5  0.0  7.5
## 
## $雄心得分
##  [1]  8.0  7.0  9.0  6.0 10.0  5.0  4.0  2.0  3.0  1.0  5.5  9.5  0.0  7.5  8.5
## 
## $智力得分
##  [1]  8.0 10.0  9.0  7.0  5.0  6.0  4.0  3.0  1.0  2.0  6.5  8.5  0.0  7.5  5.5
## 
## $对方年龄
##  [1] 27 22 23 24 25 30 28 21 26 29 32 35 34 18 20 19 33 31 36 42 38 55
## 
## $对方种族
## [1] 欧洲裔 拉丁裔 亚裔   其他   非洲裔
## Levels: 非洲裔 欧洲裔 拉丁裔 亚裔 印第安土著 其他
## 
## $是否同一种族
## [1] 非同一种族 同一种族  
## Levels: 非同一种族 同一种族
## 
## $日常约会频率
## [1] 7 5 3 4 6 1 2

4.5 删除共线性变量

caret包中的findCorrelation()函数会自动找到高度共线性的变量,并给出建议删除的变量

但需要注意,这个函数对输入的数据要求比较高

data %>% 
  select(where(is.numeric)) %>% 
  cor() -> data_cor

data_cor %>% round(1) %>% ggcorrplot::ggcorrplot(lab = TRUE,type = "lower")  + mytheme

data_high_cor <- findCorrelation(data_cor,cutoff = 0.75,names = TRUE)
data_high_cor      # 没有共线性变量
## character(0)

5 划分数据集(无标准化)

set.seed(1234)
data_id <- createDataPartition(y = data$决定,p = 0.7,list = FALSE,times = 1)
data_training <- data[data_id,]
data_testing <- data[-data_id,]
data_training %$% table(决定) %>% prop.table()
data_testing %$% table(决定) %>% prop.table()
data %$% table(决定) %>% prop.table()
## 决定
##    拒绝    接受 
## 0.56002 0.43998 
## 决定
##      拒绝      接受 
## 0.5600233 0.4399767 
## 决定
##     拒绝     接受 
## 0.560021 0.439979

6 训练模型

6.1 网格搜索结果

# model_gbm <- train(决定 ~ .,
#                      data = data_training,
#                      method = "gbm",     # 方法
#                      trControl = fit_control_gbm,
#                      verbose = FALSE,
#                      tuneGrid = grid_gbm,   # 网格搜索结果
#                      metric = "ROC")  # 指标

data_training %>% dim()
## [1] 4007   29
# save(model_gbm,file = "Machine_Learning_and_Causal_Inference/result/model_gbm.RData")
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm.RData"))
model_gbm
## Stochastic Gradient Boosting 
## 
## 4007 samples
##   28 predictor
##    2 classes: '拒绝', '接受' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3607, 3607, 3606, 3606, 3606, 3607, ... 
## Resampling results across tuning parameters:
## 
##   interaction.depth  n.trees  ROC        Sens       Spec     
##   1                    50     0.8458812  0.7803234  0.7271540
##   1                   100     0.8547180  0.7852143  0.7430566
##   1                   150     0.8585245  0.7914623  0.7481638
##   1                   200     0.8601837  0.7923552  0.7521347
##   1                   250     0.8608241  0.7896865  0.7481638
##   1                   300     0.8607350  0.7883452  0.7493002
##   1                   350     0.8603262  0.7914702  0.7470275
##   1                   400     0.8607127  0.7874544  0.7470339
##   1                   450     0.8607112  0.7910119  0.7464625
##   1                   500     0.8603422  0.7914643  0.7487256
##   1                   550     0.8599789  0.7928056  0.7510048
##   1                   600     0.8601791  0.7932500  0.7447772
##   1                   650     0.8606016  0.7932460  0.7504430
##   1                   700     0.8601213  0.7945833  0.7527125
##   1                   750     0.8597893  0.7950258  0.7481542
##   1                   800     0.8591014  0.7972560  0.7475892
##   1                   850     0.8583529  0.7941409  0.7481478
##   1                   900     0.8583318  0.7928056  0.7470178
##   1                   950     0.8578740  0.7914702  0.7453197
##   1                  1000     0.8581766  0.7896865  0.7458911
##   5                    50     0.8637739  0.8083948  0.7436120
##   5                   100     0.8679483  0.8043948  0.7493002
##   5                   150     0.8708205  0.8079444  0.7561344
##   5                   200     0.8720001  0.8101746  0.7538617
##   5                   250     0.8730226  0.8141944  0.7555598
##   5                   300     0.8738410  0.8150952  0.7532839
##   5                   350     0.8741765  0.8097480  0.7572515
##   5                   400     0.8738580  0.8084008  0.7634823
##   5                   450     0.8751164  0.8119722  0.7617777
##   5                   500     0.8747352  0.8097440  0.7589529
##   5                   550     0.8743121  0.8026091  0.7527253
##   5                   600     0.8756294  0.8070655  0.7606703
##   5                   650     0.8749445  0.8057262  0.7612160
##   5                   700     0.8753659  0.8070655  0.7612128
##   5                   750     0.8759504  0.8079524  0.7617938
##   5                   800     0.8763178  0.8088512  0.7663328
##   5                   850     0.8756155  0.8084067  0.7617970
##   5                   900     0.8762538  0.8075139  0.7652061
##   5                   950     0.8765378  0.8066230  0.7612160
##   5                  1000     0.8761628  0.8061726  0.7617842
##   9                    50     0.8665025  0.8092976  0.7436441
##   9                   100     0.8716481  0.8159901  0.7544331
##   9                   150     0.8735060  0.8146369  0.7544331
##   9                   200     0.8747870  0.8222143  0.7515858
##   9                   250     0.8768251  0.8240040  0.7538392
##   9                   300     0.8769973  0.8213294  0.7532775
##   9                   350     0.8772060  0.8182103  0.7549692
##   9                   400     0.8767967  0.8155357  0.7509951
##   9                   450     0.8778824  0.8199921  0.7527029
##   9                   500     0.8781857  0.8159861  0.7532743
##   9                   550     0.8785753  0.8231171  0.7515858
##   9                   600     0.8792980  0.8204444  0.7476181
##   9                   650     0.8795248  0.8226627  0.7521475
##   9                   700     0.8793808  0.8213254  0.7538521
##   9                   750     0.8801409  0.8248889  0.7527221
##   9                   800     0.8797453  0.8186508  0.7544171
##   9                   850     0.8799191  0.8217639  0.7544203
##   9                   900     0.8794603  0.8217659  0.7487705
##   9                   950     0.8796739  0.8231032  0.7510112
##   9                  1000     0.8797459  0.8199901  0.7487481
## 
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
## 
## Tuning parameter 'n.minobsinnode' was held constant at a value of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 750, interaction.depth =
##  9, shrinkage = 0.1 and n.minobsinnode = 20.

第五步:模型会自动确定ROC曲线面积最大(即AUC值最高)的参数组合,也就是图3中最高的点对应的参数组合,对应的AUC值为90.14%。

model_gbm %>% ggplot() + mytheme

6.2 网格搜索结果-变量重要性

library(gbm)
## Loaded gbm 2.1.8
model_gbm %>% varImp(scale = FALSE) -> variable_imp
variable_imp %>% ggplot() + mytheme

6.3 随机搜索结果

随机搜索与网格搜索相比,参数的选择没有固定的范围,最终的结果可能好也可能坏。它的实现步骤如下:

  • 第一步:设定随机种子。
  • 第二步:利用trainControl()函数设定模型训练的参数,但是多了一项:search=”random”
fit_control_gbm_random <- trainControl(method = "cv",
                                       number = 10,
                                       classProbs = TRUE,
                                       summaryFunction = twoClassSummary,
                                       search = "random")   # 随机搜索
  • 第三步:超参数在随机搜索中不受约束,没有条条框框的限制。所以无须设置tuneGrid参数,只需要设置参数tuneLength(随机搜索多少组)。
# model_gbm_random <- train(决定~.,data =data_training,
#                             method = "gbm",  # 方法
#                             trControl = fit_control_gbm_random,
#                             verbose = FALSE,
#                             metric = "ROC",  # 指标
#                             tuneLength = 30)

# save(model_gbm_random,file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random.RData"))
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random.RData"))
model_gbm_random
## Stochastic Gradient Boosting 
## 
## 4007 samples
##   28 predictor
##    2 classes: '拒绝', '接受' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3606, 3607, 3606, 3606, 3605, 3607, ... 
## Resampling results across tuning parameters:
## 
##   shrinkage    interaction.depth  n.minobsinnode  n.trees  ROC        Sens     
##   0.004834566  10                 13              4378     0.8830347  0.8204087
##   0.017773097   6                 18              2940     0.8801644  0.8088194
##   0.030012312  10                  7               939     0.8828093  0.8199623
##   0.065353448   2                 12              3466     0.8705157  0.8038948
##   0.076775003   7                 24              2421     0.8743527  0.8168313
##   0.098133025   9                 13              4482     0.8774781  0.8204087
##   0.120162734   7                  6              3408     0.8794093  0.8163889
##   0.170841822   4                 18              4080     0.8647263  0.8043631
##   0.211689456   2                  5               966     0.8655478  0.7949960
##   0.224400530   5                 21              2407     0.8651855  0.8074742
##   0.281793601   5                  8              3231     0.8693740  0.8123810
##   0.301029938   4                 11                32     0.8619964  0.8088294
##   0.311465749   4                 20              1499     0.8621497  0.8101647
##   0.340869527   2                 20              3095     0.8544660  0.7950317
##   0.341748821   9                 20              1300     0.8642645  0.8101647
##   0.349149748   4                 17              4122     0.8613487  0.8030417
##   0.367269076   3                 14              1158     0.8583283  0.8034603
##   0.380703725   1                  6              1745     0.8504483  0.7838492
##   0.384781755   5                  7               143     0.8551641  0.8003790
##   0.402804276   6                  6              4050     0.8751762  0.8164107
##   0.433879408   8                  7               220     0.8560501  0.7972421
##   0.446740546   5                 21              3244     0.8630270  0.8115119
##   0.463837192   2                 20               207     0.8489684  0.7909742
##   0.466876172   3                 25              3485     0.8535598  0.7981329
##   0.474556008  10                  6              1790     0.8714595  0.8190556
##   0.483829035   7                 22              1094     0.8482503  0.7874286
##   0.516980422   6                 12              4603     0.8675019  0.8088175
##   0.562797687   3                 18              2735     0.8544221  0.8047996
##   0.586171088  10                 18              4890     0.8635503  0.7901091
##   0.591066128   8                 16               547     0.8535618  0.8119365
##   Spec     
##   0.7617264
##   0.7634438
##   0.7594536
##   0.7571905
##   0.7639927
##   0.7611486
##   0.7640055
##   0.7532101
##   0.7537879
##   0.7452844
##   0.7509470
##   0.7430117
##   0.7481061
##   0.7526451
##   0.7509534
##   0.7430342
##   0.7412975
##   0.7373941
##   0.7356317
##   0.7571841
##   0.7390280
##   0.7458430
##   0.7316866
##   0.7418657
##   0.7464079
##   0.7407550
##   0.7452908
##   0.7362513
##   0.7600122
##   0.7288649
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 4378, interaction.depth
##  = 10, shrinkage = 0.004834566 and n.minobsinnode = 13.

6.4 随机搜索结果-变量重要性

model_gbm_random %>% varImp(scale = FALSE) -> variable_imp
variable_imp %>% ggplot() + mytheme

7 预测模型

确定最优参数之后,模型如何进行预测呢?使用predict()函数,只要输入模型及测试集,就可以预测了。然后利用confusionMatrix()函数输入真实的Y与预测的Y就可以得到混淆矩阵(Confusion Matrix)

网格搜索的参数与随机搜索的参数的预测结果有什么区别呢?下面的操作结果可以明显看出两者的区别。

7.1 网格搜索预测结果

model_gbm_pre <- predict(model_gbm,newdata = data_testing)
confusionMatrix(model_gbm_pre,data_testing$决定)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction 拒绝 接受
##       拒绝  807  172
##       接受  154  583
##                                           
##                Accuracy : 0.81            
##                  95% CI : (0.7906, 0.8283)
##     No Information Rate : 0.56            
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.6135          
##                                           
##  Mcnemar's Test P-Value : 0.3464          
##                                           
##             Sensitivity : 0.8398          
##             Specificity : 0.7722          
##          Pos Pred Value : 0.8243          
##          Neg Pred Value : 0.7910          
##              Prevalence : 0.5600          
##          Detection Rate : 0.4703          
##    Detection Prevalence : 0.5705          
##       Balanced Accuracy : 0.8060          
##                                           
##        'Positive' Class : 拒绝            
## 

7.2 随机搜索预测结果

model_gbm_random_pre <- predict(model_gbm_random,newdata = data_testing)
confusionMatrix(model_gbm_random_pre,data_testing$决定)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction 拒绝 接受
##       拒绝  814  171
##       接受  147  584
##                                           
##                Accuracy : 0.8147          
##                  95% CI : (0.7955, 0.8328)
##     No Information Rate : 0.56            
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.6227          
##                                           
##  Mcnemar's Test P-Value : 0.1971          
##                                           
##             Sensitivity : 0.8470          
##             Specificity : 0.7735          
##          Pos Pred Value : 0.8264          
##          Neg Pred Value : 0.7989          
##              Prevalence : 0.5600          
##          Detection Rate : 0.4744          
##    Detection Prevalence : 0.5740          
##       Balanced Accuracy : 0.8103          
##                                           
##        'Positive' Class : 拒绝            
## 

8 划分数据集(进行标准化)

标准化处理是指将数据处理为均值为0、标准差为1的数据。那么为什么要进行标准化处理呢?因为在进行实证分析时,有些变量取值很大,有些变量取值很小,这里需要营造一个公平公正的环境,权重的大小不能被自身变量取值的大小所束缚。比如在判断一个女生是否是美女时,会考虑腿长、脸长、脸宽、腰围等因素,这些因素的学名为特征。显然腿长的取值比脸长的取值大得多,这时为了防止腿长的权重过高,就需要将这些特征进行标准化才能学习各个变量真实的权重。

标准化处理时,只能利用训练集的均值与标准差对训练集和测试集进行标准化

pre_process_value <- preProcess(data_training,
                                method = c("center","scale"))

data_training_std <- predict(pre_process_value,
                             data_training)

# 利用训练集的均值和标准差对测试集进行标准化(重要)
data_testing_std <- predict(pre_process_value,
                            data_testing)

9 训练模型

9.1 网格搜索结果

# model_gbm_std <- train(决定 ~ .,
#                      data = data_training_std,   # 数据已经标准化
#                      method = "gbm",   # 方法
#                      trControl = fit_control_gbm,   
#                      verbose = FALSE,
#                      tuneGrid = grid_gbm,   # 网格搜索
#                      metric = "ROC")  # 指标

data_training_std %>% dim()
## [1] 4007   29
# save(model_gbm_std,file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_std.RData"))
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_std.RData"))
model_gbm_std
## Stochastic Gradient Boosting 
## 
## 4007 samples
##   28 predictor
##    2 classes: '拒绝', '接受' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3607, 3607, 3606, 3606, 3606, 3607, ... 
## Resampling results across tuning parameters:
## 
##   interaction.depth  n.trees  ROC        Sens       Spec     
##   1                    50     0.8458685  0.7803234  0.7271540
##   1                   100     0.8547154  0.7852143  0.7430566
##   1                   150     0.8584942  0.7910159  0.7481638
##   1                   200     0.8600550  0.7923552  0.7527029
##   1                   250     0.8609050  0.7901310  0.7464625
##   1                   300     0.8606871  0.7896825  0.7498684
##   1                   350     0.8602378  0.7914702  0.7470243
##   1                   400     0.8605659  0.7892361  0.7475989
##   1                   450     0.8604984  0.7927976  0.7464625
##   1                   500     0.8603470  0.7923571  0.7498523
##   1                   550     0.8598781  0.7941409  0.7510015
##   1                   600     0.8600794  0.7941349  0.7493066
##   1                   650     0.8607604  0.7950258  0.7527061
##   1                   700     0.8596862  0.7927996  0.7515761
##   1                   750     0.8595700  0.7950238  0.7470307
##   1                   800     0.8590555  0.7941329  0.7453294
##   1                   850     0.8583105  0.7927956  0.7504269
##   1                   900     0.8584066  0.7932440  0.7487256
##   1                   950     0.8577257  0.7914683  0.7453197
##   1                  1000     0.8580634  0.7914742  0.7470339
##   5                    50     0.8638778  0.8088413  0.7419074
##   5                   100     0.8683209  0.8035099  0.7487352
##   5                   150     0.8709084  0.8119563  0.7601053
##   5                   200     0.8720089  0.8150794  0.7549884
##   5                   250     0.8732528  0.8137440  0.7566801
##   5                   300     0.8737551  0.8177659  0.7600764
##   5                   350     0.8743140  0.8173254  0.7583686
##   5                   400     0.8746825  0.8124206  0.7572355
##   5                   450     0.8758394  0.8173234  0.7532743
##   5                   500     0.8761987  0.8164246  0.7561216
##   5                   550     0.8762954  0.8177659  0.7533032
##   5                   600     0.8771053  0.8168651  0.7601149
##   5                   650     0.8766423  0.8173115  0.7652189
##   5                   700     0.8768278  0.8186488  0.7623844
##   5                   750     0.8771631  0.8195437  0.7606574
##   5                   800     0.8774133  0.8155417  0.7618034
##   5                   850     0.8772425  0.8137440  0.7657646
##   5                   900     0.8779780  0.8195337  0.7663232
##   5                   950     0.8774379  0.8150754  0.7685927
##   5                  1000     0.8781535  0.8195437  0.7680277
##   9                    50     0.8667744  0.8106349  0.7425045
##   9                   100     0.8714460  0.8159782  0.7510304
##   9                   150     0.8730945  0.8141984  0.7533064
##   9                   200     0.8743664  0.8231052  0.7510208
##   9                   250     0.8756663  0.8239921  0.7566834
##   9                   300     0.8756694  0.8195317  0.7561120
##   9                   350     0.8749122  0.8173135  0.7538424
##   9                   400     0.8741921  0.8164206  0.7583686
##   9                   450     0.8743759  0.8155357  0.7623299
##   9                   500     0.8747896  0.8150774  0.7606382
##   9                   550     0.8756587  0.8186548  0.7629109
##   9                   600     0.8761933  0.8190992  0.7578165
##   9                   650     0.8760428  0.8204286  0.7566705
##   9                   700     0.8762266  0.8204286  0.7515761
##   9                   750     0.8766359  0.8213294  0.7549820
##   9                   800     0.8765712  0.8266627  0.7527125
##   9                   850     0.8764037  0.8235476  0.7566801
##   9                   900     0.8766565  0.8253313  0.7538424
##   9                   950     0.8765870  0.8217639  0.7510080
##   9                  1000     0.8765769  0.8204325  0.7510144
## 
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
## 
## Tuning parameter 'n.minobsinnode' was held constant at a value of 20
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 1000, interaction.depth =
##  5, shrinkage = 0.1 and n.minobsinnode = 20.

9.2 网格搜索结果-变量重要性

model_gbm_std %>% varImp(scale = FALSE) -> variable_imp
variable_imp %>% ggplot() + mytheme

9.3 随机搜索结果

fit_control_gbm_random <- trainControl(method = "cv",
                                       number = 10,
                                       classProbs = TRUE,
                                       summaryFunction = twoClassSummary,
                                       search = "random")  # 随机搜索

# model_gbm_random_std <- train(
#   决定 ~ .,
#   data = data_training_std,  # 数据已经标准化
#   method = "gbm",
#   trControl = fit_control_gbm_random,
#   verbose = FALSE,
#   metric = "ROC",
#   tuneLength = 30    # 随机搜索
# )

# save(model_gbm_random_std,file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random_std.RData"))
load(file = here::here("Machine_Learning_and_Causal_Inference/result/model_gbm_random_std.RData"))
model_gbm_random_std
## Stochastic Gradient Boosting 
## 
## 4007 samples
##   28 predictor
##    2 classes: '拒绝', '接受' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3607, 3606, 3607, 3607, 3606, 3606, ... 
## Resampling results across tuning parameters:
## 
##   shrinkage   interaction.depth  n.minobsinnode  n.trees  ROC        Sens     
##   0.05639147   6                 10               213     0.8730341  0.8177183
##   0.06563599   9                 17              4530     0.8779457  0.8128056
##   0.08328927   5                  5              3468     0.8797575  0.8230575
##   0.09511411   3                  8              4758     0.8762998  0.8190456
##   0.12019083   2                 11               725     0.8703719  0.8105893
##   0.13293341   6                 23              2669     0.8731750  0.8203889
##   0.14990818   3                 15              3684     0.8682060  0.8088016
##   0.15548045  10                 22              2601     0.8710143  0.8154881
##   0.16083543   5                 17              3884     0.8707270  0.8101389
##   0.19761656   3                 10              1590     0.8653865  0.8074762
##   0.19882499   7                 24              1869     0.8691928  0.8029881
##   0.23788788   5                 24              1733     0.8639957  0.8101548
##   0.25513100   3                 18               251     0.8588094  0.8070159
##   0.32942565   4                  9              3004     0.8660796  0.8016825
##   0.33407670  10                 16              3788     0.8720841  0.8074524
##   0.35872628   9                 24              2169     0.8684708  0.8114603
##   0.38207893  10                 23              4830     0.8708166  0.8141488
##   0.39522148  10                 21              4059     0.8665400  0.8016488
##   0.40394650   4                  5              1826     0.8627525  0.8047718
##   0.43589702   5                 25              3366     0.8602892  0.7972024
##   0.45618750   3                 11              2956     0.8581971  0.7981091
##   0.47175568  10                 19               658     0.8573075  0.8061230
##   0.48685329   4                 19              1259     0.8468197  0.7931905
##   0.50218047   1                 19              1828     0.8463778  0.7771647
##   0.53845428   4                 17              3248     0.8562245  0.7976726
##   0.54697262   2                 14              1908     0.8480733  0.7994286
##   0.54935125   2                  7               113     0.8526083  0.7891845
##   0.56671028   9                 24              4737     0.8639321  0.8070099
##   0.58348927   5                 21              1411     0.8518223  0.7959067
##   0.58480647   3                  5              2246     0.8510294  0.7940992
##   Spec     
##   0.7464529
##   0.7600732
##   0.7521090
##   0.7578197
##   0.7549628
##   0.7492809
##   0.7532646
##   0.7436312
##   0.7521058
##   0.7521475
##   0.7549275
##   0.7379622
##   0.7384919
##   0.7555277
##   0.7538457
##   0.7509791
##   0.7504301
##   0.7419395
##   0.7487288
##   0.7408064
##   0.7419138
##   0.7373652
##   0.7254655
##   0.7294171
##   0.7356895
##   0.7317315
##   0.7345243
##   0.7396540
##   0.7453101
##   0.7379526
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 3468, interaction.depth =
##  5, shrinkage = 0.08328927 and n.minobsinnode = 5.

9.4 随机搜索结果-变量重要性

model_gbm_random_std %>% varImp(scale = FALSE) -> variable_imp
variable_imp %>% ggplot() + mytheme

10 预测模型

10.1 网格搜索预测结果

model_gbm_std_pre <- predict(model_gbm_std,newdata = data_testing_std)
confusionMatrix(model_gbm_std_pre,data_testing_std$决定)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction 拒绝 接受
##       拒绝  799  168
##       接受  162  587
##                                           
##                Accuracy : 0.8077          
##                  95% CI : (0.7882, 0.8261)
##     No Information Rate : 0.56            
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.6094          
##                                           
##  Mcnemar's Test P-Value : 0.7831          
##                                           
##             Sensitivity : 0.8314          
##             Specificity : 0.7775          
##          Pos Pred Value : 0.8263          
##          Neg Pred Value : 0.7837          
##              Prevalence : 0.5600          
##          Detection Rate : 0.4656          
##    Detection Prevalence : 0.5635          
##       Balanced Accuracy : 0.8045          
##                                           
##        'Positive' Class : 拒绝            
## 

10.2 随机搜索预测结果

model_gbm_random_std_pre <- predict(model_gbm_random_std,
                                    newdata = data_testing_std)
confusionMatrix(model_gbm_random_std_pre,data_testing_std$决定)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction 拒绝 接受
##       拒绝  819  168
##       接受  142  587
##                                           
##                Accuracy : 0.8193          
##                  95% CI : (0.8003, 0.8373)
##     No Information Rate : 0.56            
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.6321          
##                                           
##  Mcnemar's Test P-Value : 0.1556          
##                                           
##             Sensitivity : 0.8522          
##             Specificity : 0.7775          
##          Pos Pred Value : 0.8298          
##          Neg Pred Value : 0.8052          
##              Prevalence : 0.5600          
##          Detection Rate : 0.4773          
##    Detection Prevalence : 0.5752          
##       Balanced Accuracy : 0.8149          
##                                           
##        'Positive' Class : 拒绝            
## 

11 比较结果

result_df_compare <- tibble(
  gbm = confusionMatrix(model_gbm_pre, data_testing$决定)[[3]][[1]],
  gbm_std = confusionMatrix(model_gbm_std_pre, data_testing$决定)[[3]][[1]],  # 利用混淆矩阵评估模型
  gbm_random_std = confusionMatrix(model_gbm_random_std_pre, data_testing$决定)[[3]][[1]],  # 利用混淆矩阵评估模型
  gbm_random = confusionMatrix(model_gbm_random_pre, data_testing$决定)[[3]][[1]], 
)
result_df_compare %>% t() %>% as.data.frame() %>% rownames_to_column(var = "model") %>% 
  rename(value = V1) %>% 
  arrange(value)

12 思考(标准化好坏)

  • 为什么需要进行标准化?
  • 什么时候需要进行标准化?
  • 为什么命名model_gbm_random_std_pre等?model,gbm,random,std,pre分别的含义,如此命名的好处,如果加入特征选择又该如何命名?model_gbm,model_gbm_random,model_gbm_std,model_gbm_random_std.