library(ggplot2)
theme_set(theme_bw(base_family = "STKaiti"))
library(readr)
library(VIM)
## Loading required package: colorspace
## Loading required package: grid
## Loading required package: data.table
## VIM is ready to use.
## Since version 4.0.0 the GUI is in its own package VIMGUI.
##
## Please use the package to use the new (and old) GUI.
## Suggestions and bug-reports can be submitted at: https://github.com/alexkowa/VIM/issues
##
## Attaching package: 'VIM'
## The following object is masked from 'package:datasets':
##
## sleep
library(caret)
## Loading required package: lattice
library(rpart)
library(rpart.plot)
library(Metrics)
##
## Attaching package: 'Metrics'
## The following objects are masked from 'package:caret':
##
## precision, recall
library(ROCR)
## Loading required package: gplots
##
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
##
## lowess
library(readxl)
library(stringr)
library(ggplot2)
## 使用决策树算法,分析泰坦尼克号数据集,
## 读取训练集和测试集
Ttrain <- read_csv("data/chap9/Titanic train.csv")
## Parsed with column specification:
## cols(
## PassengerId = col_double(),
## Survived = col_double(),
## Pclass = col_double(),
## Name = col_character(),
## Sex = col_character(),
## Age = col_double(),
## SibSp = col_double(),
## Parch = col_double(),
## Ticket = col_character(),
## Fare = col_double(),
## Cabin = col_character(),
## Embarked = col_character()
## )
Ttest <- read_csv("data/chap9/Titanic test.csv")
## Parsed with column specification:
## cols(
## PassengerId = col_double(),
## Pclass = col_double(),
## Name = col_character(),
## Sex = col_character(),
## Age = col_double(),
## SibSp = col_double(),
## Parch = col_double(),
## Ticket = col_character(),
## Fare = col_double(),
## Cabin = col_character(),
## Embarked = col_character()
## )
dim(Ttrain)
## [1] 891 12
colnames(Ttrain)
## [1] "PassengerId" "Survived" "Pclass" "Name" "Sex"
## [6] "Age" "SibSp" "Parch" "Ticket" "Fare"
## [11] "Cabin" "Embarked"
dim(Ttest)
## [1] 418 11
## 组合数据
Alldata <- rbind.data.frame(Ttrain[,-2],Ttest)
summary(Ttrain)
## PassengerId Survived Pclass Name
## Min. : 1.0 Min. :0.0000 Min. :1.000 Length:891
## 1st Qu.:223.5 1st Qu.:0.0000 1st Qu.:2.000 Class :character
## Median :446.0 Median :0.0000 Median :3.000 Mode :character
## Mean :446.0 Mean :0.3838 Mean :2.309
## 3rd Qu.:668.5 3rd Qu.:1.0000 3rd Qu.:3.000
## Max. :891.0 Max. :1.0000 Max. :3.000
##
## Sex Age SibSp Parch
## Length:891 Min. : 0.42 Min. :0.000 Min. :0.0000
## Class :character 1st Qu.:20.12 1st Qu.:0.000 1st Qu.:0.0000
## Mode :character Median :28.00 Median :0.000 Median :0.0000
## Mean :29.70 Mean :0.523 Mean :0.3816
## 3rd Qu.:38.00 3rd Qu.:1.000 3rd Qu.:0.0000
## Max. :80.00 Max. :8.000 Max. :6.0000
## NA's :177
## Ticket Fare Cabin Embarked
## Length:891 Min. : 0.00 Length:891 Length:891
## Class :character 1st Qu.: 7.91 Class :character Class :character
## Mode :character Median : 14.45 Mode :character Mode :character
## Mean : 32.20
## 3rd Qu.: 31.00
## Max. :512.33
##
summary(Ttest)
## PassengerId Pclass Name Sex
## Min. : 892.0 Min. :1.000 Length:418 Length:418
## 1st Qu.: 996.2 1st Qu.:1.000 Class :character Class :character
## Median :1100.5 Median :3.000 Mode :character Mode :character
## Mean :1100.5 Mean :2.266
## 3rd Qu.:1204.8 3rd Qu.:3.000
## Max. :1309.0 Max. :3.000
##
## Age SibSp Parch Ticket
## Min. : 0.17 Min. :0.0000 Min. :0.0000 Length:418
## 1st Qu.:21.00 1st Qu.:0.0000 1st Qu.:0.0000 Class :character
## Median :27.00 Median :0.0000 Median :0.0000 Mode :character
## Mean :30.27 Mean :0.4474 Mean :0.3923
## 3rd Qu.:39.00 3rd Qu.:1.0000 3rd Qu.:0.0000
## Max. :76.00 Max. :8.0000 Max. :9.0000
## NA's :86
## Fare Cabin Embarked
## Min. : 0.000 Length:418 Length:418
## 1st Qu.: 7.896 Class :character Class :character
## Median : 14.454 Mode :character Mode :character
## Mean : 35.627
## 3rd Qu.: 31.500
## Max. :512.329
## NA's :1
summary(Alldata)
## PassengerId Pclass Name Sex
## Min. : 1 Min. :1.000 Length:1309 Length:1309
## 1st Qu.: 328 1st Qu.:2.000 Class :character Class :character
## Median : 655 Median :3.000 Mode :character Mode :character
## Mean : 655 Mean :2.295
## 3rd Qu.: 982 3rd Qu.:3.000
## Max. :1309 Max. :3.000
##
## Age SibSp Parch Ticket
## Min. : 0.17 Min. :0.0000 Min. :0.000 Length:1309
## 1st Qu.:21.00 1st Qu.:0.0000 1st Qu.:0.000 Class :character
## Median :28.00 Median :0.0000 Median :0.000 Mode :character
## Mean :29.88 Mean :0.4989 Mean :0.385
## 3rd Qu.:39.00 3rd Qu.:1.0000 3rd Qu.:0.000
## Max. :80.00 Max. :8.0000 Max. :9.000
## NA's :263
## Fare Cabin Embarked
## Min. : 0.000 Length:1309 Length:1309
## 1st Qu.: 7.896 Class :character Class :character
## Median : 14.454 Mode :character Mode :character
## Mean : 33.295
## 3rd Qu.: 31.275
## Max. :512.329
## NA's :1
Survived <- Ttrain$Survived
table(Survived)
## Survived
## 0 1
## 549 342
## 数据探索与可视化,及特征变换
## 分析数据的缺失值情况,并进行处理
aggr(Alldata)
colnames(Alldata)
## [1] "PassengerId" "Pclass" "Name" "Sex" "Age"
## [6] "SibSp" "Parch" "Ticket" "Fare" "Cabin"
## [11] "Embarked"
# 训练集和测试集要同时作相同的操作
## Cabin缺失值太多,可以直接剔除
Alldata$Cabin <- NULL
## 船票和ID具有识别性所以需要剔除
Alldata$PassengerId <- NULL
Alldata$Ticket <- NULL
## 年龄变量的缺失值可以使用中位数来填补
Alldata$Age[is.na(Alldata$Age)] <- median(Alldata$Age,na.rm = TRUE)
## fare变量的缺失值可以使用均值来填补
Alldata$Fare[is.na(Alldata$Fare)] <- mean(Alldata$Fare,na.rm = TRUE)
## Embarked变量的缺失值,可以使用众数来填补
Embarkedmod <- names(sort(table(Alldata$Embarked),decreasing = T)[1])
Alldata$Embarked[is.na(Alldata$Embarked)] <- Embarkedmod
## 获取新的特征,提取name变量中的特征
newname <- str_split(Alldata$Name," ")
newname <- sapply(newname, function(x) x[2])
sort(table(newname),decreasing = T)
## newname
## Mr. Miss. Mrs. Master. Dr. Rev.
## 736 256 191 59 8 8
## y Col. Planke, Billiard, Impe, Carlo,
## 8 4 4 3 3 2
## Gordon, Major. Messemaeker, Mlle. Ms. Brito,
## 2 2 2 2 2 1
## Capt. Cruyssen, der Don. Jonkheer. Khalil,
## 1 1 1 1 1 1
## Melkebeke, Mme. Mulder, Palmquist, Pelsmaeker, Shawah,
## 1 1 1 1 1 1
## Steen, the Velde, Walle,
## 1 1 1 1
## 名字设置为 Mr. Miss. Mrs. Master.,其余的使用other代替
newnamepart <- c("Mr.","Miss.","Mrs.","Master.")
newname[!(newname %in% newnamepart)] <- "other"
Alldata$Name <- as.factor(newname)
Alldata$Sex <- as.factor(Alldata$Sex)
Alldata$Embarked <- as.factor(Alldata$Embarked)
summary(Alldata)
## Pclass Name Sex Age SibSp
## Min. :1.000 Master.: 59 female:466 Min. : 0.17 Min. :0.0000
## 1st Qu.:2.000 Miss. :256 male :843 1st Qu.:22.00 1st Qu.:0.0000
## Median :3.000 Mr. :736 Median :28.00 Median :0.0000
## Mean :2.295 Mrs. :191 Mean :29.50 Mean :0.4989
## 3rd Qu.:3.000 other : 67 3rd Qu.:35.00 3rd Qu.:1.0000
## Max. :3.000 Max. :80.00 Max. :8.0000
## Parch Fare Embarked
## Min. :0.000 Min. : 0.000 C:270
## 1st Qu.:0.000 1st Qu.: 7.896 Q:123
## Median :0.000 Median : 14.454 S:916
## Mean :0.385 Mean : 33.295
## 3rd Qu.:0.000 3rd Qu.: 31.275
## Max. :9.000 Max. :512.329
str(Alldata)
## Classes 'tbl_df', 'tbl' and 'data.frame': 1309 obs. of 8 variables:
## $ Pclass : num 3 1 3 1 3 3 1 3 3 2 ...
## $ Name : Factor w/ 5 levels "Master.","Miss.",..: 3 4 2 4 3 3 3 1 4 4 ...
## $ Sex : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
## $ Age : num 22 38 26 35 35 28 54 2 27 14 ...
## $ SibSp : num 1 1 0 1 0 0 0 3 0 1 ...
## $ Parch : num 0 0 0 0 0 0 0 1 2 0 ...
## $ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
## $ Embarked: Factor w/ 3 levels "C","Q","S": 3 1 3 3 3 2 3 3 3 1 ...
## 有多少兄弟姐妹/配偶同船:'SibSp', 有多少父母/子女同船:'Parch',进行相加生成新的特征
#Alldata$SP <- Alldata$SibSp + Alldata$Parch
#Alldata$SibSp <- NULL
#Alldata$Parch <- NULL
#Alldata$SP <- cut_width(Alldata$SP,3)
#Alldata$Age <- cut_number(Alldata$Age,6)
#Alldata$Fare <- cut_number(Alldata$Fare,5)
#Alldata$Pclass <- as.factor(Alldata$Pclass)
#table(Alldata$Pclass)
## 与处理好的训练数据和测试数据分开
Ttrainp <- Alldata[1:nrow(Ttrain),]
Ttrainp$Survived <- Survived
Ttestp <- Alldata[(nrow(Ttrain)+1):nrow(Alldata),]
str(Ttrainp)
## Classes 'tbl_df', 'tbl' and 'data.frame': 891 obs. of 9 variables:
## $ Pclass : num 3 1 3 1 3 3 1 3 3 2 ...
## $ Name : Factor w/ 5 levels "Master.","Miss.",..: 3 4 2 4 3 3 3 1 4 4 ...
## $ Sex : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
## $ Age : num 22 38 26 35 35 28 54 2 27 14 ...
## $ SibSp : num 1 1 0 1 0 0 0 3 0 1 ...
## $ Parch : num 0 0 0 0 0 0 0 1 2 0 ...
## $ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
## $ Embarked: Factor w/ 3 levels "C","Q","S": 3 1 3 3 3 2 3 3 3 1 ...
## $ Survived: num 0 1 1 1 0 0 0 0 1 1 ...
## 将处理好的训练数据集保存为文件
write.csv(Ttrainp,"data/chap9/Titanic处理后数据.csv",row.names = F)
## 切分训练数据集为训练集和测试集,70%训练
set.seed(123)
CDP <- createDataPartition(Ttrainp$Survived,p = 0.8)
train_data <- Ttrainp[CDP$Resample1,]
test_data <- Ttrainp[-CDP$Resample1,]
mod1 <- rpart(Survived~.,data = train_data,method="class",cp = 0.000001)
summary(mod1)
## Call:
## rpart(formula = Survived ~ ., data = train_data, method = "class",
## cp = 1e-06)
## n= 713
##
## CP nsplit rel error xerror xstd
## 1 0.450549451 0 1.0000000 1.0000000 0.04754450
## 2 0.051282051 1 0.5494505 0.5604396 0.04015432
## 3 0.014652015 3 0.4468864 0.4615385 0.03730757
## 4 0.010989011 4 0.4322344 0.4798535 0.03787823
## 5 0.007326007 6 0.4102564 0.5091575 0.03874859
## 6 0.003663004 12 0.3589744 0.4835165 0.03798985
## 7 0.000001000 15 0.3479853 0.4798535 0.03787823
##
## Variable importance
## Name Sex Fare Age Parch Pclass SibSp Embarked
## 26 21 15 11 9 9 7 2
##
## Node number 1: 713 observations, complexity param=0.4505495
## predicted class=0 expected loss=0.3828892 P(node) =1
## class counts: 440 273
## probabilities: 0.617 0.383
## left son=2 (444 obs) right son=3 (269 obs)
## Primary splits:
## Name splits as RRLRL, improve=103.27050, (0 missing)
## Sex splits as RL, improve=101.59730, (0 missing)
## Pclass < 2.5 to the right, improve= 31.70004, (0 missing)
## Fare < 50.9875 to the left, improve= 26.05713, (0 missing)
## Embarked splits as RLL, improve= 12.38503, (0 missing)
## Surrogate splits:
## Sex splits as RL, agree=0.938, adj=0.836, (0 split)
## Parch < 0.5 to the left, agree=0.735, adj=0.297, (0 split)
## Age < 15.5 to the right, agree=0.701, adj=0.208, (0 split)
## SibSp < 0.5 to the left, agree=0.673, adj=0.134, (0 split)
## Fare < 77.6229 to the left, agree=0.651, adj=0.074, (0 split)
##
## Node number 2: 444 observations, complexity param=0.007326007
## predicted class=0 expected loss=0.1734234 P(node) =0.6227209
## class counts: 367 77
## probabilities: 0.827 0.173
## left son=4 (339 obs) right son=5 (105 obs)
## Primary splits:
## Pclass < 1.5 to the right, improve=11.845680, (0 missing)
## Fare < 26.26875 to the left, improve=10.478310, (0 missing)
## Sex splits as RL, improve= 7.211408, (0 missing)
## Embarked splits as RLL, improve= 6.085100, (0 missing)
## Name splits as --L-R, improve= 2.929979, (0 missing)
## Surrogate splits:
## Fare < 26.26875 to the left, agree=0.928, adj=0.695, (0 split)
## Age < 44.5 to the left, agree=0.795, adj=0.133, (0 split)
## Embarked splits as RLL, agree=0.768, adj=0.019, (0 split)
##
## Node number 3: 269 observations, complexity param=0.05128205
## predicted class=1 expected loss=0.2713755 P(node) =0.3772791
## class counts: 73 196
## probabilities: 0.271 0.729
## left son=6 (133 obs) right son=7 (136 obs)
## Primary splits:
## Pclass < 2.5 to the right, improve=24.854190, (0 missing)
## SibSp < 2.5 to the right, improve=20.063110, (0 missing)
## Fare < 48.2 to the left, improve= 7.790478, (0 missing)
## Parch < 3.5 to the right, improve= 4.932072, (0 missing)
## Age < 12 to the left, improve= 3.221960, (0 missing)
## Surrogate splits:
## Fare < 25.73335 to the left, agree=0.747, adj=0.489, (0 split)
## Age < 28.5 to the left, agree=0.699, adj=0.391, (0 split)
## Name splits as LL-R-, agree=0.625, adj=0.241, (0 split)
## Embarked splits as RLR, agree=0.606, adj=0.203, (0 split)
## SibSp < 1.5 to the right, agree=0.602, adj=0.195, (0 split)
##
## Node number 4: 339 observations
## predicted class=0 expected loss=0.1091445 P(node) =0.4754558
## class counts: 302 37
## probabilities: 0.891 0.109
##
## Node number 5: 105 observations, complexity param=0.007326007
## predicted class=0 expected loss=0.3809524 P(node) =0.1472651
## class counts: 65 40
## probabilities: 0.619 0.381
## left son=10 (18 obs) right son=11 (87 obs)
## Primary splits:
## Age < 53 to the right, improve=3.1636560, (0 missing)
## Embarked splits as RLL, improve=2.3172550, (0 missing)
## Fare < 25.9271 to the left, improve=2.1768710, (0 missing)
## Name splits as --L-R, improve=0.9803579, (0 missing)
## Parch < 0.5 to the right, improve=0.4625068, (0 missing)
##
## Node number 6: 133 observations, complexity param=0.05128205
## predicted class=1 expected loss=0.4887218 P(node) =0.1865358
## class counts: 65 68
## probabilities: 0.489 0.511
## left son=12 (34 obs) right son=13 (99 obs)
## Primary splits:
## Fare < 24.80835 to the right, improve=16.349110, (0 missing)
## SibSp < 2.5 to the right, improve=11.585210, (0 missing)
## Embarked splits as RRL, improve= 5.596270, (0 missing)
## Age < 38.5 to the right, improve= 3.658035, (0 missing)
## Parch < 1.5 to the right, improve= 3.583283, (0 missing)
## Surrogate splits:
## SibSp < 2.5 to the right, agree=0.910, adj=0.647, (0 split)
## Parch < 1.5 to the right, agree=0.842, adj=0.382, (0 split)
## Name splits as LR-R-, agree=0.774, adj=0.118, (0 split)
## Sex splits as RL, agree=0.774, adj=0.118, (0 split)
## Age < 37.5 to the right, agree=0.767, adj=0.088, (0 split)
##
## Node number 7: 136 observations
## predicted class=1 expected loss=0.05882353 P(node) =0.1907433
## class counts: 8 128
## probabilities: 0.059 0.941
##
## Node number 10: 18 observations
## predicted class=0 expected loss=0.1111111 P(node) =0.02524544
## class counts: 16 2
## probabilities: 0.889 0.111
##
## Node number 11: 87 observations, complexity param=0.007326007
## predicted class=0 expected loss=0.4367816 P(node) =0.1220196
## class counts: 49 38
## probabilities: 0.563 0.437
## left son=22 (7 obs) right son=23 (80 obs)
## Primary splits:
## Fare < 25.9271 to the left, improve=2.9045980, (0 missing)
## Embarked splits as RLL, improve=1.6625330, (0 missing)
## Age < 27.5 to the right, improve=1.3891050, (0 missing)
## Name splits as --L-R, improve=0.9657088, (0 missing)
## Parch < 0.5 to the right, improve=0.2979310, (0 missing)
##
## Node number 12: 34 observations
## predicted class=0 expected loss=0.08823529 P(node) =0.04768583
## class counts: 31 3
## probabilities: 0.912 0.088
##
## Node number 13: 99 observations, complexity param=0.01465201
## predicted class=1 expected loss=0.3434343 P(node) =0.1388499
## class counts: 34 65
## probabilities: 0.343 0.657
## left son=26 (14 obs) right son=27 (85 obs)
## Primary splits:
## Age < 28.5 to the right, improve=2.923776, (0 missing)
## Embarked splits as LRL, improve=2.250789, (0 missing)
## Fare < 8.0396 to the right, improve=1.662338, (0 missing)
## Name splits as RL-L-, improve=1.068687, (0 missing)
## Sex splits as LR, improve=1.068687, (0 missing)
##
## Node number 22: 7 observations
## predicted class=0 expected loss=0 P(node) =0.009817672
## class counts: 7 0
## probabilities: 1.000 0.000
##
## Node number 23: 80 observations, complexity param=0.007326007
## predicted class=0 expected loss=0.475 P(node) =0.112202
## class counts: 42 38
## probabilities: 0.525 0.475
## left son=46 (68 obs) right son=47 (12 obs)
## Primary splits:
## Fare < 27.1354 to the right, improve=3.6254900, (0 missing)
## Name splits as --L-R, improve=0.9562771, (0 missing)
## Age < 27.5 to the right, improve=0.9000000, (0 missing)
## Embarked splits as RLL, improve=0.8309463, (0 missing)
## Parch < 0.5 to the right, improve=0.5666667, (0 missing)
##
## Node number 26: 14 observations
## predicted class=0 expected loss=0.3571429 P(node) =0.01963534
## class counts: 9 5
## probabilities: 0.643 0.357
##
## Node number 27: 85 observations, complexity param=0.01098901
## predicted class=1 expected loss=0.2941176 P(node) =0.1192146
## class counts: 25 60
## probabilities: 0.294 0.706
## left son=54 (52 obs) right son=55 (33 obs)
## Primary splits:
## Fare < 8.0396 to the right, improve=2.1938850, (0 missing)
## Embarked splits as LRL, improve=1.9129700, (0 missing)
## Age < 7 to the right, improve=0.7669547, (0 missing)
## Parch < 1.5 to the left, improve=0.6742346, (0 missing)
## Name splits as RL-L-, improve=0.6742346, (0 missing)
## Surrogate splits:
## Embarked splits as LRL, agree=0.753, adj=0.364, (0 split)
## SibSp < 0.5 to the right, agree=0.729, adj=0.303, (0 split)
## Parch < 0.5 to the right, agree=0.718, adj=0.273, (0 split)
## Name splits as LR-L-, agree=0.647, adj=0.091, (0 split)
##
## Node number 46: 68 observations, complexity param=0.007326007
## predicted class=0 expected loss=0.4117647 P(node) =0.09537167
## class counts: 40 28
## probabilities: 0.588 0.412
## left son=92 (35 obs) right son=93 (33 obs)
## Primary splits:
## Embarked splits as RLL, improve=2.2918260, (0 missing)
## Age < 27.5 to the right, improve=1.9027150, (0 missing)
## Name splits as --L-R, improve=1.8935570, (0 missing)
## Fare < 127.8166 to the right, improve=0.4745098, (0 missing)
## Parch < 0.5 to the right, improve=0.1792717, (0 missing)
## Surrogate splits:
## Age < 27.5 to the right, agree=0.662, adj=0.303, (0 split)
## Fare < 61.2771 to the left, agree=0.647, adj=0.273, (0 split)
## Name splits as --L-R, agree=0.603, adj=0.182, (0 split)
## Parch < 0.5 to the left, agree=0.574, adj=0.121, (0 split)
## Sex splits as RL, agree=0.559, adj=0.091, (0 split)
##
## Node number 47: 12 observations
## predicted class=1 expected loss=0.1666667 P(node) =0.01683029
## class counts: 2 10
## probabilities: 0.167 0.833
##
## Node number 54: 52 observations, complexity param=0.01098901
## predicted class=1 expected loss=0.3846154 P(node) =0.07293128
## class counts: 20 32
## probabilities: 0.385 0.615
## left son=108 (10 obs) right son=109 (42 obs)
## Primary splits:
## Fare < 10.825 to the left, improve=4.272527, (0 missing)
## Parch < 0.5 to the left, improve=2.690347, (0 missing)
## Age < 7 to the right, improve=2.239445, (0 missing)
## Name splits as RL-L-, improve=1.628305, (0 missing)
## Sex splits as LR, improve=1.628305, (0 missing)
##
## Node number 55: 33 observations
## predicted class=1 expected loss=0.1515152 P(node) =0.04628331
## class counts: 5 28
## probabilities: 0.152 0.848
##
## Node number 92: 35 observations
## predicted class=0 expected loss=0.2857143 P(node) =0.04908836
## class counts: 25 10
## probabilities: 0.714 0.286
##
## Node number 93: 33 observations, complexity param=0.007326007
## predicted class=1 expected loss=0.4545455 P(node) =0.04628331
## class counts: 15 18
## probabilities: 0.455 0.545
## left son=186 (11 obs) right son=187 (22 obs)
## Primary splits:
## Fare < 98.7521 to the right, improve=1.0909090, (0 missing)
## Age < 26.5 to the right, improve=1.0909090, (0 missing)
## Parch < 0.5 to the right, improve=0.6136364, (0 missing)
## Name splits as --L-R, improve=0.3636364, (0 missing)
## SibSp < 0.5 to the left, improve=0.2727273, (0 missing)
## Surrogate splits:
## Age < 22.5 to the left, agree=0.788, adj=0.364, (0 split)
## Parch < 0.5 to the right, agree=0.727, adj=0.182, (0 split)
##
## Node number 108: 10 observations
## predicted class=0 expected loss=0.2 P(node) =0.01402525
## class counts: 8 2
## probabilities: 0.800 0.200
##
## Node number 109: 42 observations, complexity param=0.003663004
## predicted class=1 expected loss=0.2857143 P(node) =0.05890603
## class counts: 12 30
## probabilities: 0.286 0.714
## left son=218 (29 obs) right son=219 (13 obs)
## Primary splits:
## Age < 7 to the right, improve=1.6415310, (0 missing)
## Fare < 15.3729 to the left, improve=1.5873020, (0 missing)
## Embarked splits as LRR, improve=0.9075630, (0 missing)
## Name splits as RL-L-, improve=0.6984127, (0 missing)
## Sex splits as LR, improve=0.6984127, (0 missing)
## Surrogate splits:
## SibSp < 1.5 to the left, agree=0.738, adj=0.154, (0 split)
## Name splits as RL-L-, agree=0.714, adj=0.077, (0 split)
## Sex splits as LR, agree=0.714, adj=0.077, (0 split)
## Fare < 11.1875 to the right, agree=0.714, adj=0.077, (0 split)
##
## Node number 186: 11 observations
## predicted class=0 expected loss=0.3636364 P(node) =0.01542777
## class counts: 7 4
## probabilities: 0.636 0.364
##
## Node number 187: 22 observations, complexity param=0.003663004
## predicted class=1 expected loss=0.3636364 P(node) =0.03085554
## class counts: 8 14
## probabilities: 0.364 0.636
## left son=374 (15 obs) right son=375 (7 obs)
## Primary splits:
## Age < 27.5 to the right, improve=2.7151520, (0 missing)
## Fare < 30.8479 to the left, improve=1.7175320, (0 missing)
## SibSp < 0.5 to the left, improve=1.0008660, (0 missing)
## Name splits as --L-R, improve=0.1246753, (0 missing)
## Surrogate splits:
## Sex splits as RL, agree=0.727, adj=0.143, (0 split)
## Fare < 44.5521 to the left, agree=0.727, adj=0.143, (0 split)
##
## Node number 218: 29 observations, complexity param=0.003663004
## predicted class=1 expected loss=0.3793103 P(node) =0.04067321
## class counts: 11 18
## probabilities: 0.379 0.621
## left son=436 (14 obs) right son=437 (15 obs)
## Primary splits:
## Fare < 15.3729 to the left, improve=1.99803000, (0 missing)
## Age < 21 to the left, improve=1.48675100, (0 missing)
## Embarked splits as LRR, improve=1.19363400, (0 missing)
## Name splits as RL-R-, improve=0.59634890, (0 missing)
## Parch < 0.5 to the left, improve=0.02660099, (0 missing)
## Surrogate splits:
## Embarked splits as LRR, agree=0.828, adj=0.643, (0 split)
## Age < 17.5 to the left, agree=0.586, adj=0.143, (0 split)
## SibSp < 0.5 to the left, agree=0.586, adj=0.143, (0 split)
## Parch < 0.5 to the right, agree=0.586, adj=0.143, (0 split)
##
## Node number 219: 13 observations
## predicted class=1 expected loss=0.07692308 P(node) =0.01823282
## class counts: 1 12
## probabilities: 0.077 0.923
##
## Node number 374: 15 observations
## predicted class=0 expected loss=0.4666667 P(node) =0.02103787
## class counts: 8 7
## probabilities: 0.533 0.467
##
## Node number 375: 7 observations
## predicted class=1 expected loss=0 P(node) =0.009817672
## class counts: 0 7
## probabilities: 0.000 1.000
##
## Node number 436: 14 observations
## predicted class=0 expected loss=0.4285714 P(node) =0.01963534
## class counts: 8 6
## probabilities: 0.571 0.429
##
## Node number 437: 15 observations
## predicted class=1 expected loss=0.2 P(node) =0.02103787
## class counts: 3 12
## probabilities: 0.200 0.800
## 看变量重要性
mod1$variable.importance
## Name Sex Fare Age Parch Pclass SibSp Embarked
## 111.91631 89.02461 61.63420 46.05885 38.32346 36.69987 30.46093 9.64527
#cp是每次分割对应的复杂度系数
mod1$cp
## CP nsplit rel error xerror xstd
## 1 0.450549451 0 1.0000000 1.0000000 0.04754450
## 2 0.051282051 1 0.5494505 0.5604396 0.04015432
## 3 0.014652015 3 0.4468864 0.4615385 0.03730757
## 4 0.010989011 4 0.4322344 0.4798535 0.03787823
## 5 0.007326007 6 0.4102564 0.5091575 0.03874859
## 6 0.003663004 12 0.3589744 0.4835165 0.03798985
## 7 0.000001000 15 0.3479853 0.4798535 0.03787823
## plot cross-validation results
plotcp(mod1)
par(family = "STKaiti")
rpart.plot(mod1, type = 2,extra="auto", under=TRUE,
fallen.leaves = FALSE,cex=0.7, main="决策树")
## 查看模型在训练集和测试集上的预测效果
pre_train <- predict(mod1,train_data,type = "prob")
pre_train2<-as.factor(as.vector(ifelse(pre_train[,2]>0.5,1,0)))
pre_test <- predict(mod1,test_data)
pre_test2<-as.factor(as.vector(ifelse(pre_test[,2]>0.5,1,0)))
sprintf("决策树模型在训练集精度为:%f",accuracy(train_data$Survived,pre_train2))
## [1] "决策树模型在训练集精度为:0.866760"
sprintf("决策树模型在测试集精度为:%f",accuracy(test_data$Survived,pre_test2))
## [1] "决策树模型在测试集精度为:0.820225"
## 计算混淆矩阵和模型的精度
cfm <- confusionMatrix(pre_test2,as.factor(test_data$Survived))
cfm$table
## Reference
## Prediction 0 1
## 0 97 20
## 1 12 49
cfm
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 97 20
## 1 12 49
##
## Accuracy : 0.8202
## 95% CI : (0.7558, 0.8737)
## No Information Rate : 0.6124
## P-Value [Acc > NIR] : 1.645e-09
##
## Kappa : 0.6131
##
## Mcnemar's Test P-Value : 0.2159
##
## Sensitivity : 0.8899
## Specificity : 0.7101
## Pos Pred Value : 0.8291
## Neg Pred Value : 0.8033
## Prevalence : 0.6124
## Detection Rate : 0.5449
## Detection Prevalence : 0.6573
## Balanced Accuracy : 0.8000
##
## 'Positive' Class : 0
##
bestcp <- mod1$cptable[which.min(mod1$cptable[,"xerror"]),"CP"]
bestcp
## [1] 0.01465201
# Step3: Prune the tree using the best cp.
mod1.pruned <- prune(mod1, cp = bestcp)
## plot cross-validation results
plotcp(mod1.pruned)
## 可视化剪枝后的决策树
par(family = "STKaiti")
rpart.plot(mod1.pruned, type = 2,extra="auto", under=TRUE,
fallen.leaves = FALSE,cex=0.7, main="剪枝后决策树")
## 查看剪枝后模型在训练集和测试集上的预测效果
pre_train_p <- predict(mod1.pruned,train_data)
pre_train_p2<-as.factor(as.vector(ifelse(pre_train_p[,2]>0.5,1,0)))
pre_test_p <- predict(mod1.pruned,test_data)
pre_test_p2<-as.factor(as.vector(ifelse(pre_test_p[,2]>0.5,1,0)))
sprintf("剪枝后决策树模型在训练集精度为:%f",accuracy(train_data$Survived,pre_train_p2))
## [1] "剪枝后决策树模型在训练集精度为:0.828892"
sprintf("剪枝后决策树模型在测试集精度为:%f",accuracy(test_data$Survived,pre_test_p2))
## [1] "剪枝后决策树模型在测试集精度为:0.808989"
## 计算混淆矩阵和模型的精度
cfm <- confusionMatrix(pre_test_p2,as.factor(test_data$Survived))
cfm$table
## Reference
## Prediction 0 1
## 0 94 19
## 1 15 50
cfm
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 94 19
## 1 15 50
##
## Accuracy : 0.809
## 95% CI : (0.7434, 0.8639)
## No Information Rate : 0.6124
## P-Value [Acc > NIR] : 1.294e-08
##
## Kappa : 0.5933
##
## Mcnemar's Test P-Value : 0.6069
##
## Sensitivity : 0.8624
## Specificity : 0.7246
## Pos Pred Value : 0.8319
## Neg Pred Value : 0.7692
## Prevalence : 0.6124
## Detection Rate : 0.5281
## Detection Prevalence : 0.6348
## Balanced Accuracy : 0.7935
##
## 'Positive' Class : 0
##
## 绘制决策树剪枝前后在测试集上的ROC曲线
## 绘制出ROC曲线对比两种模型的效果
## 计算决策树模型的ROC坐标
pr <- prediction(pre_test[,2], test_data$Survived)
prf <- performance(pr, measure = "tpr", x.measure = "fpr")
prfdf <- data.frame(x = prf@x.values[[1]],
y = prf@y.values[[1]],
model = "rpart")
## 计算剪枝后决策树模型的ROC坐标
pr <- prediction(pre_test_p[,2], test_data$Survived)
prf <- performance(pr, measure = "tpr", x.measure = "fpr")
prfdf2 <- data.frame(x = prf@x.values[[1]],
y = prf@y.values[[1]],
model = "rpart.prund")
## 合并数据
prfdf <- rbind.data.frame(prfdf,prfdf2)
## plot ROC
ggplot(prfdf,aes(x= x,y = y,colour = model))+
geom_line(aes(linetype = model),size = 1)+
theme(aspect.ratio=1)+
labs(x = "假正例率",y = "真正例率")
## 计算AUC的取值
auc(test_data$Survived,as.vector(pre_test[,2]))
## [1] 0.8704295
auc(test_data$Survived,as.vector(pre_test_p[,2]))
## [1] 0.8336657
library(tidyr)
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
library(ggRandomForests)
## Loading required package: randomForestSRC
##
## randomForestSRC 2.9.2
##
## Type rfsrc.news() to see new features, changes, and bug fixes.
##
##
## Attaching package: 'ggRandomForests'
## The following object is masked from 'package:randomForestSRC':
##
## partial.rfsrc
library(caret)
train_data$Survived <- as.factor(train_data$Survived)
rfcla <- randomForest(Survived~.,data = train_data,ntree=200, proximity=TRUE)
summary(rfcla)
## Length Class Mode
## call 5 -none- call
## type 1 -none- character
## predicted 713 factor numeric
## err.rate 600 -none- numeric
## confusion 6 -none- numeric
## votes 1426 matrix numeric
## oob.times 713 -none- numeric
## classes 2 -none- character
## importance 8 -none- numeric
## importanceSD 0 -none- NULL
## localImportance 0 -none- NULL
## proximity 508369 -none- numeric
## ntree 1 -none- numeric
## mtry 1 -none- numeric
## forest 14 -none- list
## y 713 factor numeric
## test 0 -none- NULL
## inbag 0 -none- NULL
## terms 3 terms call
## 可视化随机森林的训练过程,随着树的增加训练误差的变化
trainerror <- as.data.frame(plot(rfcla,type = "l"))
colnames(trainerror) <- paste("error",colnames(trainerror),sep = "")
trainerror$ntree <- 1:nrow(trainerror)
trainerror <- gather(trainerror,key = "Type",value = "Error",1:3)
ggplot(trainerror,aes(x = ntree,y = Error))+
geom_line(aes(linetype = Type,colour = Type))+
#theme(legend.position = "bottom")+
ggtitle("随机森林分类模型")+
theme(plot.title = element_text(hjust = 0.5))
## 模型的精度逐渐趋于平稳
## 或者使用 ggRandomForests包可视化误差
plot(gg_error(rfcla))
## 从randomForest绘制邻近矩阵的标准化坐标。
MDSplot(rfcla,train_data$Survived,k = 2 , palette=c(1, 2),
pch=20+as.numeric(train_data$Survived))
## 可视化变量的重要性
importance(rfcla)
## MeanDecreaseGini
## Pclass 22.696178
## Name 55.122124
## Sex 52.632961
## Age 33.436435
## SibSp 13.921757
## Parch 9.361209
## Fare 42.790499
## Embarked 10.316719
varImpPlot(rfcla,pch = 20, main = "Importance of Variables")
## 查看模型在测试集上的精度
rfclapre<- predict(rfcla,test_data)
sprintf("随机森林模型测试集精度为:%f",accuracy(test_data$Survived,rfclapre))
## [1] "随机森林模型测试集精度为:0.837079"
Ttrainp$Survived <- as.factor(Ttrainp$Survived)
rfclanew <- randomForest(Survived~.,data = Ttrainp,ntree=200, proximity=TRUE)
summary(rfclanew)
## Length Class Mode
## call 5 -none- call
## type 1 -none- character
## predicted 891 factor numeric
## err.rate 600 -none- numeric
## confusion 6 -none- numeric
## votes 1782 matrix numeric
## oob.times 891 -none- numeric
## classes 2 -none- character
## importance 8 -none- numeric
## importanceSD 0 -none- NULL
## localImportance 0 -none- NULL
## proximity 793881 -none- numeric
## ntree 1 -none- numeric
## mtry 1 -none- numeric
## forest 14 -none- list
## y 891 factor numeric
## test 0 -none- NULL
## inbag 0 -none- NULL
## terms 3 terms call
rfclanew
##
## Call:
## randomForest(formula = Survived ~ ., data = Ttrainp, ntree = 200, proximity = TRUE)
## Type of random forest: classification
## Number of trees: 200
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 16.05%
## Confusion matrix:
## 0 1 class.error
## 0 502 47 0.0856102
## 1 96 246 0.2807018
## 预测测试集
Ttestpre <- predict(rfclanew,Ttestp)
table(Ttestpre)
## Ttestpre
## 0 1
## 277 141
## 使用随机森林的方法,对ENB2012数据进行回归分析
ENB <- read_excel("data/chap9/ENB2012.xlsx")
head(ENB)
## # A tibble: 6 x 9
## X1 X2 X3 X4 X5 X6 X7 X8 Y1
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 0.98 514. 294 110. 7 2 0 0 15.6
## 2 0.98 514. 294 110. 7 3 0 0 15.6
## 3 0.98 514. 294 110. 7 4 0 0 15.6
## 4 0.98 514. 294 110. 7 5 0 0 15.6
## 5 0.9 564. 318. 122. 7 2 0 0 20.8
## 6 0.9 564. 318. 122. 7 3 0 0 21.5
summary(ENB)
## X1 X2 X3 X4
## Min. :0.6200 Min. :514.5 Min. :245.0 Min. :110.2
## 1st Qu.:0.6825 1st Qu.:606.4 1st Qu.:294.0 1st Qu.:140.9
## Median :0.7500 Median :673.8 Median :318.5 Median :183.8
## Mean :0.7642 Mean :671.7 Mean :318.5 Mean :176.6
## 3rd Qu.:0.8300 3rd Qu.:741.1 3rd Qu.:343.0 3rd Qu.:220.5
## Max. :0.9800 Max. :808.5 Max. :416.5 Max. :220.5
## X5 X6 X7 X8 Y1
## Min. :3.50 Min. :2.00 Min. :0.0000 Min. :0.000 Min. : 6.01
## 1st Qu.:3.50 1st Qu.:2.75 1st Qu.:0.1000 1st Qu.:1.750 1st Qu.:12.99
## Median :5.25 Median :3.50 Median :0.2500 Median :3.000 Median :18.95
## Mean :5.25 Mean :3.50 Mean :0.2344 Mean :2.812 Mean :22.31
## 3rd Qu.:7.00 3rd Qu.:4.25 3rd Qu.:0.4000 3rd Qu.:4.000 3rd Qu.:31.67
## Max. :7.00 Max. :5.00 Max. :0.4000 Max. :5.000 Max. :43.10
str(ENB)
## Classes 'tbl_df', 'tbl' and 'data.frame': 768 obs. of 9 variables:
## $ X1: num 0.98 0.98 0.98 0.98 0.9 0.9 0.9 0.9 0.86 0.86 ...
## $ X2: num 514 514 514 514 564 ...
## $ X3: num 294 294 294 294 318 ...
## $ X4: num 110 110 110 110 122 ...
## $ X5: num 7 7 7 7 7 7 7 7 7 7 ...
## $ X6: num 2 3 4 5 2 3 4 5 2 3 ...
## $ X7: num 0 0 0 0 0 0 0 0 0 0 ...
## $ X8: num 0 0 0 0 0 0 0 0 0 0 ...
## $ Y1: num 15.6 15.6 15.6 15.6 20.8 ...
## 数据切分为训练集和测试集,训练集70%
set.seed(12)
index <- sample(nrow(ENB),round(nrow(ENB)*0.7))
trainEnb <- ENB[index,]
testENB <- ENB[-index,]
## 建立随机森林回归模型
rfreg <- randomForest(Y1~.,data = trainEnb,ntree=500)
summary(rfreg)
## Length Class Mode
## call 4 -none- call
## type 1 -none- character
## predicted 538 -none- numeric
## mse 500 -none- numeric
## rsq 500 -none- numeric
## oob.times 538 -none- numeric
## importance 8 -none- numeric
## importanceSD 0 -none- NULL
## localImportance 0 -none- NULL
## proximity 0 -none- NULL
## ntree 1 -none- numeric
## mtry 1 -none- numeric
## forest 11 -none- list
## coefs 0 -none- NULL
## y 538 -none- numeric
## test 0 -none- NULL
## inbag 0 -none- NULL
## terms 3 terms call
## 可视化模型随着树的增加误差OOB的变化
par(family = "STKaiti")
plot(rfreg,type = "l",col = "red",main = "随机森林回归")
## 使用ggrandomforest包可视化误差
plot(gg_error(rfreg))+labs(title = "随机森林回归")
## 可视化变量的重要性
importance(rfreg)
## IncNodePurity
## X1 14525.01820
## X2 12354.57214
## X3 3224.53478
## X4 8846.38752
## X5 10807.09663
## X6 77.46897
## X7 2353.68826
## X8 945.38532
varImpPlot(rfreg,pch = 20, main = "Importance of Variables")
## 对测试集进行预测,并计算 Mean Squared Error
rfpre <- predict(rfreg,testENB)
sprintf("均方根误差为: %f",mse(testENB$Y1,rfpre))
## [1] "均方根误差为: 1.332674"
## 参数搜索,寻找合适的 mtry参数,训练更好的模型
## Tune randomForest for the optimal mtry parameter
set.seed(1234)
rftune <- tuneRF(x = trainEnb[,1:8],y = trainEnb$Y1,
stepFactor=1.5,ntreeTry = 500)
## mtry = 2 OOB error = 1.335695
## Searching left ...
## Searching right ...
## mtry = 3 OOB error = 0.6641563
## 0.5027634 0.05
## mtry = 4 OOB error = 0.417198
## 0.3718375 0.05
## mtry = 6 OOB error = 0.3555105
## 0.1478614 0.05
## mtry = 8 OOB error = 0.3629043
## -0.02079754 0.05
print(rftune)
## mtry OOBError
## 2 2 1.3356946
## 3 3 0.6641563
## 4 4 0.4171980
## 6 6 0.3555105
## 8 8 0.3629043
## OOBError误差最小的mtry参数为6
## 建立优化后的随机森林回归模型
rfregbest <- randomForest(Y1~.,data = trainEnb,ntree=500,mtry = 6)
## 可视化两种模型随着树的增加误差OOB的变化
rfregerr <- as.data.frame(plot(rfreg))
colnames(rfregerr) <- "rfregerr"
rfregbesterr <- as.data.frame(plot(rfregbest))
colnames(rfregbesterr) <- "rfregbesterr"
plotrfdata <- cbind.data.frame(rfregerr,rfregbesterr)
plotrfdata$ntree <- 1:nrow(plotrfdata)
plotrfdata <- gather(plotrfdata,key = "Type",value = "Error",1:2)
ggplot(plotrfdata,aes(x = ntree,y = Error))+
geom_line(aes(linetype = Type,colour = Type),size = 0.9)+
theme(legend.position = "top")+
ggtitle("随机森林回归模型")+
theme(plot.title = element_text(hjust = 0.5))
## 使用优化后的随机森林回归模型,对测试集进行预测,并计算 Mean Squared Error
rfprebest <- predict(rfregbest,testENB)
sprintf("优化后均方根误差为: %f",mse(testENB$Y1,rfprebest))
## [1] "优化后均方根误差为: 0.421116"
## 数据准备
index <- order(testENB$Y1)
X <- sort(index)
Y1 <- testENB$Y1[index]
rfpre2 <- rfpre[index]
rfprebest2 <- rfprebest[index]
plotdata <- data.frame(X = X,Y1 = Y1,rfpre =rfpre2,rfprebest = rfprebest2)
plotdata <- gather(plotdata,key="model",value="value",c(-X))
## 可视化模型的预测误差
ggplot(plotdata,aes(x = X,y = value))+
geom_line(aes(linetype = model,colour = model),size = 0.8)+
theme(legend.position = c(0.1,0.8),
plot.title = element_text(hjust = 0.5))+
ggtitle("随机森林回归模型")
## 随机森林回归的效果非常好
梯度提升模型是回归或分类树模型的集合。 两者都是前向学习集合方法,通过逐步改进的估计获得预测结果。 Boosting是一种灵活的非线性回归程序,有助于提高树模型的准确性。 通过将弱分类算法顺序地应用于递增改变的数据,创建一系列决策树,其产生弱预测模型的集合。
library(h2o)
##
## ----------------------------------------------------------------------
##
## Your next step is to start H2O:
## > h2o.init()
##
## For H2O package documentation, ask for help:
## > ??h2o
##
## After starting H2O, you can use the Web UI at http://localhost:54321
## For more information visit http://docs.h2o.ai
##
## ----------------------------------------------------------------------
##
## Attaching package: 'h2o'
## The following objects are masked from 'package:data.table':
##
## hour, month, week, year
## The following objects are masked from 'package:stats':
##
## cor, sd, var
## The following objects are masked from 'package:base':
##
## &&, %*%, %in%, ||, apply, as.factor, as.numeric, colnames,
## colnames<-, ifelse, is.character, is.factor, is.numeric, log,
## log10, log1p, log2, round, signif, trunc
library(Metrics)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:randomForest':
##
## combine
## The following objects are masked from 'package:data.table':
##
## between, first, last
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
####################################
##说明:下面启动h2o实例时可能需要更新java版本,在Program\Other文件夹里已提供了jdk-8u221-macosx_64插件,安装后重新运行即可。
###################################
## 启动初始化一个h2o实例; 定义为2核同时计算;
h2o.init(nthreads=2,max_mem_size='4G')
##
## H2O is not running yet, starting it now...
##
## Note: In case of errors look at the following log files:
## /var/folders/dv/qd8nc5bx6xx3pxq_xx_j47240000gn/T//RtmphvRb3k/h2o_sunlanxin_started_from_r.out
## /var/folders/dv/qd8nc5bx6xx3pxq_xx_j47240000gn/T//RtmphvRb3k/h2o_sunlanxin_started_from_r.err
##
##
## Starting H2O JVM and connecting: . Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 1 seconds 645 milliseconds
## H2O cluster timezone: Asia/Shanghai
## H2O data parsing timezone: UTC
## H2O cluster version: 3.26.0.2
## H2O cluster version age: 4 months and 28 days !!!
## H2O cluster name: H2O_started_from_R_sunlanxin_aui415
## H2O cluster total nodes: 1
## H2O cluster total memory: 3.56 GB
## H2O cluster total cores: 8
## H2O cluster allowed cores: 2
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## H2O API Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4
## R Version: R version 3.6.2 (2019-12-12)
## Warning in h2o.clusterInfo():
## Your H2O cluster version is too old (4 months and 28 days)!
## Please download and install the latest version from http://h2o.ai/download/
## 读取数据
train<-h2o.uploadFile("data/chap9/Titanic处理后数据.csv",
destination_frame = "train.hex")
##
|
| | 0%
|
|======================================================================| 100%
str(train)
## Class 'H2OFrame' <environment: 0x7fc58d67b318>
## - attr(*, "op")= chr "Parse"
## - attr(*, "id")= chr "train.hex"
## - attr(*, "eval")= logi FALSE
## - attr(*, "nrow")= int 891
## - attr(*, "ncol")= int 9
## - attr(*, "types")=List of 9
## ..$ : chr "int"
## ..$ : chr "enum"
## ..$ : chr "enum"
## ..$ : chr "real"
## ..$ : chr "int"
## ..$ : chr "int"
## ..$ : chr "real"
## ..$ : chr "enum"
## ..$ : chr "int"
## - attr(*, "data")='data.frame': 10 obs. of 9 variables:
## ..$ Pclass : num 3 1 3 1 3 3 1 3 3 2
## ..$ Name : Factor w/ 5 levels "Master.","Miss.",..: 3 4 2 4 3 3 3 1 4 4
## ..$ Sex : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1
## ..$ Age : num 22 38 26 35 35 28 54 2 27 14
## ..$ SibSp : num 1 1 0 1 0 0 0 3 0 1
## ..$ Parch : num 0 0 0 0 0 0 0 1 2 0
## ..$ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
## ..$ Embarked: Factor w/ 3 levels "C","Q","S": 3 1 3 3 3 2 3 3 3 1
## ..$ Survived: num 0 1 1 1 0 0 0 0 1 1
train$Survived <- as.factor(train$Survived)
colnames(train)
## [1] "Pclass" "Name" "Sex" "Age" "SibSp" "Parch" "Fare"
## [8] "Embarked" "Survived"
head(train)
## Pclass Name Sex Age SibSp Parch Fare Embarked Survived
## 1 3 Mr. male 22 1 0 7.2500 S 0
## 2 1 Mrs. female 38 1 0 71.2833 C 1
## 3 3 Miss. female 26 0 0 7.9250 S 1
## 4 1 Mrs. female 35 1 0 53.1000 S 1
## 5 3 Mr. male 35 0 0 8.0500 S 0
## 6 3 Mr. male 28 0 0 8.4583 Q 0
## 数据切分为训练集和测试集,70%训练集,30%测试集
splits <- h2o.splitFrame(data = train, ratios = 0.7,seed = 1234)
train_data <- splits[[1]]
test_data <- splits[[2]]
dim(train_data)
## [1] 634 9
head(train_data)
## Pclass Name Sex Age SibSp Parch Fare Embarked Survived
## 1 1 Mrs. female 38 1 0 71.2833 C 1
## 2 3 Miss. female 26 0 0 7.9250 S 1
## 3 3 Mr. male 35 0 0 8.0500 S 0
## 4 3 Mr. male 28 0 0 8.4583 Q 0
## 5 1 Mr. male 54 0 0 51.8625 S 0
## 6 3 Master. male 2 3 1 21.0750 S 0
dim(test_data)
## [1] 257 9
## GBM模型
name1 <- colnames(train)
predictors <- name1[1:8]
target <- "Survived"
gbm <- h2o.gbm(x = predictors, y = target,
training_frame = train_data,
distribution="bernoulli", ## 二分类模型
ntrees = 100, ## 模型使用数的树量
learn_rate=0.01, ## 学习率
sample_rate = 0.8,## 每棵树使用80%的样本
col_sample_rate = 0.6,## 每次拆分使用80%的特征
seed = 1234)
##
|
| | 0%
|
|======================================================================| 100%
summary(gbm)
## Model Details:
## ==============
##
## H2OBinomialModel: gbm
## Model Key: GBM_model_R_1577260091654_1
## Model Summary:
## number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1 100 100 30728 5
## max_depth mean_depth min_leaves max_leaves mean_leaves
## 1 5 5.00000 14 26 19.79000
##
## H2OBinomialMetrics: gbm
## ** Reported on training data. **
##
## MSE: 0.1264054
## RMSE: 0.3555354
## LogLoss: 0.4216086
## Mean Per-Class Error: 0.1451068
## AUC: 0.9286271
## pr_auc: 0.8944934
## Gini: 0.8572543
## R^2: 0.4571642
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## 0 1 Error Rate
## 0 354 46 0.115000 =46/400
## 1 41 193 0.175214 =41/234
## Totals 395 239 0.137224 =87/634
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.401407 0.816068 182
## 2 max f2 0.200353 0.850669 313
## 3 max f0point5 0.516888 0.863029 122
## 4 max accuracy 0.477965 0.873817 144
## 5 max precision 0.768800 1.000000 0
## 6 max recall 0.190450 1.000000 345
## 7 max specificity 0.768800 1.000000 0
## 8 max absolute_mcc 0.477965 0.726296 144
## 9 max min_per_class_accuracy 0.354077 0.841880 200
## 10 max mean_per_class_accuracy 0.401407 0.854893 182
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
##
##
##
## Scoring History:
## timestamp duration number_of_trees training_rmse training_logloss
## 1 2019-12-25 15:48:14 0.012 sec 0 0.48256 0.65847
## 2 2019-12-25 15:48:14 0.149 sec 1 0.48026 0.65371
## 3 2019-12-25 15:48:14 0.168 sec 2 0.47773 0.64854
## 4 2019-12-25 15:48:14 0.183 sec 3 0.47553 0.64407
## 5 2019-12-25 15:48:14 0.194 sec 4 0.47346 0.63987
## training_auc training_pr_auc training_lift training_classification_error
## 1 0.50000 0.00000 1.00000 0.63091
## 2 0.88092 0.62533 2.65828 0.17666
## 3 0.90779 0.86243 2.60519 0.14038
## 4 0.90978 0.85113 2.70940 0.14669
## 5 0.91073 0.87802 2.70940 0.14511
##
## ---
## timestamp duration number_of_trees training_rmse
## 96 2019-12-25 15:48:15 0.817 sec 95 0.35867
## 97 2019-12-25 15:48:15 0.823 sec 96 0.35801
## 98 2019-12-25 15:48:15 0.829 sec 97 0.35741
## 99 2019-12-25 15:48:15 0.835 sec 98 0.35681
## 100 2019-12-25 15:48:15 0.840 sec 99 0.35616
## 101 2019-12-25 15:48:15 0.845 sec 100 0.35554
## training_logloss training_auc training_pr_auc training_lift
## 96 0.42750 0.92811 0.89364 2.70940
## 97 0.42625 0.92816 0.89372 2.70940
## 98 0.42512 0.92819 0.89369 2.70940
## 99 0.42400 0.92821 0.89369 2.70940
## 100 0.42276 0.92825 0.89385 2.70940
## 101 0.42161 0.92863 0.89449 2.70940
## training_classification_error
## 96 0.13880
## 97 0.13722
## 98 0.13722
## 99 0.13722
## 100 0.13722
## 101 0.13722
##
## Variable Importances: (Extract with `h2o.varimp`)
## =================================================
##
## Variable Importances:
## variable relative_importance scaled_importance percentage
## 1 Name 1028.667114 1.000000 0.324098
## 2 Sex 848.877869 0.825221 0.267452
## 3 Fare 411.877594 0.400399 0.129769
## 4 Pclass 381.371033 0.370743 0.120157
## 5 Age 280.832764 0.273006 0.088481
## 6 SibSp 153.149597 0.148882 0.048252
## 7 Embarked 38.354553 0.037286 0.012084
## 8 Parch 30.808376 0.029950 0.009707
## 可视化模型中变量的重要性
h2o.varimp_plot(gbm)
## 计算模型在测试集上的预测值和性能
gbmpre <- as.data.frame(h2o.predict(gbm, newdata = test_data))
##
|
| | 0%
|
|======================================================================| 100%
head(gbmpre)
## predict p0 p1
## 1 0 0.7895975 0.2104025
## 2 1 0.2456006 0.7543994
## 3 1 0.2885600 0.7114400
## 4 1 0.4053528 0.5946472
## 5 0 0.8299843 0.1700157
## 6 0 0.6812376 0.3187624
acc <- accuracy(as.vector(test_data$Survived),gbmpre$predict)
auc <- h2o.auc(h2o.performance(gbm, newdata = test_data))
sprintf("GBM model acc: %f",acc)
## [1] "GBM model acc: 0.832685"
sprintf("GBM model AUC: %f",auc)
## [1] "GBM model AUC: 0.869749"
## 使用参数网格搜索,寻找更合适的模型
ntrees_opt <- c(20,50,100,200,500) ## 树的数量
maxdepth_opt <- c(2,4,6,8,10) ## 树的最大深度
balance_opt <- c(TRUE,FALSE) ## 是否对数据进行类别平衡
hyper_par <- list(ntrees=ntrees_opt, max_depth=maxdepth_opt,
balance_classes= balance_opt)
## 使用GBM模型进行超参数搜索
grid <- h2o.grid("gbm", hyper_params = hyper_par,grid_id = "gbm_grid_mol.hex",
x = predictors, y = target, distribution="bernoulli",
training_frame =train_data,learn_rate=0.01)
##
|
| | 0%
|
|== | 3%
|
|===== | 7%
|
|========= | 12%
|
|============ | 17%
|
|=============== | 22%
|
|=================== | 27%
|
|====================== | 31%
|
|======================== | 35%
|
|=========================== | 38%
|
|============================= | 42%
|
|================================ | 46%
|
|=================================== | 49%
|
|===================================== | 53%
|
|======================================== | 57%
|
|========================================== | 60%
|
|============================================ | 63%
|
|============================================== | 66%
|
|================================================= | 69%
|
|================================================== | 72%
|
|===================================================== | 75%
|
|====================================================== | 77%
|
|======================================================== | 81%
|
|========================================================== | 83%
|
|============================================================ | 86%
|
|============================================================== | 88%
|
|================================================================ | 91%
|
|================================================================== | 94%
|
|=================================================================== | 96%
|
|===================================================================== | 99%
|
|======================================================================| 100%
## 查看模型的输出
sortedGrid <- h2o.getGrid("gbm_grid_mol.hex", sort_by=c("accuracy"),
decreasing = TRUE)
sortedGrid@summary_table%>%head()
## Hyper-Parameter Search Summary: ordered by decreasing accuracy
## balance_classes max_depth ntrees model_ids accuracy
## 1 true 10 500 gbm_grid_mol.hex_model_49 0.9501246882793017
## 2 true 8 500 gbm_grid_mol.hex_model_47 0.947109471094711
## 3 false 10 500 gbm_grid_mol.hex_model_50 0.9416403785488959
## 4 false 8 500 gbm_grid_mol.hex_model_48 0.9369085173501577
## 5 false 6 500 gbm_grid_mol.hex_model_46 0.9242902208201893
## 6 true 6 500 gbm_grid_mol.hex_model_45 0.9104665825977302
## 将搜索的每个模型均作用于测试集,查看在测试集上的精度
grid_models <- lapply(grid@model_ids,
function(model_id) {model = h2o.getModel(model_id) })
acc <- vector()
modelid <- vector()
for (i in 1:length(grid_models)) {
gbmpre <- as.data.frame(h2o.predict(grid_models[[i]], newdata = test_data))
acc[i] <- accuracy(as.vector(test_data$Survived),gbmpre$predict)
modelid[i] <- grid_models[[i]]@model_id
}
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
##
|
| | 0%
|
|======================================================================| 100%
data.frame(modelid = modelid,acc = acc) %>%
inner_join(sortedGrid@summary_table,by = c("modelid"="model_ids"))%>%
group_by(modelid)%>%
arrange(desc(acc))%>%
head()
## Warning: Column `modelid`/`model_ids` joining factor and character vector,
## coercing into character vector
## # A tibble: 6 x 6
## # Groups: modelid [6]
## modelid acc balance_classes max_depth ntrees accuracy
## <chr> <dbl> <chr> <chr> <chr> <chr>
## 1 gbm_grid_mol.hex_mode… 0.844 false 6 500 0.9242902208201…
## 2 gbm_grid_mol.hex_mode… 0.840 false 4 500 0.9006309148264…
## 3 gbm_grid_mol.hex_mode… 0.840 false 8 200 0.8974763406940…
## 4 gbm_grid_mol.hex_mode… 0.840 true 8 200 0.9028642590286…
## 5 gbm_grid_mol.hex_mode… 0.837 false 10 500 0.9416403785488…
## 6 gbm_grid_mol.hex_mode… 0.837 false 10 200 0.8974763406940…
仍然使用建筑能耗数据集
library(h2o)
library(Metrics)
library(dplyr)
library(readxl)
library(tidyr)
library(ggplot2)
## 启动初始化一个h2o实例; 定义为2核同时计算;
h2o.init(nthreads=2,max_mem_size='4G')
## Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 1 minutes 8 seconds
## H2O cluster timezone: Asia/Shanghai
## H2O data parsing timezone: UTC
## H2O cluster version: 3.26.0.2
## H2O cluster version age: 4 months and 28 days !!!
## H2O cluster name: H2O_started_from_R_sunlanxin_aui415
## H2O cluster total nodes: 1
## H2O cluster total memory: 3.45 GB
## H2O cluster total cores: 8
## H2O cluster allowed cores: 2
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## H2O API Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4
## R Version: R version 3.6.2 (2019-12-12)
## Warning in h2o.clusterInfo():
## Your H2O cluster version is too old (4 months and 28 days)!
## Please download and install the latest version from http://h2o.ai/download/
## 使用GBM,对ENB2012数据进行回归分析
ENB <- read_excel("data/chap9/ENB2012.xlsx")
head(ENB)
## # A tibble: 6 x 9
## X1 X2 X3 X4 X5 X6 X7 X8 Y1
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 0.98 514. 294 110. 7 2 0 0 15.6
## 2 0.98 514. 294 110. 7 3 0 0 15.6
## 3 0.98 514. 294 110. 7 4 0 0 15.6
## 4 0.98 514. 294 110. 7 5 0 0 15.6
## 5 0.9 564. 318. 122. 7 2 0 0 20.8
## 6 0.9 564. 318. 122. 7 3 0 0 21.5
summary(ENB)
## X1 X2 X3 X4
## Min. :0.6200 Min. :514.5 Min. :245.0 Min. :110.2
## 1st Qu.:0.6825 1st Qu.:606.4 1st Qu.:294.0 1st Qu.:140.9
## Median :0.7500 Median :673.8 Median :318.5 Median :183.8
## Mean :0.7642 Mean :671.7 Mean :318.5 Mean :176.6
## 3rd Qu.:0.8300 3rd Qu.:741.1 3rd Qu.:343.0 3rd Qu.:220.5
## Max. :0.9800 Max. :808.5 Max. :416.5 Max. :220.5
## X5 X6 X7 X8 Y1
## Min. :3.50 Min. :2.00 Min. :0.0000 Min. :0.000 Min. : 6.01
## 1st Qu.:3.50 1st Qu.:2.75 1st Qu.:0.1000 1st Qu.:1.750 1st Qu.:12.99
## Median :5.25 Median :3.50 Median :0.2500 Median :3.000 Median :18.95
## Mean :5.25 Mean :3.50 Mean :0.2344 Mean :2.812 Mean :22.31
## 3rd Qu.:7.00 3rd Qu.:4.25 3rd Qu.:0.4000 3rd Qu.:4.000 3rd Qu.:31.67
## Max. :7.00 Max. :5.00 Max. :0.4000 Max. :5.000 Max. :43.10
str(ENB)
## Classes 'tbl_df', 'tbl' and 'data.frame': 768 obs. of 9 variables:
## $ X1: num 0.98 0.98 0.98 0.98 0.9 0.9 0.9 0.9 0.86 0.86 ...
## $ X2: num 514 514 514 514 564 ...
## $ X3: num 294 294 294 294 318 ...
## $ X4: num 110 110 110 110 122 ...
## $ X5: num 7 7 7 7 7 7 7 7 7 7 ...
## $ X6: num 2 3 4 5 2 3 4 5 2 3 ...
## $ X7: num 0 0 0 0 0 0 0 0 0 0 ...
## $ X8: num 0 0 0 0 0 0 0 0 0 0 ...
## $ Y1: num 15.6 15.6 15.6 15.6 20.8 ...
## 数据切分为训练集和测试集,训练集70%
set.seed(12)
index <- sample(nrow(ENB),round(nrow(ENB)*0.7))
trainEnb <- as.h2o(ENB[index,])
##
|
| | 0%
|
|======================================================================| 100%
testENB <- as.h2o(ENB[-index,])
##
|
| | 0%
|
|======================================================================| 100%
## GBM回归模型
name1 <- colnames(trainEnb)
predictors <- name1[1:8]
target <- "Y1"
## 使用训练集训练一个基础GBM回归模型
gbmreg <- h2o.gbm(x = predictors, y = target,
training_frame = trainEnb,
distribution="AUTO", ## 回归模型
ntrees = 100,seed = 1234)
##
|
| | 0%
|
|====== | 9%
|
|======================================================================| 100%
## 查看在测试集上的模型表达能力
h2o.performance(gbmreg,testENB)
## H2ORegressionMetrics: gbm
##
## MSE: 0.233291
## RMSE: 0.4830021
## MAE: 0.347336
## RMSLE: 0.02502702
## Mean Residual Deviance : 0.233291
## R^2 : 0.9977173
## 使用参数网格搜索,寻找更合适的模型
ntrees_opt <- c(50,100,200,500) ## 树的数量
maxdepth_opt <- c(2,4,6,8,10) ## 树的最大深度
hyper_par <- list(ntrees=ntrees_opt, max_depth=maxdepth_opt)
## 使用GBM模型进行超参数搜索
gbm_grid_reg <- h2o.grid(algorithm="gbm", x = predictors,
grid_id ="gbm_grid_reg",
y = target,distribution="AUTO",
training_frame = trainEnb,hyper_params = hyper_par)
##
|
| | 0%
|
|============ | 16%
|
|====================== | 32%
|
|================================ | 45%
|
|====================================== | 55%
|
|============================================= | 64%
|
|==================================================== | 74%
|
|========================================================= | 81%
|
|============================================================= | 87%
|
|=================================================================== | 96%
|
|======================================================================| 100%
## 查看模型的输出
sortedGrid <- h2o.getGrid("gbm_grid_reg", sort_by="mse",
decreasing = FALSE)
sortedGrid@summary_table%>%head()
## Hyper-Parameter Search Summary: ordered by increasing mse
## max_depth ntrees model_ids mse
## 1 10 500 gbm_grid_reg_model_20 0.015489566841374501
## 2 8 500 gbm_grid_reg_model_19 0.02014229884380941
## 3 6 500 gbm_grid_reg_model_18 0.03779650631098121
## 4 10 200 gbm_grid_reg_model_15 0.05339006823934593
## 5 8 200 gbm_grid_reg_model_14 0.059855247801575646
## 6 4 500 gbm_grid_reg_model_17 0.07499614778959614
## 较好的模型为参数ntree=500,maxdepth=10
## 使用新的参数重新训练模型
gbmreg <- h2o.gbm(x = predictors, y = target,
training_frame = trainEnb,
distribution="AUTO", ## 回归模型
ntrees = 500, ## 模型使用数的树量
max_depth = 10,seed = 1234)
##
|
| | 0%
|
|== | 3%
|
|======================================================= | 78%
|
|======================================================================| 100%
## 查看在测试集上的预测性能
h2o.performance(gbmreg,newdata = testENB)
## H2ORegressionMetrics: gbm
##
## MSE: 0.1476192
## RMSE: 0.3842124
## MAE: 0.281197
## RMSLE: 0.01928503
## Mean Residual Deviance : 0.1476192
## R^2 : 0.9985556
## 在测试集上的误差为mse = 0.14
## 可视化模型的预测效果
gbmpre <- as.data.frame(h2o.predict(gbmreg,testENB))
##
|
| | 0%
|
|======================================================================| 100%
testENBdf <- as.data.frame(testENB)
index <- order(testENBdf$Y1)
X <- sort(index)
Y1 <- testENBdf$Y1[index]
gbmprebest2 <- gbmpre$predict[index]
plotdata <- data.frame(X = X,Y1 = Y1,gbmprebest = gbmprebest2)
plotdata <- gather(plotdata,key="model",value="value",c(-X))
## 可视化模型的预测误差
ggplot(plotdata,aes(x = X,y = value))+
geom_line(aes(linetype = model,colour = model),size = 0.8)+
theme(legend.position = c(0.1,0.8),
plot.title = element_text(hjust = 0.5))+
ggtitle("GBM回归模型")