统一设置ggplot2的绘图风格

library(ggplot2)
theme_set(theme_bw(base_family = "STKaiti"))

9.1:决策树模型

一个决策树的实例

library(readr)
library(VIM)
## Loading required package: colorspace
## Loading required package: grid
## Loading required package: data.table
## VIM is ready to use. 
##  Since version 4.0.0 the GUI is in its own package VIMGUI.
## 
##           Please use the package to use the new (and old) GUI.
## Suggestions and bug-reports can be submitted at: https://github.com/alexkowa/VIM/issues
## 
## Attaching package: 'VIM'
## The following object is masked from 'package:datasets':
## 
##     sleep
library(caret)
## Loading required package: lattice
library(rpart)
library(rpart.plot)
library(Metrics)
## 
## Attaching package: 'Metrics'
## The following objects are masked from 'package:caret':
## 
##     precision, recall
library(ROCR)
## Loading required package: gplots
## 
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
## 
##     lowess
library(readxl)
library(stringr)
library(ggplot2)

## 使用决策树算法,分析泰坦尼克号数据集,
## 读取训练集和测试集

Ttrain <- read_csv("data/chap9/Titanic train.csv")
## Parsed with column specification:
## cols(
##   PassengerId = col_double(),
##   Survived = col_double(),
##   Pclass = col_double(),
##   Name = col_character(),
##   Sex = col_character(),
##   Age = col_double(),
##   SibSp = col_double(),
##   Parch = col_double(),
##   Ticket = col_character(),
##   Fare = col_double(),
##   Cabin = col_character(),
##   Embarked = col_character()
## )
Ttest <- read_csv("data/chap9/Titanic test.csv")
## Parsed with column specification:
## cols(
##   PassengerId = col_double(),
##   Pclass = col_double(),
##   Name = col_character(),
##   Sex = col_character(),
##   Age = col_double(),
##   SibSp = col_double(),
##   Parch = col_double(),
##   Ticket = col_character(),
##   Fare = col_double(),
##   Cabin = col_character(),
##   Embarked = col_character()
## )
dim(Ttrain)
## [1] 891  12
colnames(Ttrain)
##  [1] "PassengerId" "Survived"    "Pclass"      "Name"        "Sex"        
##  [6] "Age"         "SibSp"       "Parch"       "Ticket"      "Fare"       
## [11] "Cabin"       "Embarked"
dim(Ttest)
## [1] 418  11
## 组合数据
Alldata <- rbind.data.frame(Ttrain[,-2],Ttest)
summary(Ttrain)
##   PassengerId       Survived          Pclass          Name          
##  Min.   :  1.0   Min.   :0.0000   Min.   :1.000   Length:891        
##  1st Qu.:223.5   1st Qu.:0.0000   1st Qu.:2.000   Class :character  
##  Median :446.0   Median :0.0000   Median :3.000   Mode  :character  
##  Mean   :446.0   Mean   :0.3838   Mean   :2.309                     
##  3rd Qu.:668.5   3rd Qu.:1.0000   3rd Qu.:3.000                     
##  Max.   :891.0   Max.   :1.0000   Max.   :3.000                     
##                                                                     
##      Sex                 Age            SibSp           Parch       
##  Length:891         Min.   : 0.42   Min.   :0.000   Min.   :0.0000  
##  Class :character   1st Qu.:20.12   1st Qu.:0.000   1st Qu.:0.0000  
##  Mode  :character   Median :28.00   Median :0.000   Median :0.0000  
##                     Mean   :29.70   Mean   :0.523   Mean   :0.3816  
##                     3rd Qu.:38.00   3rd Qu.:1.000   3rd Qu.:0.0000  
##                     Max.   :80.00   Max.   :8.000   Max.   :6.0000  
##                     NA's   :177                                     
##     Ticket               Fare           Cabin             Embarked        
##  Length:891         Min.   :  0.00   Length:891         Length:891        
##  Class :character   1st Qu.:  7.91   Class :character   Class :character  
##  Mode  :character   Median : 14.45   Mode  :character   Mode  :character  
##                     Mean   : 32.20                                        
##                     3rd Qu.: 31.00                                        
##                     Max.   :512.33                                        
## 
summary(Ttest)
##   PassengerId         Pclass          Name               Sex           
##  Min.   : 892.0   Min.   :1.000   Length:418         Length:418        
##  1st Qu.: 996.2   1st Qu.:1.000   Class :character   Class :character  
##  Median :1100.5   Median :3.000   Mode  :character   Mode  :character  
##  Mean   :1100.5   Mean   :2.266                                        
##  3rd Qu.:1204.8   3rd Qu.:3.000                                        
##  Max.   :1309.0   Max.   :3.000                                        
##                                                                        
##       Age            SibSp            Parch           Ticket         
##  Min.   : 0.17   Min.   :0.0000   Min.   :0.0000   Length:418        
##  1st Qu.:21.00   1st Qu.:0.0000   1st Qu.:0.0000   Class :character  
##  Median :27.00   Median :0.0000   Median :0.0000   Mode  :character  
##  Mean   :30.27   Mean   :0.4474   Mean   :0.3923                     
##  3rd Qu.:39.00   3rd Qu.:1.0000   3rd Qu.:0.0000                     
##  Max.   :76.00   Max.   :8.0000   Max.   :9.0000                     
##  NA's   :86                                                          
##       Fare            Cabin             Embarked        
##  Min.   :  0.000   Length:418         Length:418        
##  1st Qu.:  7.896   Class :character   Class :character  
##  Median : 14.454   Mode  :character   Mode  :character  
##  Mean   : 35.627                                        
##  3rd Qu.: 31.500                                        
##  Max.   :512.329                                        
##  NA's   :1
summary(Alldata)
##   PassengerId       Pclass          Name               Sex           
##  Min.   :   1   Min.   :1.000   Length:1309        Length:1309       
##  1st Qu.: 328   1st Qu.:2.000   Class :character   Class :character  
##  Median : 655   Median :3.000   Mode  :character   Mode  :character  
##  Mean   : 655   Mean   :2.295                                        
##  3rd Qu.: 982   3rd Qu.:3.000                                        
##  Max.   :1309   Max.   :3.000                                        
##                                                                      
##       Age            SibSp            Parch          Ticket         
##  Min.   : 0.17   Min.   :0.0000   Min.   :0.000   Length:1309       
##  1st Qu.:21.00   1st Qu.:0.0000   1st Qu.:0.000   Class :character  
##  Median :28.00   Median :0.0000   Median :0.000   Mode  :character  
##  Mean   :29.88   Mean   :0.4989   Mean   :0.385                     
##  3rd Qu.:39.00   3rd Qu.:1.0000   3rd Qu.:0.000                     
##  Max.   :80.00   Max.   :8.0000   Max.   :9.000                     
##  NA's   :263                                                        
##       Fare            Cabin             Embarked        
##  Min.   :  0.000   Length:1309        Length:1309       
##  1st Qu.:  7.896   Class :character   Class :character  
##  Median : 14.454   Mode  :character   Mode  :character  
##  Mean   : 33.295                                        
##  3rd Qu.: 31.275                                        
##  Max.   :512.329                                        
##  NA's   :1
Survived <- Ttrain$Survived
table(Survived)
## Survived
##   0   1 
## 549 342
## 数据探索与可视化,及特征变换
## 分析数据的缺失值情况,并进行处理
aggr(Alldata)

colnames(Alldata)
##  [1] "PassengerId" "Pclass"      "Name"        "Sex"         "Age"        
##  [6] "SibSp"       "Parch"       "Ticket"      "Fare"        "Cabin"      
## [11] "Embarked"
# 训练集和测试集要同时作相同的操作
## Cabin缺失值太多,可以直接剔除
Alldata$Cabin <- NULL
## 船票和ID具有识别性所以需要剔除
Alldata$PassengerId <- NULL
Alldata$Ticket <- NULL

## 年龄变量的缺失值可以使用中位数来填补
Alldata$Age[is.na(Alldata$Age)] <- median(Alldata$Age,na.rm = TRUE)
## fare变量的缺失值可以使用均值来填补
Alldata$Fare[is.na(Alldata$Fare)] <- mean(Alldata$Fare,na.rm = TRUE)
## Embarked变量的缺失值,可以使用众数来填补
Embarkedmod <- names(sort(table(Alldata$Embarked),decreasing = T)[1])
Alldata$Embarked[is.na(Alldata$Embarked)] <- Embarkedmod

## 获取新的特征,提取name变量中的特征
newname <- str_split(Alldata$Name," ")
newname <- sapply(newname, function(x) x[2])

sort(table(newname),decreasing = T)
## newname
##          Mr.        Miss.         Mrs.      Master.          Dr.         Rev. 
##          736          256          191           59            8            8 
##            y         Col.      Planke,    Billiard,        Impe,       Carlo, 
##            8            4            4            3            3            2 
##      Gordon,       Major. Messemaeker,        Mlle.          Ms.       Brito, 
##            2            2            2            2            2            1 
##        Capt.    Cruyssen,          der         Don.    Jonkheer.      Khalil, 
##            1            1            1            1            1            1 
##   Melkebeke,         Mme.      Mulder,   Palmquist,  Pelsmaeker,      Shawah, 
##            1            1            1            1            1            1 
##       Steen,          the       Velde,       Walle, 
##            1            1            1            1
## 名字设置为 Mr.  Miss. Mrs.  Master.,其余的使用other代替
newnamepart <- c("Mr.","Miss.","Mrs.","Master.")
newname[!(newname %in% newnamepart)] <- "other"
Alldata$Name <- as.factor(newname)
Alldata$Sex <- as.factor(Alldata$Sex)
Alldata$Embarked <- as.factor(Alldata$Embarked)
summary(Alldata)
##      Pclass           Name         Sex           Age            SibSp       
##  Min.   :1.000   Master.: 59   female:466   Min.   : 0.17   Min.   :0.0000  
##  1st Qu.:2.000   Miss.  :256   male  :843   1st Qu.:22.00   1st Qu.:0.0000  
##  Median :3.000   Mr.    :736                Median :28.00   Median :0.0000  
##  Mean   :2.295   Mrs.   :191                Mean   :29.50   Mean   :0.4989  
##  3rd Qu.:3.000   other  : 67                3rd Qu.:35.00   3rd Qu.:1.0000  
##  Max.   :3.000                              Max.   :80.00   Max.   :8.0000  
##      Parch            Fare         Embarked
##  Min.   :0.000   Min.   :  0.000   C:270   
##  1st Qu.:0.000   1st Qu.:  7.896   Q:123   
##  Median :0.000   Median : 14.454   S:916   
##  Mean   :0.385   Mean   : 33.295           
##  3rd Qu.:0.000   3rd Qu.: 31.275           
##  Max.   :9.000   Max.   :512.329
str(Alldata)
## Classes 'tbl_df', 'tbl' and 'data.frame':    1309 obs. of  8 variables:
##  $ Pclass  : num  3 1 3 1 3 3 1 3 3 2 ...
##  $ Name    : Factor w/ 5 levels "Master.","Miss.",..: 3 4 2 4 3 3 3 1 4 4 ...
##  $ Sex     : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
##  $ Age     : num  22 38 26 35 35 28 54 2 27 14 ...
##  $ SibSp   : num  1 1 0 1 0 0 0 3 0 1 ...
##  $ Parch   : num  0 0 0 0 0 0 0 1 2 0 ...
##  $ Fare    : num  7.25 71.28 7.92 53.1 8.05 ...
##  $ Embarked: Factor w/ 3 levels "C","Q","S": 3 1 3 3 3 2 3 3 3 1 ...
## 有多少兄弟姐妹/配偶同船:'SibSp', 有多少父母/子女同船:'Parch',进行相加生成新的特征
#Alldata$SP <- Alldata$SibSp + Alldata$Parch
#Alldata$SibSp <- NULL
#Alldata$Parch <- NULL
#Alldata$SP <- cut_width(Alldata$SP,3)

#Alldata$Age <- cut_number(Alldata$Age,6)
#Alldata$Fare <- cut_number(Alldata$Fare,5)
#Alldata$Pclass <- as.factor(Alldata$Pclass)

#table(Alldata$Pclass)
## 与处理好的训练数据和测试数据分开
Ttrainp <- Alldata[1:nrow(Ttrain),]
Ttrainp$Survived <- Survived
Ttestp <- Alldata[(nrow(Ttrain)+1):nrow(Alldata),]
str(Ttrainp)
## Classes 'tbl_df', 'tbl' and 'data.frame':    891 obs. of  9 variables:
##  $ Pclass  : num  3 1 3 1 3 3 1 3 3 2 ...
##  $ Name    : Factor w/ 5 levels "Master.","Miss.",..: 3 4 2 4 3 3 3 1 4 4 ...
##  $ Sex     : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
##  $ Age     : num  22 38 26 35 35 28 54 2 27 14 ...
##  $ SibSp   : num  1 1 0 1 0 0 0 3 0 1 ...
##  $ Parch   : num  0 0 0 0 0 0 0 1 2 0 ...
##  $ Fare    : num  7.25 71.28 7.92 53.1 8.05 ...
##  $ Embarked: Factor w/ 3 levels "C","Q","S": 3 1 3 3 3 2 3 3 3 1 ...
##  $ Survived: num  0 1 1 1 0 0 0 0 1 1 ...
## 将处理好的训练数据集保存为文件

write.csv(Ttrainp,"data/chap9/Titanic处理后数据.csv",row.names = F)

数据集切分和模型建立

## 切分训练数据集为训练集和测试集,70%训练
set.seed(123)
CDP <- createDataPartition(Ttrainp$Survived,p = 0.8)
train_data <- Ttrainp[CDP$Resample1,]
test_data <- Ttrainp[-CDP$Resample1,]

mod1 <- rpart(Survived~.,data = train_data,method="class",cp = 0.000001)
summary(mod1)
## Call:
## rpart(formula = Survived ~ ., data = train_data, method = "class", 
##     cp = 1e-06)
##   n= 713 
## 
##            CP nsplit rel error    xerror       xstd
## 1 0.450549451      0 1.0000000 1.0000000 0.04754450
## 2 0.051282051      1 0.5494505 0.5604396 0.04015432
## 3 0.014652015      3 0.4468864 0.4615385 0.03730757
## 4 0.010989011      4 0.4322344 0.4798535 0.03787823
## 5 0.007326007      6 0.4102564 0.5091575 0.03874859
## 6 0.003663004     12 0.3589744 0.4835165 0.03798985
## 7 0.000001000     15 0.3479853 0.4798535 0.03787823
## 
## Variable importance
##     Name      Sex     Fare      Age    Parch   Pclass    SibSp Embarked 
##       26       21       15       11        9        9        7        2 
## 
## Node number 1: 713 observations,    complexity param=0.4505495
##   predicted class=0  expected loss=0.3828892  P(node) =1
##     class counts:   440   273
##    probabilities: 0.617 0.383 
##   left son=2 (444 obs) right son=3 (269 obs)
##   Primary splits:
##       Name     splits as  RRLRL,        improve=103.27050, (0 missing)
##       Sex      splits as  RL,           improve=101.59730, (0 missing)
##       Pclass   < 2.5      to the right, improve= 31.70004, (0 missing)
##       Fare     < 50.9875  to the left,  improve= 26.05713, (0 missing)
##       Embarked splits as  RLL,          improve= 12.38503, (0 missing)
##   Surrogate splits:
##       Sex   splits as  RL,           agree=0.938, adj=0.836, (0 split)
##       Parch < 0.5      to the left,  agree=0.735, adj=0.297, (0 split)
##       Age   < 15.5     to the right, agree=0.701, adj=0.208, (0 split)
##       SibSp < 0.5      to the left,  agree=0.673, adj=0.134, (0 split)
##       Fare  < 77.6229  to the left,  agree=0.651, adj=0.074, (0 split)
## 
## Node number 2: 444 observations,    complexity param=0.007326007
##   predicted class=0  expected loss=0.1734234  P(node) =0.6227209
##     class counts:   367    77
##    probabilities: 0.827 0.173 
##   left son=4 (339 obs) right son=5 (105 obs)
##   Primary splits:
##       Pclass   < 1.5      to the right, improve=11.845680, (0 missing)
##       Fare     < 26.26875 to the left,  improve=10.478310, (0 missing)
##       Sex      splits as  RL,           improve= 7.211408, (0 missing)
##       Embarked splits as  RLL,          improve= 6.085100, (0 missing)
##       Name     splits as  --L-R,        improve= 2.929979, (0 missing)
##   Surrogate splits:
##       Fare     < 26.26875 to the left,  agree=0.928, adj=0.695, (0 split)
##       Age      < 44.5     to the left,  agree=0.795, adj=0.133, (0 split)
##       Embarked splits as  RLL,          agree=0.768, adj=0.019, (0 split)
## 
## Node number 3: 269 observations,    complexity param=0.05128205
##   predicted class=1  expected loss=0.2713755  P(node) =0.3772791
##     class counts:    73   196
##    probabilities: 0.271 0.729 
##   left son=6 (133 obs) right son=7 (136 obs)
##   Primary splits:
##       Pclass < 2.5      to the right, improve=24.854190, (0 missing)
##       SibSp  < 2.5      to the right, improve=20.063110, (0 missing)
##       Fare   < 48.2     to the left,  improve= 7.790478, (0 missing)
##       Parch  < 3.5      to the right, improve= 4.932072, (0 missing)
##       Age    < 12       to the left,  improve= 3.221960, (0 missing)
##   Surrogate splits:
##       Fare     < 25.73335 to the left,  agree=0.747, adj=0.489, (0 split)
##       Age      < 28.5     to the left,  agree=0.699, adj=0.391, (0 split)
##       Name     splits as  LL-R-,        agree=0.625, adj=0.241, (0 split)
##       Embarked splits as  RLR,          agree=0.606, adj=0.203, (0 split)
##       SibSp    < 1.5      to the right, agree=0.602, adj=0.195, (0 split)
## 
## Node number 4: 339 observations
##   predicted class=0  expected loss=0.1091445  P(node) =0.4754558
##     class counts:   302    37
##    probabilities: 0.891 0.109 
## 
## Node number 5: 105 observations,    complexity param=0.007326007
##   predicted class=0  expected loss=0.3809524  P(node) =0.1472651
##     class counts:    65    40
##    probabilities: 0.619 0.381 
##   left son=10 (18 obs) right son=11 (87 obs)
##   Primary splits:
##       Age      < 53       to the right, improve=3.1636560, (0 missing)
##       Embarked splits as  RLL,          improve=2.3172550, (0 missing)
##       Fare     < 25.9271  to the left,  improve=2.1768710, (0 missing)
##       Name     splits as  --L-R,        improve=0.9803579, (0 missing)
##       Parch    < 0.5      to the right, improve=0.4625068, (0 missing)
## 
## Node number 6: 133 observations,    complexity param=0.05128205
##   predicted class=1  expected loss=0.4887218  P(node) =0.1865358
##     class counts:    65    68
##    probabilities: 0.489 0.511 
##   left son=12 (34 obs) right son=13 (99 obs)
##   Primary splits:
##       Fare     < 24.80835 to the right, improve=16.349110, (0 missing)
##       SibSp    < 2.5      to the right, improve=11.585210, (0 missing)
##       Embarked splits as  RRL,          improve= 5.596270, (0 missing)
##       Age      < 38.5     to the right, improve= 3.658035, (0 missing)
##       Parch    < 1.5      to the right, improve= 3.583283, (0 missing)
##   Surrogate splits:
##       SibSp < 2.5      to the right, agree=0.910, adj=0.647, (0 split)
##       Parch < 1.5      to the right, agree=0.842, adj=0.382, (0 split)
##       Name  splits as  LR-R-,        agree=0.774, adj=0.118, (0 split)
##       Sex   splits as  RL,           agree=0.774, adj=0.118, (0 split)
##       Age   < 37.5     to the right, agree=0.767, adj=0.088, (0 split)
## 
## Node number 7: 136 observations
##   predicted class=1  expected loss=0.05882353  P(node) =0.1907433
##     class counts:     8   128
##    probabilities: 0.059 0.941 
## 
## Node number 10: 18 observations
##   predicted class=0  expected loss=0.1111111  P(node) =0.02524544
##     class counts:    16     2
##    probabilities: 0.889 0.111 
## 
## Node number 11: 87 observations,    complexity param=0.007326007
##   predicted class=0  expected loss=0.4367816  P(node) =0.1220196
##     class counts:    49    38
##    probabilities: 0.563 0.437 
##   left son=22 (7 obs) right son=23 (80 obs)
##   Primary splits:
##       Fare     < 25.9271  to the left,  improve=2.9045980, (0 missing)
##       Embarked splits as  RLL,          improve=1.6625330, (0 missing)
##       Age      < 27.5     to the right, improve=1.3891050, (0 missing)
##       Name     splits as  --L-R,        improve=0.9657088, (0 missing)
##       Parch    < 0.5      to the right, improve=0.2979310, (0 missing)
## 
## Node number 12: 34 observations
##   predicted class=0  expected loss=0.08823529  P(node) =0.04768583
##     class counts:    31     3
##    probabilities: 0.912 0.088 
## 
## Node number 13: 99 observations,    complexity param=0.01465201
##   predicted class=1  expected loss=0.3434343  P(node) =0.1388499
##     class counts:    34    65
##    probabilities: 0.343 0.657 
##   left son=26 (14 obs) right son=27 (85 obs)
##   Primary splits:
##       Age      < 28.5     to the right, improve=2.923776, (0 missing)
##       Embarked splits as  LRL,          improve=2.250789, (0 missing)
##       Fare     < 8.0396   to the right, improve=1.662338, (0 missing)
##       Name     splits as  RL-L-,        improve=1.068687, (0 missing)
##       Sex      splits as  LR,           improve=1.068687, (0 missing)
## 
## Node number 22: 7 observations
##   predicted class=0  expected loss=0  P(node) =0.009817672
##     class counts:     7     0
##    probabilities: 1.000 0.000 
## 
## Node number 23: 80 observations,    complexity param=0.007326007
##   predicted class=0  expected loss=0.475  P(node) =0.112202
##     class counts:    42    38
##    probabilities: 0.525 0.475 
##   left son=46 (68 obs) right son=47 (12 obs)
##   Primary splits:
##       Fare     < 27.1354  to the right, improve=3.6254900, (0 missing)
##       Name     splits as  --L-R,        improve=0.9562771, (0 missing)
##       Age      < 27.5     to the right, improve=0.9000000, (0 missing)
##       Embarked splits as  RLL,          improve=0.8309463, (0 missing)
##       Parch    < 0.5      to the right, improve=0.5666667, (0 missing)
## 
## Node number 26: 14 observations
##   predicted class=0  expected loss=0.3571429  P(node) =0.01963534
##     class counts:     9     5
##    probabilities: 0.643 0.357 
## 
## Node number 27: 85 observations,    complexity param=0.01098901
##   predicted class=1  expected loss=0.2941176  P(node) =0.1192146
##     class counts:    25    60
##    probabilities: 0.294 0.706 
##   left son=54 (52 obs) right son=55 (33 obs)
##   Primary splits:
##       Fare     < 8.0396   to the right, improve=2.1938850, (0 missing)
##       Embarked splits as  LRL,          improve=1.9129700, (0 missing)
##       Age      < 7        to the right, improve=0.7669547, (0 missing)
##       Parch    < 1.5      to the left,  improve=0.6742346, (0 missing)
##       Name     splits as  RL-L-,        improve=0.6742346, (0 missing)
##   Surrogate splits:
##       Embarked splits as  LRL,          agree=0.753, adj=0.364, (0 split)
##       SibSp    < 0.5      to the right, agree=0.729, adj=0.303, (0 split)
##       Parch    < 0.5      to the right, agree=0.718, adj=0.273, (0 split)
##       Name     splits as  LR-L-,        agree=0.647, adj=0.091, (0 split)
## 
## Node number 46: 68 observations,    complexity param=0.007326007
##   predicted class=0  expected loss=0.4117647  P(node) =0.09537167
##     class counts:    40    28
##    probabilities: 0.588 0.412 
##   left son=92 (35 obs) right son=93 (33 obs)
##   Primary splits:
##       Embarked splits as  RLL,          improve=2.2918260, (0 missing)
##       Age      < 27.5     to the right, improve=1.9027150, (0 missing)
##       Name     splits as  --L-R,        improve=1.8935570, (0 missing)
##       Fare     < 127.8166 to the right, improve=0.4745098, (0 missing)
##       Parch    < 0.5      to the right, improve=0.1792717, (0 missing)
##   Surrogate splits:
##       Age   < 27.5     to the right, agree=0.662, adj=0.303, (0 split)
##       Fare  < 61.2771  to the left,  agree=0.647, adj=0.273, (0 split)
##       Name  splits as  --L-R,        agree=0.603, adj=0.182, (0 split)
##       Parch < 0.5      to the left,  agree=0.574, adj=0.121, (0 split)
##       Sex   splits as  RL,           agree=0.559, adj=0.091, (0 split)
## 
## Node number 47: 12 observations
##   predicted class=1  expected loss=0.1666667  P(node) =0.01683029
##     class counts:     2    10
##    probabilities: 0.167 0.833 
## 
## Node number 54: 52 observations,    complexity param=0.01098901
##   predicted class=1  expected loss=0.3846154  P(node) =0.07293128
##     class counts:    20    32
##    probabilities: 0.385 0.615 
##   left son=108 (10 obs) right son=109 (42 obs)
##   Primary splits:
##       Fare  < 10.825   to the left,  improve=4.272527, (0 missing)
##       Parch < 0.5      to the left,  improve=2.690347, (0 missing)
##       Age   < 7        to the right, improve=2.239445, (0 missing)
##       Name  splits as  RL-L-,        improve=1.628305, (0 missing)
##       Sex   splits as  LR,           improve=1.628305, (0 missing)
## 
## Node number 55: 33 observations
##   predicted class=1  expected loss=0.1515152  P(node) =0.04628331
##     class counts:     5    28
##    probabilities: 0.152 0.848 
## 
## Node number 92: 35 observations
##   predicted class=0  expected loss=0.2857143  P(node) =0.04908836
##     class counts:    25    10
##    probabilities: 0.714 0.286 
## 
## Node number 93: 33 observations,    complexity param=0.007326007
##   predicted class=1  expected loss=0.4545455  P(node) =0.04628331
##     class counts:    15    18
##    probabilities: 0.455 0.545 
##   left son=186 (11 obs) right son=187 (22 obs)
##   Primary splits:
##       Fare  < 98.7521  to the right, improve=1.0909090, (0 missing)
##       Age   < 26.5     to the right, improve=1.0909090, (0 missing)
##       Parch < 0.5      to the right, improve=0.6136364, (0 missing)
##       Name  splits as  --L-R,        improve=0.3636364, (0 missing)
##       SibSp < 0.5      to the left,  improve=0.2727273, (0 missing)
##   Surrogate splits:
##       Age   < 22.5     to the left,  agree=0.788, adj=0.364, (0 split)
##       Parch < 0.5      to the right, agree=0.727, adj=0.182, (0 split)
## 
## Node number 108: 10 observations
##   predicted class=0  expected loss=0.2  P(node) =0.01402525
##     class counts:     8     2
##    probabilities: 0.800 0.200 
## 
## Node number 109: 42 observations,    complexity param=0.003663004
##   predicted class=1  expected loss=0.2857143  P(node) =0.05890603
##     class counts:    12    30
##    probabilities: 0.286 0.714 
##   left son=218 (29 obs) right son=219 (13 obs)
##   Primary splits:
##       Age      < 7        to the right, improve=1.6415310, (0 missing)
##       Fare     < 15.3729  to the left,  improve=1.5873020, (0 missing)
##       Embarked splits as  LRR,          improve=0.9075630, (0 missing)
##       Name     splits as  RL-L-,        improve=0.6984127, (0 missing)
##       Sex      splits as  LR,           improve=0.6984127, (0 missing)
##   Surrogate splits:
##       SibSp < 1.5      to the left,  agree=0.738, adj=0.154, (0 split)
##       Name  splits as  RL-L-,        agree=0.714, adj=0.077, (0 split)
##       Sex   splits as  LR,           agree=0.714, adj=0.077, (0 split)
##       Fare  < 11.1875  to the right, agree=0.714, adj=0.077, (0 split)
## 
## Node number 186: 11 observations
##   predicted class=0  expected loss=0.3636364  P(node) =0.01542777
##     class counts:     7     4
##    probabilities: 0.636 0.364 
## 
## Node number 187: 22 observations,    complexity param=0.003663004
##   predicted class=1  expected loss=0.3636364  P(node) =0.03085554
##     class counts:     8    14
##    probabilities: 0.364 0.636 
##   left son=374 (15 obs) right son=375 (7 obs)
##   Primary splits:
##       Age   < 27.5     to the right, improve=2.7151520, (0 missing)
##       Fare  < 30.8479  to the left,  improve=1.7175320, (0 missing)
##       SibSp < 0.5      to the left,  improve=1.0008660, (0 missing)
##       Name  splits as  --L-R,        improve=0.1246753, (0 missing)
##   Surrogate splits:
##       Sex  splits as  RL,           agree=0.727, adj=0.143, (0 split)
##       Fare < 44.5521  to the left,  agree=0.727, adj=0.143, (0 split)
## 
## Node number 218: 29 observations,    complexity param=0.003663004
##   predicted class=1  expected loss=0.3793103  P(node) =0.04067321
##     class counts:    11    18
##    probabilities: 0.379 0.621 
##   left son=436 (14 obs) right son=437 (15 obs)
##   Primary splits:
##       Fare     < 15.3729  to the left,  improve=1.99803000, (0 missing)
##       Age      < 21       to the left,  improve=1.48675100, (0 missing)
##       Embarked splits as  LRR,          improve=1.19363400, (0 missing)
##       Name     splits as  RL-R-,        improve=0.59634890, (0 missing)
##       Parch    < 0.5      to the left,  improve=0.02660099, (0 missing)
##   Surrogate splits:
##       Embarked splits as  LRR,          agree=0.828, adj=0.643, (0 split)
##       Age      < 17.5     to the left,  agree=0.586, adj=0.143, (0 split)
##       SibSp    < 0.5      to the left,  agree=0.586, adj=0.143, (0 split)
##       Parch    < 0.5      to the right, agree=0.586, adj=0.143, (0 split)
## 
## Node number 219: 13 observations
##   predicted class=1  expected loss=0.07692308  P(node) =0.01823282
##     class counts:     1    12
##    probabilities: 0.077 0.923 
## 
## Node number 374: 15 observations
##   predicted class=0  expected loss=0.4666667  P(node) =0.02103787
##     class counts:     8     7
##    probabilities: 0.533 0.467 
## 
## Node number 375: 7 observations
##   predicted class=1  expected loss=0  P(node) =0.009817672
##     class counts:     0     7
##    probabilities: 0.000 1.000 
## 
## Node number 436: 14 observations
##   predicted class=0  expected loss=0.4285714  P(node) =0.01963534
##     class counts:     8     6
##    probabilities: 0.571 0.429 
## 
## Node number 437: 15 observations
##   predicted class=1  expected loss=0.2  P(node) =0.02103787
##     class counts:     3    12
##    probabilities: 0.200 0.800
## 看变量重要性
mod1$variable.importance
##      Name       Sex      Fare       Age     Parch    Pclass     SibSp  Embarked 
## 111.91631  89.02461  61.63420  46.05885  38.32346  36.69987  30.46093   9.64527
#cp是每次分割对应的复杂度系数
mod1$cp
##            CP nsplit rel error    xerror       xstd
## 1 0.450549451      0 1.0000000 1.0000000 0.04754450
## 2 0.051282051      1 0.5494505 0.5604396 0.04015432
## 3 0.014652015      3 0.4468864 0.4615385 0.03730757
## 4 0.010989011      4 0.4322344 0.4798535 0.03787823
## 5 0.007326007      6 0.4102564 0.5091575 0.03874859
## 6 0.003663004     12 0.3589744 0.4835165 0.03798985
## 7 0.000001000     15 0.3479853 0.4798535 0.03787823
## plot cross-validation results
plotcp(mod1)

par(family = "STKaiti")
rpart.plot(mod1, type = 2,extra="auto", under=TRUE, 
           fallen.leaves = FALSE,cex=0.7, main="决策树")

## 查看模型在训练集和测试集上的预测效果
pre_train <- predict(mod1,train_data,type = "prob")
pre_train2<-as.factor(as.vector(ifelse(pre_train[,2]>0.5,1,0)))
pre_test <- predict(mod1,test_data)
pre_test2<-as.factor(as.vector(ifelse(pre_test[,2]>0.5,1,0)))
sprintf("决策树模型在训练集精度为:%f",accuracy(train_data$Survived,pre_train2))
## [1] "决策树模型在训练集精度为:0.866760"
sprintf("决策树模型在测试集精度为:%f",accuracy(test_data$Survived,pre_test2))
## [1] "决策树模型在测试集精度为:0.820225"
## 计算混淆矩阵和模型的精度
cfm <- confusionMatrix(pre_test2,as.factor(test_data$Survived))
cfm$table
##           Reference
## Prediction  0  1
##          0 97 20
##          1 12 49
cfm
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 97 20
##          1 12 49
##                                           
##                Accuracy : 0.8202          
##                  95% CI : (0.7558, 0.8737)
##     No Information Rate : 0.6124          
##     P-Value [Acc > NIR] : 1.645e-09       
##                                           
##                   Kappa : 0.6131          
##                                           
##  Mcnemar's Test P-Value : 0.2159          
##                                           
##             Sensitivity : 0.8899          
##             Specificity : 0.7101          
##          Pos Pred Value : 0.8291          
##          Neg Pred Value : 0.8033          
##              Prevalence : 0.6124          
##          Detection Rate : 0.5449          
##    Detection Prevalence : 0.6573          
##       Balanced Accuracy : 0.8000          
##                                           
##        'Positive' Class : 0               
## 

决策树的优化(提高准确率,剪枝等相关操作)

bestcp <- mod1$cptable[which.min(mod1$cptable[,"xerror"]),"CP"]
bestcp
## [1] 0.01465201
# Step3: Prune the tree using the best cp.
mod1.pruned <- prune(mod1, cp = bestcp)

## plot cross-validation results
plotcp(mod1.pruned)

## 可视化剪枝后的决策树
par(family = "STKaiti")
rpart.plot(mod1.pruned, type = 2,extra="auto", under=TRUE, 
           fallen.leaves = FALSE,cex=0.7, main="剪枝后决策树")

## 查看剪枝后模型在训练集和测试集上的预测效果
pre_train_p <- predict(mod1.pruned,train_data)
pre_train_p2<-as.factor(as.vector(ifelse(pre_train_p[,2]>0.5,1,0)))

pre_test_p <- predict(mod1.pruned,test_data)
pre_test_p2<-as.factor(as.vector(ifelse(pre_test_p[,2]>0.5,1,0)))

sprintf("剪枝后决策树模型在训练集精度为:%f",accuracy(train_data$Survived,pre_train_p2))
## [1] "剪枝后决策树模型在训练集精度为:0.828892"
sprintf("剪枝后决策树模型在测试集精度为:%f",accuracy(test_data$Survived,pre_test_p2))
## [1] "剪枝后决策树模型在测试集精度为:0.808989"
## 计算混淆矩阵和模型的精度
cfm <- confusionMatrix(pre_test_p2,as.factor(test_data$Survived))
cfm$table
##           Reference
## Prediction  0  1
##          0 94 19
##          1 15 50
cfm
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 94 19
##          1 15 50
##                                           
##                Accuracy : 0.809           
##                  95% CI : (0.7434, 0.8639)
##     No Information Rate : 0.6124          
##     P-Value [Acc > NIR] : 1.294e-08       
##                                           
##                   Kappa : 0.5933          
##                                           
##  Mcnemar's Test P-Value : 0.6069          
##                                           
##             Sensitivity : 0.8624          
##             Specificity : 0.7246          
##          Pos Pred Value : 0.8319          
##          Neg Pred Value : 0.7692          
##              Prevalence : 0.6124          
##          Detection Rate : 0.5281          
##    Detection Prevalence : 0.6348          
##       Balanced Accuracy : 0.7935          
##                                           
##        'Positive' Class : 0               
## 
## 绘制决策树剪枝前后在测试集上的ROC曲线
## 绘制出ROC曲线对比两种模型的效果
## 计算决策树模型的ROC坐标
pr <- prediction(pre_test[,2], test_data$Survived)
prf <- performance(pr, measure = "tpr", x.measure = "fpr")
prfdf <- data.frame(x = prf@x.values[[1]],
                    y = prf@y.values[[1]],
                    model = "rpart")
## 计算剪枝后决策树模型的ROC坐标
pr <- prediction(pre_test_p[,2], test_data$Survived)
prf <- performance(pr, measure = "tpr", x.measure = "fpr")
prfdf2 <- data.frame(x = prf@x.values[[1]],
                    y = prf@y.values[[1]],
                    model = "rpart.prund")

## 合并数据
prfdf <- rbind.data.frame(prfdf,prfdf2)
## plot ROC
ggplot(prfdf,aes(x= x,y = y,colour = model))+
  geom_line(aes(linetype = model),size = 1)+
  theme(aspect.ratio=1)+
  labs(x = "假正例率",y = "真正例率")

## 计算AUC的取值
auc(test_data$Survived,as.vector(pre_test[,2]))
## [1] 0.8704295
auc(test_data$Survived,as.vector(pre_test_p[,2]))
## [1] 0.8336657

9.2:随机森林模型

一个实际的数据集作为实例,分类和回归

随机森林分类

library(tidyr)
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
library(ggRandomForests)
## Loading required package: randomForestSRC
## 
##  randomForestSRC 2.9.2 
##  
##  Type rfsrc.news() to see new features, changes, and bug fixes. 
## 
## 
## Attaching package: 'ggRandomForests'
## The following object is masked from 'package:randomForestSRC':
## 
##     partial.rfsrc
library(caret)

train_data$Survived <- as.factor(train_data$Survived)
rfcla <- randomForest(Survived~.,data = train_data,ntree=200, proximity=TRUE)
summary(rfcla)
##                 Length Class  Mode     
## call                 5 -none- call     
## type                 1 -none- character
## predicted          713 factor numeric  
## err.rate           600 -none- numeric  
## confusion            6 -none- numeric  
## votes             1426 matrix numeric  
## oob.times          713 -none- numeric  
## classes              2 -none- character
## importance           8 -none- numeric  
## importanceSD         0 -none- NULL     
## localImportance      0 -none- NULL     
## proximity       508369 -none- numeric  
## ntree                1 -none- numeric  
## mtry                 1 -none- numeric  
## forest              14 -none- list     
## y                  713 factor numeric  
## test                 0 -none- NULL     
## inbag                0 -none- NULL     
## terms                3 terms  call
## 可视化随机森林的训练过程,随着树的增加训练误差的变化
trainerror <- as.data.frame(plot(rfcla,type = "l"))

colnames(trainerror) <- paste("error",colnames(trainerror),sep = "")
trainerror$ntree <- 1:nrow(trainerror)
trainerror <- gather(trainerror,key = "Type",value = "Error",1:3)
ggplot(trainerror,aes(x = ntree,y = Error))+
  geom_line(aes(linetype = Type,colour = Type))+
  #theme(legend.position = "bottom")+
  ggtitle("随机森林分类模型")+
  theme(plot.title = element_text(hjust = 0.5))

##  模型的精度逐渐趋于平稳

## 或者使用 ggRandomForests包可视化误差
plot(gg_error(rfcla))

## 从randomForest绘制邻近矩阵的标准化坐标。
MDSplot(rfcla,train_data$Survived,k = 2 , palette=c(1, 2),
        pch=20+as.numeric(train_data$Survived))

## 可视化变量的重要性
importance(rfcla)
##          MeanDecreaseGini
## Pclass          22.696178
## Name            55.122124
## Sex             52.632961
## Age             33.436435
## SibSp           13.921757
## Parch            9.361209
## Fare            42.790499
## Embarked        10.316719
varImpPlot(rfcla,pch = 20, main = "Importance of Variables")

## 查看模型在测试集上的精度
rfclapre<- predict(rfcla,test_data)
sprintf("随机森林模型测试集精度为:%f",accuracy(test_data$Survived,rfclapre))
## [1] "随机森林模型测试集精度为:0.837079"

使用所有的训练集训练随机森林模型,预测测试集

Ttrainp$Survived <- as.factor(Ttrainp$Survived)
rfclanew <- randomForest(Survived~.,data = Ttrainp,ntree=200, proximity=TRUE)
summary(rfclanew)
##                 Length Class  Mode     
## call                 5 -none- call     
## type                 1 -none- character
## predicted          891 factor numeric  
## err.rate           600 -none- numeric  
## confusion            6 -none- numeric  
## votes             1782 matrix numeric  
## oob.times          891 -none- numeric  
## classes              2 -none- character
## importance           8 -none- numeric  
## importanceSD         0 -none- NULL     
## localImportance      0 -none- NULL     
## proximity       793881 -none- numeric  
## ntree                1 -none- numeric  
## mtry                 1 -none- numeric  
## forest              14 -none- list     
## y                  891 factor numeric  
## test                 0 -none- NULL     
## inbag                0 -none- NULL     
## terms                3 terms  call
rfclanew
## 
## Call:
##  randomForest(formula = Survived ~ ., data = Ttrainp, ntree = 200,      proximity = TRUE) 
##                Type of random forest: classification
##                      Number of trees: 200
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 16.05%
## Confusion matrix:
##     0   1 class.error
## 0 502  47   0.0856102
## 1  96 246   0.2807018
## 预测测试集
Ttestpre <- predict(rfclanew,Ttestp)
table(Ttestpre)
## Ttestpre
##   0   1 
## 277 141

随机森林回归

## 使用随机森林的方法,对ENB2012数据进行回归分析
ENB <- read_excel("data/chap9/ENB2012.xlsx")
head(ENB)
## # A tibble: 6 x 9
##      X1    X2    X3    X4    X5    X6    X7    X8    Y1
##   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1  0.98  514.  294   110.     7     2     0     0  15.6
## 2  0.98  514.  294   110.     7     3     0     0  15.6
## 3  0.98  514.  294   110.     7     4     0     0  15.6
## 4  0.98  514.  294   110.     7     5     0     0  15.6
## 5  0.9   564.  318.  122.     7     2     0     0  20.8
## 6  0.9   564.  318.  122.     7     3     0     0  21.5
summary(ENB)
##        X1               X2              X3              X4       
##  Min.   :0.6200   Min.   :514.5   Min.   :245.0   Min.   :110.2  
##  1st Qu.:0.6825   1st Qu.:606.4   1st Qu.:294.0   1st Qu.:140.9  
##  Median :0.7500   Median :673.8   Median :318.5   Median :183.8  
##  Mean   :0.7642   Mean   :671.7   Mean   :318.5   Mean   :176.6  
##  3rd Qu.:0.8300   3rd Qu.:741.1   3rd Qu.:343.0   3rd Qu.:220.5  
##  Max.   :0.9800   Max.   :808.5   Max.   :416.5   Max.   :220.5  
##        X5             X6             X7               X8              Y1       
##  Min.   :3.50   Min.   :2.00   Min.   :0.0000   Min.   :0.000   Min.   : 6.01  
##  1st Qu.:3.50   1st Qu.:2.75   1st Qu.:0.1000   1st Qu.:1.750   1st Qu.:12.99  
##  Median :5.25   Median :3.50   Median :0.2500   Median :3.000   Median :18.95  
##  Mean   :5.25   Mean   :3.50   Mean   :0.2344   Mean   :2.812   Mean   :22.31  
##  3rd Qu.:7.00   3rd Qu.:4.25   3rd Qu.:0.4000   3rd Qu.:4.000   3rd Qu.:31.67  
##  Max.   :7.00   Max.   :5.00   Max.   :0.4000   Max.   :5.000   Max.   :43.10
str(ENB)
## Classes 'tbl_df', 'tbl' and 'data.frame':    768 obs. of  9 variables:
##  $ X1: num  0.98 0.98 0.98 0.98 0.9 0.9 0.9 0.9 0.86 0.86 ...
##  $ X2: num  514 514 514 514 564 ...
##  $ X3: num  294 294 294 294 318 ...
##  $ X4: num  110 110 110 110 122 ...
##  $ X5: num  7 7 7 7 7 7 7 7 7 7 ...
##  $ X6: num  2 3 4 5 2 3 4 5 2 3 ...
##  $ X7: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ X8: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ Y1: num  15.6 15.6 15.6 15.6 20.8 ...
## 数据切分为训练集和测试集,训练集70%
set.seed(12)
index <- sample(nrow(ENB),round(nrow(ENB)*0.7))
trainEnb <- ENB[index,]
testENB <- ENB[-index,]
## 建立随机森林回归模型
rfreg <- randomForest(Y1~.,data = trainEnb,ntree=500)
summary(rfreg)
##                 Length Class  Mode     
## call              4    -none- call     
## type              1    -none- character
## predicted       538    -none- numeric  
## mse             500    -none- numeric  
## rsq             500    -none- numeric  
## oob.times       538    -none- numeric  
## importance        8    -none- numeric  
## importanceSD      0    -none- NULL     
## localImportance   0    -none- NULL     
## proximity         0    -none- NULL     
## ntree             1    -none- numeric  
## mtry              1    -none- numeric  
## forest           11    -none- list     
## coefs             0    -none- NULL     
## y               538    -none- numeric  
## test              0    -none- NULL     
## inbag             0    -none- NULL     
## terms             3    terms  call
## 可视化模型随着树的增加误差OOB的变化
par(family = "STKaiti")
plot(rfreg,type = "l",col = "red",main = "随机森林回归")

## 使用ggrandomforest包可视化误差
plot(gg_error(rfreg))+labs(title = "随机森林回归")

## 可视化变量的重要性
importance(rfreg)
##    IncNodePurity
## X1   14525.01820
## X2   12354.57214
## X3    3224.53478
## X4    8846.38752
## X5   10807.09663
## X6      77.46897
## X7    2353.68826
## X8     945.38532
varImpPlot(rfreg,pch = 20, main = "Importance of Variables")

## 对测试集进行预测,并计算 Mean Squared Error
rfpre <- predict(rfreg,testENB)
sprintf("均方根误差为: %f",mse(testENB$Y1,rfpre))
## [1] "均方根误差为: 1.332674"
## 参数搜索,寻找合适的 mtry参数,训练更好的模型
## Tune randomForest for the optimal mtry parameter
set.seed(1234)
rftune <- tuneRF(x = trainEnb[,1:8],y = trainEnb$Y1,
                 stepFactor=1.5,ntreeTry = 500)
## mtry = 2  OOB error = 1.335695 
## Searching left ...
## Searching right ...
## mtry = 3     OOB error = 0.6641563 
## 0.5027634 0.05 
## mtry = 4     OOB error = 0.417198 
## 0.3718375 0.05 
## mtry = 6     OOB error = 0.3555105 
## 0.1478614 0.05 
## mtry = 8     OOB error = 0.3629043 
## -0.02079754 0.05

print(rftune)
##   mtry  OOBError
## 2    2 1.3356946
## 3    3 0.6641563
## 4    4 0.4171980
## 6    6 0.3555105
## 8    8 0.3629043
## OOBError误差最小的mtry参数为6

## 建立优化后的随机森林回归模型
rfregbest <- randomForest(Y1~.,data = trainEnb,ntree=500,mtry = 6)

## 可视化两种模型随着树的增加误差OOB的变化
rfregerr <- as.data.frame(plot(rfreg))

colnames(rfregerr) <- "rfregerr"
rfregbesterr <- as.data.frame(plot(rfregbest))

colnames(rfregbesterr) <- "rfregbesterr"
plotrfdata <- cbind.data.frame(rfregerr,rfregbesterr)
plotrfdata$ntree <- 1:nrow(plotrfdata)
plotrfdata <- gather(plotrfdata,key = "Type",value = "Error",1:2)
ggplot(plotrfdata,aes(x = ntree,y = Error))+
  geom_line(aes(linetype = Type,colour = Type),size = 0.9)+
  theme(legend.position = "top")+
  ggtitle("随机森林回归模型")+
  theme(plot.title = element_text(hjust = 0.5))

## 使用优化后的随机森林回归模型,对测试集进行预测,并计算 Mean Squared Error
rfprebest <- predict(rfregbest,testENB)
sprintf("优化后均方根误差为: %f",mse(testENB$Y1,rfprebest))
## [1] "优化后均方根误差为: 0.421116"
## 数据准备
index <- order(testENB$Y1)
X <- sort(index)
Y1 <- testENB$Y1[index]
rfpre2 <- rfpre[index]
rfprebest2 <- rfprebest[index]

plotdata <- data.frame(X = X,Y1 = Y1,rfpre =rfpre2,rfprebest = rfprebest2)
plotdata <- gather(plotdata,key="model",value="value",c(-X))

## 可视化模型的预测误差
ggplot(plotdata,aes(x = X,y = value))+
  geom_line(aes(linetype = model,colour = model),size = 0.8)+
  theme(legend.position = c(0.1,0.8),
        plot.title = element_text(hjust = 0.5))+
  ggtitle("随机森林回归模型")

## 随机森林回归的效果非常好

9.3:梯度提升机

梯度提升模型是回归或分类树模型的集合。 两者都是前向学习集合方法,通过逐步改进的估计获得预测结果。 Boosting是一种灵活的非线性回归程序,有助于提高树模型的准确性。 通过将弱分类算法顺序地应用于递增改变的数据,创建一系列决策树,其产生弱预测模型的集合。

使用GBM算法对泰坦尼克号数据集进行建模

library(h2o)
## 
## ----------------------------------------------------------------------
## 
## Your next step is to start H2O:
##     > h2o.init()
## 
## For H2O package documentation, ask for help:
##     > ??h2o
## 
## After starting H2O, you can use the Web UI at http://localhost:54321
## For more information visit http://docs.h2o.ai
## 
## ----------------------------------------------------------------------
## 
## Attaching package: 'h2o'
## The following objects are masked from 'package:data.table':
## 
##     hour, month, week, year
## The following objects are masked from 'package:stats':
## 
##     cor, sd, var
## The following objects are masked from 'package:base':
## 
##     &&, %*%, %in%, ||, apply, as.factor, as.numeric, colnames,
##     colnames<-, ifelse, is.character, is.factor, is.numeric, log,
##     log10, log1p, log2, round, signif, trunc
library(Metrics)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following object is masked from 'package:randomForest':
## 
##     combine
## The following objects are masked from 'package:data.table':
## 
##     between, first, last
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
####################################
##说明:下面启动h2o实例时可能需要更新java版本,在Program\Other文件夹里已提供了jdk-8u221-macosx_64插件,安装后重新运行即可。
###################################
## 启动初始化一个h2o实例; 定义为2核同时计算;
h2o.init(nthreads=2,max_mem_size='4G')  
## 
## H2O is not running yet, starting it now...
## 
## Note:  In case of errors look at the following log files:
##     /var/folders/dv/qd8nc5bx6xx3pxq_xx_j47240000gn/T//RtmphvRb3k/h2o_sunlanxin_started_from_r.out
##     /var/folders/dv/qd8nc5bx6xx3pxq_xx_j47240000gn/T//RtmphvRb3k/h2o_sunlanxin_started_from_r.err
## 
## 
## Starting H2O JVM and connecting: . Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         1 seconds 645 milliseconds 
##     H2O cluster timezone:       Asia/Shanghai 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.26.0.2 
##     H2O cluster version age:    4 months and 28 days !!! 
##     H2O cluster name:           H2O_started_from_R_sunlanxin_aui415 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.56 GB 
##     H2O cluster total cores:    8 
##     H2O cluster allowed cores:  2 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     H2O API Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4 
##     R Version:                  R version 3.6.2 (2019-12-12)
## Warning in h2o.clusterInfo(): 
## Your H2O cluster version is too old (4 months and 28 days)!
## Please download and install the latest version from http://h2o.ai/download/
## 读取数据
train<-h2o.uploadFile("data/chap9/Titanic处理后数据.csv",
                         destination_frame = "train.hex")
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
str(train)
## Class 'H2OFrame' <environment: 0x7fc58d67b318> 
##  - attr(*, "op")= chr "Parse"
##  - attr(*, "id")= chr "train.hex"
##  - attr(*, "eval")= logi FALSE
##  - attr(*, "nrow")= int 891
##  - attr(*, "ncol")= int 9
##  - attr(*, "types")=List of 9
##   ..$ : chr "int"
##   ..$ : chr "enum"
##   ..$ : chr "enum"
##   ..$ : chr "real"
##   ..$ : chr "int"
##   ..$ : chr "int"
##   ..$ : chr "real"
##   ..$ : chr "enum"
##   ..$ : chr "int"
##  - attr(*, "data")='data.frame': 10 obs. of  9 variables:
##   ..$ Pclass  : num  3 1 3 1 3 3 1 3 3 2
##   ..$ Name    : Factor w/ 5 levels "Master.","Miss.",..: 3 4 2 4 3 3 3 1 4 4
##   ..$ Sex     : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1
##   ..$ Age     : num  22 38 26 35 35 28 54 2 27 14
##   ..$ SibSp   : num  1 1 0 1 0 0 0 3 0 1
##   ..$ Parch   : num  0 0 0 0 0 0 0 1 2 0
##   ..$ Fare    : num  7.25 71.28 7.92 53.1 8.05 ...
##   ..$ Embarked: Factor w/ 3 levels "C","Q","S": 3 1 3 3 3 2 3 3 3 1
##   ..$ Survived: num  0 1 1 1 0 0 0 0 1 1
train$Survived <- as.factor(train$Survived)
colnames(train)
## [1] "Pclass"   "Name"     "Sex"      "Age"      "SibSp"    "Parch"    "Fare"    
## [8] "Embarked" "Survived"
head(train)
##   Pclass  Name    Sex Age SibSp Parch    Fare Embarked Survived
## 1      3   Mr.   male  22     1     0  7.2500        S        0
## 2      1  Mrs. female  38     1     0 71.2833        C        1
## 3      3 Miss. female  26     0     0  7.9250        S        1
## 4      1  Mrs. female  35     1     0 53.1000        S        1
## 5      3   Mr.   male  35     0     0  8.0500        S        0
## 6      3   Mr.   male  28     0     0  8.4583        Q        0
## 数据切分为训练集和测试集,70%训练集,30%测试集
splits <- h2o.splitFrame(data = train, ratios = 0.7,seed = 1234)
train_data <- splits[[1]]
test_data <- splits[[2]]

dim(train_data)
## [1] 634   9
head(train_data)
##   Pclass    Name    Sex Age SibSp Parch    Fare Embarked Survived
## 1      1    Mrs. female  38     1     0 71.2833        C        1
## 2      3   Miss. female  26     0     0  7.9250        S        1
## 3      3     Mr.   male  35     0     0  8.0500        S        0
## 4      3     Mr.   male  28     0     0  8.4583        Q        0
## 5      1     Mr.   male  54     0     0 51.8625        S        0
## 6      3 Master.   male   2     3     1 21.0750        S        0
dim(test_data)
## [1] 257   9
## GBM模型
name1 <- colnames(train)
predictors <- name1[1:8]
target <- "Survived"
gbm <- h2o.gbm(x = predictors, y = target,
               training_frame = train_data,
               distribution="bernoulli", ## 二分类模型
               ntrees = 100,  ## 模型使用数的树量  
               learn_rate=0.01, ## 学习率
               sample_rate = 0.8,## 每棵树使用80%的样本
               col_sample_rate = 0.6,## 每次拆分使用80%的特征
               seed = 1234)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
summary(gbm)
## Model Details:
## ==============
## 
## H2OBinomialModel: gbm
## Model Key:  GBM_model_R_1577260091654_1 
## Model Summary: 
##   number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1             100                      100               30728         5
##   max_depth mean_depth min_leaves max_leaves mean_leaves
## 1         5    5.00000         14         26    19.79000
## 
## H2OBinomialMetrics: gbm
## ** Reported on training data. **
## 
## MSE:  0.1264054
## RMSE:  0.3555354
## LogLoss:  0.4216086
## Mean Per-Class Error:  0.1451068
## AUC:  0.9286271
## pr_auc:  0.8944934
## Gini:  0.8572543
## R^2:  0.4571642
## 
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
##          0   1    Error     Rate
## 0      354  46 0.115000  =46/400
## 1       41 193 0.175214  =41/234
## Totals 395 239 0.137224  =87/634
## 
## Maximum Metrics: Maximum metrics at their respective thresholds
##                         metric threshold    value idx
## 1                       max f1  0.401407 0.816068 182
## 2                       max f2  0.200353 0.850669 313
## 3                 max f0point5  0.516888 0.863029 122
## 4                 max accuracy  0.477965 0.873817 144
## 5                max precision  0.768800 1.000000   0
## 6                   max recall  0.190450 1.000000 345
## 7              max specificity  0.768800 1.000000   0
## 8             max absolute_mcc  0.477965 0.726296 144
## 9   max min_per_class_accuracy  0.354077 0.841880 200
## 10 max mean_per_class_accuracy  0.401407 0.854893 182
## 
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## 
## 
## 
## Scoring History: 
##             timestamp   duration number_of_trees training_rmse training_logloss
## 1 2019-12-25 15:48:14  0.012 sec               0       0.48256          0.65847
## 2 2019-12-25 15:48:14  0.149 sec               1       0.48026          0.65371
## 3 2019-12-25 15:48:14  0.168 sec               2       0.47773          0.64854
## 4 2019-12-25 15:48:14  0.183 sec               3       0.47553          0.64407
## 5 2019-12-25 15:48:14  0.194 sec               4       0.47346          0.63987
##   training_auc training_pr_auc training_lift training_classification_error
## 1      0.50000         0.00000       1.00000                       0.63091
## 2      0.88092         0.62533       2.65828                       0.17666
## 3      0.90779         0.86243       2.60519                       0.14038
## 4      0.90978         0.85113       2.70940                       0.14669
## 5      0.91073         0.87802       2.70940                       0.14511
## 
## ---
##               timestamp   duration number_of_trees training_rmse
## 96  2019-12-25 15:48:15  0.817 sec              95       0.35867
## 97  2019-12-25 15:48:15  0.823 sec              96       0.35801
## 98  2019-12-25 15:48:15  0.829 sec              97       0.35741
## 99  2019-12-25 15:48:15  0.835 sec              98       0.35681
## 100 2019-12-25 15:48:15  0.840 sec              99       0.35616
## 101 2019-12-25 15:48:15  0.845 sec             100       0.35554
##     training_logloss training_auc training_pr_auc training_lift
## 96           0.42750      0.92811         0.89364       2.70940
## 97           0.42625      0.92816         0.89372       2.70940
## 98           0.42512      0.92819         0.89369       2.70940
## 99           0.42400      0.92821         0.89369       2.70940
## 100          0.42276      0.92825         0.89385       2.70940
## 101          0.42161      0.92863         0.89449       2.70940
##     training_classification_error
## 96                        0.13880
## 97                        0.13722
## 98                        0.13722
## 99                        0.13722
## 100                       0.13722
## 101                       0.13722
## 
## Variable Importances: (Extract with `h2o.varimp`) 
## =================================================
## 
## Variable Importances: 
##   variable relative_importance scaled_importance percentage
## 1     Name         1028.667114          1.000000   0.324098
## 2      Sex          848.877869          0.825221   0.267452
## 3     Fare          411.877594          0.400399   0.129769
## 4   Pclass          381.371033          0.370743   0.120157
## 5      Age          280.832764          0.273006   0.088481
## 6    SibSp          153.149597          0.148882   0.048252
## 7 Embarked           38.354553          0.037286   0.012084
## 8    Parch           30.808376          0.029950   0.009707
## 可视化模型中变量的重要性
h2o.varimp_plot(gbm)

## 计算模型在测试集上的预测值和性能
gbmpre <- as.data.frame(h2o.predict(gbm, newdata = test_data))
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
head(gbmpre)
##   predict        p0        p1
## 1       0 0.7895975 0.2104025
## 2       1 0.2456006 0.7543994
## 3       1 0.2885600 0.7114400
## 4       1 0.4053528 0.5946472
## 5       0 0.8299843 0.1700157
## 6       0 0.6812376 0.3187624
acc <- accuracy(as.vector(test_data$Survived),gbmpre$predict)
auc <- h2o.auc(h2o.performance(gbm, newdata = test_data)) 
sprintf("GBM model acc: %f",acc)
## [1] "GBM model acc: 0.832685"
sprintf("GBM model AUC: %f",auc)
## [1] "GBM model AUC: 0.869749"
## 使用参数网格搜索,寻找更合适的模型
ntrees_opt <- c(20,50,100,200,500) ## 树的数量
maxdepth_opt <- c(2,4,6,8,10) ##  树的最大深度
balance_opt <- c(TRUE,FALSE) ## 是否对数据进行类别平衡
hyper_par <- list(ntrees=ntrees_opt, max_depth=maxdepth_opt,
                  balance_classes= balance_opt)
## 使用GBM模型进行超参数搜索
grid <- h2o.grid("gbm", hyper_params = hyper_par,grid_id = "gbm_grid_mol.hex",
                 x = predictors, y = target, distribution="bernoulli",
                 training_frame =train_data,learn_rate=0.01)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==                                                                    |   3%
  |                                                                            
  |=====                                                                 |   7%
  |                                                                            
  |=========                                                             |  12%
  |                                                                            
  |============                                                          |  17%
  |                                                                            
  |===============                                                       |  22%
  |                                                                            
  |===================                                                   |  27%
  |                                                                            
  |======================                                                |  31%
  |                                                                            
  |========================                                              |  35%
  |                                                                            
  |===========================                                           |  38%
  |                                                                            
  |=============================                                         |  42%
  |                                                                            
  |================================                                      |  46%
  |                                                                            
  |===================================                                   |  49%
  |                                                                            
  |=====================================                                 |  53%
  |                                                                            
  |========================================                              |  57%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |============================================                          |  63%
  |                                                                            
  |==============================================                        |  66%
  |                                                                            
  |=================================================                     |  69%
  |                                                                            
  |==================================================                    |  72%
  |                                                                            
  |=====================================================                 |  75%
  |                                                                            
  |======================================================                |  77%
  |                                                                            
  |========================================================              |  81%
  |                                                                            
  |==========================================================            |  83%
  |                                                                            
  |============================================================          |  86%
  |                                                                            
  |==============================================================        |  88%
  |                                                                            
  |================================================================      |  91%
  |                                                                            
  |==================================================================    |  94%
  |                                                                            
  |===================================================================   |  96%
  |                                                                            
  |===================================================================== |  99%
  |                                                                            
  |======================================================================| 100%
## 查看模型的输出
sortedGrid <- h2o.getGrid("gbm_grid_mol.hex", sort_by=c("accuracy"),
                          decreasing = TRUE)    
sortedGrid@summary_table%>%head()
## Hyper-Parameter Search Summary: ordered by decreasing accuracy
##   balance_classes max_depth ntrees                 model_ids           accuracy
## 1            true        10    500 gbm_grid_mol.hex_model_49 0.9501246882793017
## 2            true         8    500 gbm_grid_mol.hex_model_47  0.947109471094711
## 3           false        10    500 gbm_grid_mol.hex_model_50 0.9416403785488959
## 4           false         8    500 gbm_grid_mol.hex_model_48 0.9369085173501577
## 5           false         6    500 gbm_grid_mol.hex_model_46 0.9242902208201893
## 6            true         6    500 gbm_grid_mol.hex_model_45 0.9104665825977302
## 将搜索的每个模型均作用于测试集,查看在测试集上的精度
grid_models <- lapply(grid@model_ids, 
                      function(model_id) {model = h2o.getModel(model_id) })

acc <- vector()
modelid <- vector()
for (i in 1:length(grid_models)) {
  gbmpre <- as.data.frame(h2o.predict(grid_models[[i]], newdata = test_data))
  acc[i] <- accuracy(as.vector(test_data$Survived),gbmpre$predict)
  modelid[i] <- grid_models[[i]]@model_id
}
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
data.frame(modelid = modelid,acc = acc) %>%
  inner_join(sortedGrid@summary_table,by = c("modelid"="model_ids"))%>%
  group_by(modelid)%>%
  arrange(desc(acc))%>%
  head()
## Warning: Column `modelid`/`model_ids` joining factor and character vector,
## coercing into character vector
## # A tibble: 6 x 6
## # Groups:   modelid [6]
##   modelid                  acc balance_classes max_depth ntrees accuracy        
##   <chr>                  <dbl> <chr>           <chr>     <chr>  <chr>           
## 1 gbm_grid_mol.hex_mode… 0.844 false           6         500    0.9242902208201…
## 2 gbm_grid_mol.hex_mode… 0.840 false           4         500    0.9006309148264…
## 3 gbm_grid_mol.hex_mode… 0.840 false           8         200    0.8974763406940…
## 4 gbm_grid_mol.hex_mode… 0.840 true            8         200    0.9028642590286…
## 5 gbm_grid_mol.hex_mode… 0.837 false           10        500    0.9416403785488…
## 6 gbm_grid_mol.hex_mode… 0.837 false           10        200    0.8974763406940…

GBM模型回归

仍然使用建筑能耗数据集

library(h2o)
library(Metrics)
library(dplyr)
library(readxl)
library(tidyr)
library(ggplot2)
## 启动初始化一个h2o实例; 定义为2核同时计算;
h2o.init(nthreads=2,max_mem_size='4G')  
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         1 minutes 8 seconds 
##     H2O cluster timezone:       Asia/Shanghai 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.26.0.2 
##     H2O cluster version age:    4 months and 28 days !!! 
##     H2O cluster name:           H2O_started_from_R_sunlanxin_aui415 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.45 GB 
##     H2O cluster total cores:    8 
##     H2O cluster allowed cores:  2 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     H2O API Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4 
##     R Version:                  R version 3.6.2 (2019-12-12)
## Warning in h2o.clusterInfo(): 
## Your H2O cluster version is too old (4 months and 28 days)!
## Please download and install the latest version from http://h2o.ai/download/
## 使用GBM,对ENB2012数据进行回归分析
ENB <- read_excel("data/chap9/ENB2012.xlsx")
head(ENB)
## # A tibble: 6 x 9
##      X1    X2    X3    X4    X5    X6    X7    X8    Y1
##   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1  0.98  514.  294   110.     7     2     0     0  15.6
## 2  0.98  514.  294   110.     7     3     0     0  15.6
## 3  0.98  514.  294   110.     7     4     0     0  15.6
## 4  0.98  514.  294   110.     7     5     0     0  15.6
## 5  0.9   564.  318.  122.     7     2     0     0  20.8
## 6  0.9   564.  318.  122.     7     3     0     0  21.5
summary(ENB)
##        X1               X2              X3              X4       
##  Min.   :0.6200   Min.   :514.5   Min.   :245.0   Min.   :110.2  
##  1st Qu.:0.6825   1st Qu.:606.4   1st Qu.:294.0   1st Qu.:140.9  
##  Median :0.7500   Median :673.8   Median :318.5   Median :183.8  
##  Mean   :0.7642   Mean   :671.7   Mean   :318.5   Mean   :176.6  
##  3rd Qu.:0.8300   3rd Qu.:741.1   3rd Qu.:343.0   3rd Qu.:220.5  
##  Max.   :0.9800   Max.   :808.5   Max.   :416.5   Max.   :220.5  
##        X5             X6             X7               X8              Y1       
##  Min.   :3.50   Min.   :2.00   Min.   :0.0000   Min.   :0.000   Min.   : 6.01  
##  1st Qu.:3.50   1st Qu.:2.75   1st Qu.:0.1000   1st Qu.:1.750   1st Qu.:12.99  
##  Median :5.25   Median :3.50   Median :0.2500   Median :3.000   Median :18.95  
##  Mean   :5.25   Mean   :3.50   Mean   :0.2344   Mean   :2.812   Mean   :22.31  
##  3rd Qu.:7.00   3rd Qu.:4.25   3rd Qu.:0.4000   3rd Qu.:4.000   3rd Qu.:31.67  
##  Max.   :7.00   Max.   :5.00   Max.   :0.4000   Max.   :5.000   Max.   :43.10
str(ENB)
## Classes 'tbl_df', 'tbl' and 'data.frame':    768 obs. of  9 variables:
##  $ X1: num  0.98 0.98 0.98 0.98 0.9 0.9 0.9 0.9 0.86 0.86 ...
##  $ X2: num  514 514 514 514 564 ...
##  $ X3: num  294 294 294 294 318 ...
##  $ X4: num  110 110 110 110 122 ...
##  $ X5: num  7 7 7 7 7 7 7 7 7 7 ...
##  $ X6: num  2 3 4 5 2 3 4 5 2 3 ...
##  $ X7: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ X8: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ Y1: num  15.6 15.6 15.6 15.6 20.8 ...
## 数据切分为训练集和测试集,训练集70%
set.seed(12)
index <- sample(nrow(ENB),round(nrow(ENB)*0.7))
trainEnb <- as.h2o(ENB[index,])
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
testENB <- as.h2o(ENB[-index,])
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
## GBM回归模型
name1 <- colnames(trainEnb)
predictors <- name1[1:8]
target <- "Y1"
## 使用训练集训练一个基础GBM回归模型
gbmreg <- h2o.gbm(x = predictors, y = target,
                  training_frame = trainEnb,
                  distribution="AUTO", ## 回归模型
                  ntrees = 100,seed = 1234)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======                                                                |   9%
  |                                                                            
  |======================================================================| 100%
## 查看在测试集上的模型表达能力
h2o.performance(gbmreg,testENB)
## H2ORegressionMetrics: gbm
## 
## MSE:  0.233291
## RMSE:  0.4830021
## MAE:  0.347336
## RMSLE:  0.02502702
## Mean Residual Deviance :  0.233291
## R^2 :  0.9977173
## 使用参数网格搜索,寻找更合适的模型
ntrees_opt <- c(50,100,200,500) ## 树的数量
maxdepth_opt <- c(2,4,6,8,10) ##  树的最大深度
hyper_par <- list(ntrees=ntrees_opt, max_depth=maxdepth_opt)
## 使用GBM模型进行超参数搜索
gbm_grid_reg <- h2o.grid(algorithm="gbm", x = predictors,
                         grid_id ="gbm_grid_reg",
                         y = target,distribution="AUTO",
                         training_frame = trainEnb,hyper_params = hyper_par)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |============                                                          |  16%
  |                                                                            
  |======================                                                |  32%
  |                                                                            
  |================================                                      |  45%
  |                                                                            
  |======================================                                |  55%
  |                                                                            
  |=============================================                         |  64%
  |                                                                            
  |====================================================                  |  74%
  |                                                                            
  |=========================================================             |  81%
  |                                                                            
  |=============================================================         |  87%
  |                                                                            
  |===================================================================   |  96%
  |                                                                            
  |======================================================================| 100%
## 查看模型的输出
sortedGrid <- h2o.getGrid("gbm_grid_reg", sort_by="mse", 
                          decreasing = FALSE)
sortedGrid@summary_table%>%head()
## Hyper-Parameter Search Summary: ordered by increasing mse
##   max_depth ntrees             model_ids                  mse
## 1        10    500 gbm_grid_reg_model_20 0.015489566841374501
## 2         8    500 gbm_grid_reg_model_19  0.02014229884380941
## 3         6    500 gbm_grid_reg_model_18  0.03779650631098121
## 4        10    200 gbm_grid_reg_model_15  0.05339006823934593
## 5         8    200 gbm_grid_reg_model_14 0.059855247801575646
## 6         4    500 gbm_grid_reg_model_17  0.07499614778959614
## 较好的模型为参数ntree=500,maxdepth=10

## 使用新的参数重新训练模型
gbmreg <- h2o.gbm(x = predictors, y = target,
                  training_frame = trainEnb,
                  distribution="AUTO", ## 回归模型
                  ntrees = 500,  ## 模型使用数的树量  
                  max_depth = 10,seed = 1234)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==                                                                    |   3%
  |                                                                            
  |=======================================================               |  78%
  |                                                                            
  |======================================================================| 100%
## 查看在测试集上的预测性能
h2o.performance(gbmreg,newdata = testENB)
## H2ORegressionMetrics: gbm
## 
## MSE:  0.1476192
## RMSE:  0.3842124
## MAE:  0.281197
## RMSLE:  0.01928503
## Mean Residual Deviance :  0.1476192
## R^2 :  0.9985556
## 在测试集上的误差为mse = 0.14
## 可视化模型的预测效果
gbmpre <- as.data.frame(h2o.predict(gbmreg,testENB))
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
testENBdf <- as.data.frame(testENB)
index <- order(testENBdf$Y1)
X <- sort(index)
Y1 <- testENBdf$Y1[index]
gbmprebest2 <- gbmpre$predict[index]

plotdata <- data.frame(X = X,Y1 = Y1,gbmprebest = gbmprebest2)
plotdata <- gather(plotdata,key="model",value="value",c(-X))

## 可视化模型的预测误差
ggplot(plotdata,aes(x = X,y = value))+
  geom_line(aes(linetype = model,colour = model),size = 0.8)+
  theme(legend.position = c(0.1,0.8),
        plot.title = element_text(hjust = 0.5))+
  ggtitle("GBM回归模型")