深度学习R语言 mlr3 建模,训练,预测,评估(随机森林,Logistic Regression)

本文主要通过使用mlr3包来训练German credit数据集,实现不同的深度学习模型。

1. 加载R使用环境

# 安装官方包,一般情况下大部分常用的包都可以官方安装
# install.packages("tidyverse")
# install.packages("bruceR")
# 
# # 安装Github来源的包
# # 先安装devtools包后才可以安装github来源的包
# 
# install.packages("devtools") 
# devtools::install_github("tidyverse")
# remotes::install_github("tidyverse")

# 加载包
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.1.1     ✓ dplyr   1.0.5
## ✓ tidyr   1.1.3     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(data.table)
## 
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
## 
##     between, first, last
## The following object is masked from 'package:purrr':
## 
##     transpose
library(mlr3)
library(mlr3learners)
library(mlr3viz)
library(ggplot2)

2. 数据描述

German credit data

德国信用数据,可以从rchallenge中获得,目标是使用20个解释变量来判断因变量信用风险(好/坏)

2.1 导入数据

# install.package("rchallenge)
data("german", package = "rchallenge") 

#观察数据
glimpse(german) # 数据类别
## Rows: 1,000
## Columns: 21
## $ status                  <fct> no checking account, no checking account, ... …
## $ duration                <int> 18, 9, 12, 12, 12, 10, 8, 6, 18, 24, 11, 30, 6…
## $ credit_history          <fct> all credits at this bank paid back duly, all c…
## $ purpose                 <fct> car (used), others, retraining, others, others…
## $ amount                  <int> 1049, 2799, 841, 2122, 2171, 2241, 3398, 1361,…
## $ savings                 <fct> unknown/no savings account, unknown/no savings…
## $ employment_duration     <fct> < 1 yr, 1 <= ... < 4 yrs, 4 <= ... < 7 yrs, 1 …
## $ installment_rate        <ord> < 20, 25 <= ... < 35, 25 <= ... < 35, 20 <= ..…
## $ personal_status_sex     <fct> female : non-single or male : single, male : m…
## $ other_debtors           <fct> none, none, none, none, none, none, none, none…
## $ present_residence       <ord> >= 7 yrs, 1 <= ... < 4 yrs, >= 7 yrs, 1 <= ...…
## $ property                <fct> car or other, unknown / no property, unknown /…
## $ age                     <int> 21, 36, 23, 39, 38, 48, 39, 40, 65, 23, 36, 24…
## $ other_installment_plans <fct> none, none, none, none, bank, none, none, none…
## $ housing                 <fct> for free, for free, for free, for free, rent, …
## $ number_credits          <ord> 1, 2-3, 1, 2-3, 2-3, 2-3, 2-3, 1, 2-3, 1, 2-3,…
## $ job                     <fct> skilled employee/official, skilled employee/of…
## $ people_liable           <fct> 0 to 2, 3 or more, 0 to 2, 3 or more, 0 to 2, …
## $ telephone               <fct> no, no, no, no, no, no, no, no, no, no, no, no…
## $ foreign_worker          <fct> no, no, no, yes, yes, yes, yes, yes, no, no, n…
## $ credit_risk             <fct> good, good, good, good, good, good, good, good…
dim(german) # 数据维数
## [1] 1000   21

通过观察发现数据集一共有2000个观测,21个属性(列)。想要预测的因变量是 creadit_risk (good or bad) ,自变量一共有20个,其中 duration, age, amount三个是数值变量,剩余的都是factor因子变量。

可以安装 skimr 包更细致的观察理解变量。

# install.packages("skimr")

skimr::skim(german)

Table: Data summary

Name

german

Number of rows

1000

Number of columns

21

_______________________

Column type frequency:

factor

18

numeric

3

________________________

Group variables

None

Variable type: factor

skim_variable

n_missing

complete_rate

ordered

n_unique

top_counts

status

0

1

FALSE

4

…: 394, no : 274, …: 269, 0<=: 63

credit_history

0

1

FALSE

5

no : 530, all: 293, exi: 88, cri: 49

purpose

0

1

FALSE

10

fur: 280, oth: 234, car: 181, car: 103

savings

0

1

FALSE

5

unk: 603, …: 183, …: 103, 100: 63

employment_duration

0

1

FALSE

5

1 <: 339, >= : 253, 4 <: 174, < 1: 172

installment_rate

0

1

TRUE

4

< 2: 476, 25 : 231, 20 : 157, >= : 136

personal_status_sex

0

1

FALSE

4

mal: 548, fem: 310, fem: 92, mal: 50

other_debtors

0

1

FALSE

3

non: 907, gua: 52, co-: 41

present_residence

0

1

TRUE

4

>= : 413, 1 <: 308, 4 <: 149, < 1: 130

property

0

1

FALSE

4

bui: 332, unk: 282, car: 232, rea: 154

other_installment_plans

0

1

FALSE

3

non: 814, ban: 139, sto: 47

housing

0

1

FALSE

3

ren: 714, for: 179, own: 107

number_credits

0

1

TRUE

4

1: 633, 2-3: 333, 4-5: 28, >= : 6

job

0

1

FALSE

4

ski: 630, uns: 200, man: 148, une: 22

people_liable

0

1

FALSE

2

0 t: 845, 3 o: 155

telephone

0

1

FALSE

2

no: 596, yes: 404

foreign_worker

0

1

FALSE

2

no: 963, yes: 37

credit_risk

0

1

FALSE

2

goo: 700, bad: 300

Variable type: numeric

skim_variable

n_missing

complete_rate

mean

sd

p0

p25

p50

p75

p100

hist

duration

0

1

20.90

12.06

4

12.0

18.0

24.00

72

▇▇▂▁▁

amount

0

1

3271.25

2822.75

250

1365.5

2319.5

3972.25

18424

▇▂▁▁▁

age

0

1

35.54

11.35

19

27.0

33.0

42.00

75

▇▆▃▁▁

3. 建模

通过使用mlr3包来解决信用风险分类问题。构建机器学习工作流程时出现的典型问题是:

  • 我们试图解决的问题是什么?
  • 什么是合适的学习算法?
  • 我们如何评价“好”的表现?

在 mlr3 中更系统地,它们可以通过五个组件来表示:

  1. 任务定义 Task
  2. 学习期定义 Learner
  3. 模型训练 Training
  4. 预测 Prediction
  5. 通过一项或多项措施进行评估 Evaluation

3.1任务定义 Task Definition

首先,我们要确定建模的目标。大多数监督机器学习问题是回归或分类问题。在 mlr3 中,为了区分这些问题,我们定义了任务。如果我们要解决一个分类问题,我们定义一个分类任务——TaskClassif。对于回归问题,我们定义了一个回归任务——TaskRegr。

在我们的例子中,我们的目标显然是对二元因子变量 credit_risk 进行建模或预测。因此,我们定义了一个 TaskClassif:

# germancredit 是任务标签,可以自行定义, german 数据集,target是目标变量
task = TaskClassif$new("germancredit", german , target = "credit_risk")

3.2学习器定义 Learner Definition

在决定建模目标后,我们需要决定如何建模。这意味着我们需要决定哪些学习算法或 Learners 是合适的。使用先验知识(例如,知道这是一项分类任务或假设类是线性可分的)最终会得到一个或多个合适的学习器。

许多学习者可以通过 mlr3learners 包获得。此外,许多学习器是通过 GitHub 上的 mlr3extralearners 包提供的。这两种资源加起来占标准学习算法的很大一部分。

所有可用的学习器(即您从 mlr3、mlr3learners、mlr3extralearners 或自己编写的安装的所有学习器)都在字典 mlr_learners 中获得:

mlr_learners
## <DictionaryLearner> with 29 stored values
## Keys: classif.cv_glmnet, classif.debug, classif.featureless,
##   classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
##   classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
##   classif.ranger, classif.rpart, classif.svm, classif.xgboost,
##   regr.cv_glmnet, regr.featureless, regr.glmnet, regr.kknn, regr.km,
##   regr.lm, regr.ranger, regr.rpart, regr.svm, regr.xgboost,
##   surv.cv_glmnet, surv.glmnet, surv.ranger, surv.xgboost

对于我们的问题,合适的学习器可以是以下之一:Logistic regression逻辑回归、CART、random forest随机森林等。

可以使用 lrn() 函数和学习器的名称来初始化学习器,例如 lrn(“classif.xxx”)。使用 ?mlr_learners_xxx 打开名为 xxx 的学习者的帮助页面。

例如,逻辑回归可以通过以下方式初始化(逻辑回归使用 R 的 glm() 函数,由 mlr3learners 包提供):

library("mlr3learners")
learner_logreg = lrn("classif.log_reg")
print(learner_logreg)
## <LearnerClassifLogReg:classif.log_reg>
## * Model: -
## * Parameters: list()
## * Packages: stats
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: twoclass, weights

3.3 训练 Training

训练是在(训练)数据上拟合模型的过程。

  • 逻辑回归Logistic regression

我们从逻辑回归的例子开始。但是,您会立即看到该过程非常容易推广到任何学习者。

可以使用 $train() 对初始化的学习器进行数据训练:

learner_logreg$train(task)

通常,在机器学习中,我们不使用可用的完整数据,而是使用一个子集,即所谓的训练数据。为了有效地执行数据拆分,可以执行以下操作:

train_set = sample(task$row_ids, 0.8 * task$nrow)
test_set = setdiff(task$row_ids, train_set)

80% 的数据用于训练。剩余的 20% 用于随后进行评估。 train_set 是一个整数向量,指的是原始数据集的选定行:

head(train_set)
## [1] 410 864 543 236 958 851

在 mlr3 中,可以通过附加参数 row_ids = train_set 声明使用数据子集的训练:

learner_logreg$train(task, row_ids = train_set)

训练拟合后的模型可以通过以下命令展示:

learner_logreg$model
## 
## Call:  stats::glm(formula = task$formula(), family = "binomial", data = task$data(), 
##     model = FALSE)
## 
## Coefficients:
##                                               (Intercept)  
##                                                -0.1819216  
##                                                       age  
##                                                 0.0056873  
##                                                    amount  
##                                                -0.0001196  
##    credit_historycritical account/other credits elsewhere  
##                                                -1.0951994  
## credit_historyno credits taken/all credits paid back duly  
##                                                 0.3816992  
##    credit_historyexisting credits paid back duly till now  
##                                                 0.9330591  
##     credit_historyall credits at this bank paid back duly  
##                                                 1.3556494  
##                                                  duration  
##                                                -0.0271785  
##                                 employment_duration< 1 yr  
##                                                -0.0150296  
##                       employment_duration1 <= ... < 4 yrs  
##                                                 0.2004790  
##                       employment_duration4 <= ... < 7 yrs  
##                                                 0.9713337  
##                               employment_duration>= 7 yrs  
##                                                 0.3789241  
##                                          foreign_workerno  
##                                                -1.2704600  
##                                               housingrent  
##                                                 0.6250064  
##                                                housingown  
##                                                 0.6444397  
##                                        installment_rate.L  
##                                                -0.5924806  
##                                        installment_rate.Q  
##                                                 0.0909648  
##                                        installment_rate.C  
##                                                 0.0636166  
##                                   jobunskilled - resident  
##                                                -0.8209089  
##                              jobskilled employee/official  
##                                                -0.7988798  
##             jobmanager/self-empl./highly qualif. employee  
##                                                -0.9088915  
##                                          number_credits.L  
##                                                -0.4671141  
##                                          number_credits.Q  
##                                                 0.0976312  
##                                          number_credits.C  
##                                                 0.0062673  
##                                 other_debtorsco-applicant  
##                                                -0.9178934  
##                                    other_debtorsguarantor  
##                                                 1.3397823  
##                             other_installment_plansstores  
##                                                 0.1427722  
##                               other_installment_plansnone  
##                                                 0.4974245  
##                                       people_liable0 to 2  
##                                                 0.2534176  
##   personal_status_sexfemale : non-single or male : single  
##                                                -0.0183188  
##                 personal_status_sexmale : married/widowed  
##                                                 0.6102816  
##                        personal_status_sexfemale : single  
##                                                 0.0759193  
##                                       present_residence.L  
##                                                -0.1602614  
##                                       present_residence.Q  
##                                                 0.4513743  
##                                       present_residence.C  
##                                                -0.3567466  
##                                      propertycar or other  
##                                                -0.2797497  
##         propertybuilding soc. savings agr./life insurance  
##                                                -0.1006801  
##                                       propertyreal estate  
##                                                -0.7330205  
##                                          purposecar (new)  
##                                                 1.6559118  
##                                         purposecar (used)  
##                                                 0.8993030  
##                                purposefurniture/equipment  
##                                                 0.8574892  
##                                   purposeradio/television  
##                                                -0.0496272  
##                                purposedomestic appliances  
##                                                -0.0426126  
##                                            purposerepairs  
##                                                 0.0285772  
##                                           purposevacation  
##                                                 0.7196447  
##                                         purposeretraining  
##                                                 0.7088115  
##                                           purposebusiness  
##                                                 2.3256145  
##                                      savings... <  100 DM  
##                                                 0.2495854  
##                               savings100 <= ... <  500 DM  
##                                                 0.5232586  
##                               savings500 <= ... < 1000 DM  
##                                                 1.3157498  
##                                     savings... >= 1000 DM  
##                                                 0.9884852  
##                                          status... < 0 DM  
##                                                 0.1314611  
##                                    status0<= ... < 200 DM  
##                                                 0.8973969  
##          status... >= 200 DM / salary for at least 1 year  
##                                                 1.6226985  
##                        telephoneyes (under customer name)  
##                                                 0.3142853  
## 
## Degrees of Freedom: 799 Total (i.e. Null);  745 Residual
## Null Deviance:	    982.4 
## Residual Deviance: 700.6 	AIC: 810.6

可以查看Logistic regression 训练后模型的类型以及总结:

class(learner_logreg$model)
## [1] "glm" "lm"
summary(learner_logreg$model)
## 
## Call:
## stats::glm(formula = task$formula(), family = "binomial", data = task$data(), 
##     model = FALSE)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.7481  -0.6573   0.3599   0.6823   2.0764  
## 
## Coefficients:
##                                                             Estimate Std. Error
## (Intercept)                                               -1.819e-01  1.313e+00
## age                                                        5.687e-03  1.045e-02
## amount                                                    -1.196e-04  5.297e-05
## credit_historycritical account/other credits elsewhere    -1.095e+00  6.830e-01
## credit_historyno credits taken/all credits paid back duly  3.817e-01  4.971e-01
## credit_historyexisting credits paid back duly till now     9.331e-01  5.441e-01
## credit_historyall credits at this bank paid back duly      1.356e+00  4.897e-01
## duration                                                  -2.718e-02  1.083e-02
## employment_duration< 1 yr                                 -1.503e-02  4.935e-01
## employment_duration1 <= ... < 4 yrs                        2.005e-01  4.693e-01
## employment_duration4 <= ... < 7 yrs                        9.713e-01  5.181e-01
## employment_duration>= 7 yrs                                3.789e-01  4.733e-01
## foreign_workerno                                          -1.270e+00  7.304e-01
## housingrent                                                6.250e-01  2.761e-01
## housingown                                                 6.444e-01  5.408e-01
## installment_rate.L                                        -5.925e-01  2.489e-01
## installment_rate.Q                                         9.096e-02  2.255e-01
## installment_rate.C                                         6.362e-02  2.311e-01
## jobunskilled - resident                                   -8.209e-01  7.516e-01
## jobskilled employee/official                              -7.989e-01  7.274e-01
## jobmanager/self-empl./highly qualif. employee             -9.089e-01  7.380e-01
## number_credits.L                                          -4.671e-01  8.489e-01
## number_credits.Q                                           9.763e-02  6.951e-01
## number_credits.C                                           6.267e-03  5.218e-01
## other_debtorsco-applicant                                 -9.179e-01  4.757e-01
## other_debtorsguarantor                                     1.340e+00  4.751e-01
## other_installment_plansstores                              1.428e-01  5.116e-01
## other_installment_plansnone                                4.974e-01  2.944e-01
## people_liable0 to 2                                        2.534e-01  2.831e-01
## personal_status_sexfemale : non-single or male : single   -1.832e-02  4.396e-01
## personal_status_sexmale : married/widowed                  6.103e-01  4.300e-01
## personal_status_sexfemale : single                         7.592e-02  5.179e-01
## present_residence.L                                       -1.603e-01  2.457e-01
## present_residence.Q                                        4.514e-01  2.304e-01
## present_residence.C                                       -3.567e-01  2.293e-01
## propertycar or other                                      -2.797e-01  2.881e-01
## propertybuilding soc. savings agr./life insurance         -1.007e-01  2.790e-01
## propertyreal estate                                       -7.330e-01  4.750e-01
## purposecar (new)                                           1.656e+00  4.260e-01
## purposecar (used)                                          8.993e-01  3.057e-01
## purposefurniture/equipment                                 8.575e-01  2.807e-01
## purposeradio/television                                   -4.963e-02  9.327e-01
## purposedomestic appliances                                -4.261e-02  6.641e-01
## purposerepairs                                             2.858e-02  4.360e-01
## purposevacation                                            7.196e-01  1.287e+00
## purposeretraining                                          7.088e-01  3.815e-01
## purposebusiness                                            2.326e+00  9.776e-01
## savings... <  100 DM                                       2.496e-01  3.377e-01
## savings100 <= ... <  500 DM                                5.233e-01  4.443e-01
## savings500 <= ... < 1000 DM                                1.316e+00  5.692e-01
## savings... >= 1000 DM                                      9.885e-01  2.983e-01
## status... < 0 DM                                           1.315e-01  2.558e-01
## status0<= ... < 200 DM                                     8.974e-01  4.427e-01
## status... >= 200 DM / salary for at least 1 year           1.623e+00  2.681e-01
## telephoneyes (under customer name)                         3.143e-01  2.305e-01
##                                                           z value Pr(>|z|)    
## (Intercept)                                                -0.139 0.889817    
## age                                                         0.544 0.586361    
## amount                                                     -2.259 0.023910 *  
## credit_historycritical account/other credits elsewhere     -1.604 0.108806    
## credit_historyno credits taken/all credits paid back duly   0.768 0.442612    
## credit_historyexisting credits paid back duly till now      1.715 0.086353 .  
## credit_historyall credits at this bank paid back duly       2.768 0.005636 ** 
## duration                                                   -2.511 0.012052 *  
## employment_duration< 1 yr                                  -0.030 0.975704    
## employment_duration1 <= ... < 4 yrs                         0.427 0.669230    
## employment_duration4 <= ... < 7 yrs                         1.875 0.060842 .  
## employment_duration>= 7 yrs                                 0.801 0.423408    
## foreign_workerno                                           -1.739 0.081956 .  
## housingrent                                                 2.264 0.023571 *  
## housingown                                                  1.192 0.233383    
## installment_rate.L                                         -2.380 0.017307 *  
## installment_rate.Q                                          0.403 0.686685    
## installment_rate.C                                          0.275 0.783095    
## jobunskilled - resident                                    -1.092 0.274757    
## jobskilled employee/official                               -1.098 0.272063    
## jobmanager/self-empl./highly qualif. employee              -1.231 0.218137    
## number_credits.L                                           -0.550 0.582157    
## number_credits.Q                                            0.140 0.888294    
## number_credits.C                                            0.012 0.990417    
## other_debtorsco-applicant                                  -1.930 0.053659 .  
## other_debtorsguarantor                                      2.820 0.004806 ** 
## other_installment_plansstores                               0.279 0.780181    
## other_installment_plansnone                                 1.689 0.091136 .  
## people_liable0 to 2                                         0.895 0.370704    
## personal_status_sexfemale : non-single or male : single    -0.042 0.966764    
## personal_status_sexmale : married/widowed                   1.419 0.155847    
## personal_status_sexfemale : single                          0.147 0.883447    
## present_residence.L                                        -0.652 0.514230    
## present_residence.Q                                         1.959 0.050116 .  
## present_residence.C                                        -1.556 0.119724    
## propertycar or other                                       -0.971 0.331602    
## propertybuilding soc. savings agr./life insurance          -0.361 0.718219    
## propertyreal estate                                        -1.543 0.122757    
## purposecar (new)                                            3.887 0.000101 ***
## purposecar (used)                                           2.942 0.003265 ** 
## purposefurniture/equipment                                  3.055 0.002251 ** 
## purposeradio/television                                    -0.053 0.957566    
## purposedomestic appliances                                 -0.064 0.948839    
## purposerepairs                                              0.066 0.947743    
## purposevacation                                             0.559 0.576176    
## purposeretraining                                           1.858 0.063142 .  
## purposebusiness                                             2.379 0.017359 *  
## savings... <  100 DM                                        0.739 0.459905    
## savings100 <= ... <  500 DM                                 1.178 0.238878    
## savings500 <= ... < 1000 DM                                 2.312 0.020794 *  
## savings... >= 1000 DM                                       3.313 0.000922 ***
## status... < 0 DM                                            0.514 0.607376    
## status0<= ... < 200 DM                                      2.027 0.042667 *  
## status... >= 200 DM / salary for at least 1 year            6.052 1.43e-09 ***
## telephoneyes (under customer name)                          1.363 0.172751    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 982.41  on 799  degrees of freedom
## Residual deviance: 700.57  on 745  degrees of freedom
## AIC: 810.57
## 
## Number of Fisher Scoring iterations: 5
  • 随机森林Random forest

就像逻辑回归一样,我们可以训练一个随机森林。我们使用 ranger包快速实现。为此,我们首先需要定义学习器,然后实际训练它。

我们现在另外提供重要性参数(importance = “permutation”)。这样做,我们覆盖默认值,让学习器根据排列特征重要性来确定特征重要性:

learner_rf = lrn("classif.ranger", importance = "permutation")
learner_rf$train(task, row_ids = train_set)

我们可以通过$importance命令来观察自变量的重要程度:

learner_rf$importance()
##                  status                duration                  amount 
##            0.0330947539            0.0175370797            0.0134572307 
##          credit_history                 savings                     age 
##            0.0129659380            0.0095783381            0.0065733821 
##                property     employment_duration                 purpose 
##            0.0053766886            0.0053485974            0.0047822849 
##           other_debtors        installment_rate     personal_status_sex 
##            0.0043989633            0.0036503334            0.0029137105 
##       present_residence          number_credits                 housing 
##            0.0022437675            0.0017202412            0.0013506399 
##               telephone           people_liable                     job 
##            0.0012456826            0.0007195306            0.0006561488 
## other_installment_plans          foreign_worker 
##            0.0003107618            0.0001042939

为了获得重要性值的图,我们将重要性转换为 data.table格式,然后用 ggplot2 处理它:

importance = as.data.table(learner_rf$importance(), keep.rownames = TRUE)
# 修改列名称
colnames(importance) = c("Feature", "Importance")

# 用ggplot包画出重要性的图

ggplot(data=importance,
       aes(x = reorder(Feature, Importance), y = Importance)) + 
  geom_col() + coord_flip() + xlab("")

R语言 ar模型预测 r语言预测实战_数据集

可以看出前七个变量对于预测因变量起到了重要作用。

3.3 预测 Prediction

接下来我们要使用训练得到的模型进行预测。训练模型后,该模型可用于预测。通常,预测是机器学习模型的主要目的。

在我们的案例中,该模型可用于对新的信用申请人进行分类。它们基于特征的相关信用风险(好与坏)。通常,机器学习模型会预测数值。在回归情况下,这是很自然的。对于分类,大多数模型预测分数或概率。基于这些值,可以得出类别预测。

  • 预测类别 Predict Classes

首先,我们直接预测类别:

pred_logreg = learner_logreg$predict(task, row_ids = test_set)
pred_rf = learner_rf$predict(task, row_ids = test_set)

pred_logreg
## <PredictionClassif> for 200 observations:
##     row_ids truth response
##           2  good      bad
##           3  good     good
##           6  good     good
## ---                       
##         986   bad     good
##         998   bad     good
##        1000   bad     good
pred_rf
## <PredictionClassif> for 200 observations:
##     row_ids truth response
##           2  good     good
##           3  good     good
##           6  good     good
## ---                       
##         986   bad     good
##         998   bad     good
##        1000   bad     good

$predict() 方法返回一个 Prediction 对象。如果想在之后使用它,可以将其转换为 data.table格式。

我们还可以显示在混淆矩阵中的预测结果:

pred_logreg$confusion
##         truth
## response bad good
##     bad   28   26
##     good  29  117
pred_rf$confusion
##         truth
## response bad good
##     bad   22   15
##     good  35  128
  • 预测概率 Predict Probabilities

大多数学习期Learner不仅可以预测类别变量(“响应”),还可以预测他们对给定响应的“置信度”/“不确定性”程度。通常,我们通过将 Learner 的 $predict_type设置为“prob”来实现这一点。有时这需要在学习者接受培训之前完成。或者,我们可以使用此选项直接创建学习器:lrn(“classif.log_reg”, predict_type=“prob”)

learner_logreg$predict_type = "prob"
learner_logreg$predict(task, row_ids = test_set)
## <PredictionClassif> for 200 observations:
##     row_ids truth response  prob.bad prob.good
##           2  good      bad 0.5502737 0.4497263
##           3  good     good 0.2432334 0.7567666
##           6  good     good 0.1617924 0.8382076
## ---                                           
##         986   bad     good 0.1088596 0.8911404
##         998   bad     good 0.1524203 0.8475797
##        1000   bad     good 0.3172837 0.6827163

3.4 评估Performance Evaluation

为了衡量学习者在新的数据上的表现,我们通常通过将数据分成训练集和测试集来模拟unseen数据的场景。训练集用于训练学习器,测试集仅用于预测和评估训练后的学习器的表现。许多重采样方法(交叉验证cross-validation、引导bootstrap)以不同的方式重复分割过程。

在 mlr3 中,我们需要使用 rsmp() 函数指定重采样策略resampling strategy:

resampling = rsmp("holdout", ratio = 2/3)
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

在这里,我们使用“holdout”,这是一个简单的训练-测试分割(只有一次迭代)。我们使用resample()函数进行重采样计算:

res = resample(task, learner = learner_logreg, resampling = resampling)
## INFO  [16:08:51.897] [mlr3]  Applying learner 'classif.log_reg' on task 'germancredit' (iter 1/1)
res
## <ResampleResult> of 1 iterations
## * Task: germancredit
## * Learner: classif.log_reg
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

度量的默认分数包含在 $aggregate() 中:

res$aggregate()
## classif.ce 
##  0.2612613

这种情况下的默认度量是分类错误。越低越好。

我们可以运行不同的重采样策略,例如重复坚持(“二次抽样”),或交叉验证。大多数方法对不同的数据子集执行重复的训练/预测循环并聚合结果(通常作为平均值)。手动执行此操作需要我们编写循环。 mlr3 为我们完成了这项工作:

resampling = rsmp("subsampling", repeats=10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
## classif.ce 
##  0.2564565

此外,我们也可以使用交叉验证

resampling = resampling = rsmp("cv", folds=10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
## classif.ce 
##      0.246

mlr3 具有更多评估的分数。在这里,我们用 mlr_measures_classif.fpr 计算 false positive rate,用 mlr_measures_classif.fnr 计算 false negative rate。可以将多个度量作为度量列表提供(可以通过 msrs() 直接构造):

# false positive rate
rr$aggregate(msr("classif.fpr"))
## classif.fpr 
##   0.1345898
# false positive rate and false negative
measures = msrs(c("classif.fpr", "classif.fnr"))
rr$aggregate(measures)
## classif.fpr classif.fnr 
##   0.1345898   0.5068602

还有更多的重采样方法和相当多的度量(在 mlr3measures 中实现)。

mlr_resamplings
## <DictionaryResampling> with 8 stored values
## Keys: bootstrap, custom, cv, holdout, insample, loo, repeated_cv,
##   subsampling
# 评估分数类型
mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
##   regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
##   regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
##   regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
##   selected_features, time_both, time_predict, time_train

3.5模型效果对 Performance Comparision and Benchmarks

我们可以通过手动评估每个学习期的 resample() 来比较学习器。但是, benchmark() 会自动为多个学习者和任务执行重采样评估。 benchmark_grid() 创建完全交叉的设计:比较多个任务的多个学习者 w.r.t.多次重采样。

learners = lrns(c("classif.log_reg", "classif.ranger"), predict_type = "prob")

bm_design = benchmark_grid(
  tasks = task,
  learners = learners,
  resamplings = rsmp("cv", folds = 50)
)

bmr = benchmark(bm_design)

在基准测试中,我们可以比较不同的度量。在这里,我们看一下误分类率和 AUC:

measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, c("learner_id", "classif.ce", "classif.auc")]



3.6超参数调优Deviating from hyperparameters defaults

之前展示的技术构建了以 mlr3 为特色的机器学习工作流程的支柱。然而,在大多数情况下,人们永远不会像我们那样进行。虽然许多 R 包都精心选择了默认设置,但它们在任何情况下都不会以最佳方式运行。通常,我们可以选择此类超参数的值。学习者的(超)参数可以通过它的 ParamSet $param_set 访问和设置:

learner_rf$param_set
## <ParamSet>
##                               id    class lower upper nlevels        default
##  1:                        alpha ParamDbl  -Inf   Inf     Inf            0.5
##  2:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>
##  3:                class.weights ParamDbl  -Inf   Inf     Inf               
##  4:                      holdout ParamLgl    NA    NA       2          FALSE
##  5:                   importance ParamFct    NA    NA       4 <NoDefault[3]>
##  6:                   keep.inbag ParamLgl    NA    NA       2          FALSE
##  7:                    max.depth ParamInt  -Inf   Inf     Inf               
##  8:                min.node.size ParamInt     1   Inf     Inf              1
##  9:                     min.prop ParamDbl  -Inf   Inf     Inf            0.1
## 10:                      minprop ParamDbl  -Inf   Inf     Inf            0.1
## 11:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>
## 12:            num.random.splits ParamInt     1   Inf     Inf              1
## 13:                  num.threads ParamInt     1   Inf     Inf              1
## 14:                    num.trees ParamInt     1   Inf     Inf            500
## 15:                    oob.error ParamLgl    NA    NA       2           TRUE
## 16:        regularization.factor ParamUty    NA    NA     Inf              1
## 17:      regularization.usedepth ParamLgl    NA    NA       2          FALSE
## 18:                      replace ParamLgl    NA    NA       2           TRUE
## 19:    respect.unordered.factors ParamFct    NA    NA       3         ignore
## 20:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>
## 21:                  save.memory ParamLgl    NA    NA       2          FALSE
## 22: scale.permutation.importance ParamLgl    NA    NA       2          FALSE
## 23:                    se.method ParamFct    NA    NA       2        infjack
## 24:                         seed ParamInt  -Inf   Inf     Inf               
## 25:         split.select.weights ParamDbl     0     1     Inf <NoDefault[3]>
## 26:                    splitrule ParamFct    NA    NA       2           gini
## 27:                      verbose ParamLgl    NA    NA       2           TRUE
## 28:                 write.forest ParamLgl    NA    NA       2           TRUE
##                               id    class lower upper nlevels        default
##        parents       value
##  1:                       
##  2:                       
##  3:                       
##  4:                       
##  5:            permutation
##  6:                       
##  7:                       
##  8:                       
##  9:                       
## 10:                       
## 11:                       
## 12:  splitrule            
## 13:                      1
## 14:                       
## 15:                       
## 16:                       
## 17:                       
## 18:                       
## 19:                       
## 20:                       
## 21:                       
## 22: importance            
## 23:                       
## 24:                       
## 25:                       
## 26:                       
## 27:                       
## 28:                       
##        parents       value
learner_rf$param_set$values = list(verbose = FALSE)

我们可以通过两种不同的方式为我们的学习者选择参数。如果我们对学习器应该如何(超)参数化有先验知识,那么要走的路将是在参数集中手动输入参数。然而,在大多数情况下,我们希望调整学习器,以便它可以自己搜索“好的”模型配置。目前,我们只想比较几个模型。

要了解可以操作哪些参数,我们可以调查原始包版本的参数或查看学习器的参数集:

as.data.table(learner_rf$param_set)[,.(id, class, lower, upper)]



对于随机森林,控制模型复杂性的两个有意义的参数是 num.trees 和 mtry。 num.trees 默认为 500,mtry 为 floor(sqrt(ncol(data) - 1)),在我们的例子中是 4。

下面我们的目标是训练三个不同的学习器:

  1. 默认随机森林。
  2. 低 num.trees 和低 mtry 的随机森林。
  3. 具有高 num.trees 和高 mtry 的随机森林。

我们将在德国信用数据集上对他们的表现进行基准测试。为此,我们构建了三个学习器并相应地设置参数:

rf_med = lrn("classif.ranger", id = "med", predict_type = "prob")

rf_low = lrn("classif.ranger", id = "low", predict_type = "prob",
  num.trees = 5, mtry = 2)

rf_high = lrn("classif.ranger", id = "high", predict_type = "prob",
  num.trees = 1000, mtry = 11)

一旦定义了学习器,我们就可以对它们进行基准测试:

learners = list(rf_low, rf_med, rf_high)
bm_design = benchmark_grid(
  tasks = task,
  learners = learners,
  resamplings = rsmp("cv", folds = 10)
)
bmr = benchmark(bm_design)
print(bmr)
## <BenchmarkResult> of 30 rows with 3 resampling runs
##  nr      task_id learner_id resampling_id iters warnings errors
##   1 germancredit        low            cv    10        0      0
##   2 germancredit        med            cv    10        0      0
##   3 germancredit       high            cv    10        0      0

我们比较不同学习器的误分类率和 AUC:

measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, .(learner_id, classif.ce, classif.auc)]



autoplot(bmr)

R语言 ar模型预测 r语言预测实战_数据_02

“低”设置似乎有点不适合,“高”设置的标准差比默认设置“中”的大。所以对比三个参数调优模型,本文中还是默认参数的模型更优。

Session info

## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.7
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] mlr3viz_0.5.3      mlr3learners_0.4.5 mlr3_0.11.0        data.table_1.14.0 
##  [5] forcats_0.5.1      stringr_1.4.0      dplyr_1.0.5        purrr_0.3.4       
##  [9] readr_1.4.0        tidyr_1.1.3        tibble_3.1.1       ggplot2_3.3.3     
## [13] tidyverse_1.3.0   
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.4.2           sass_0.3.1           jsonlite_1.7.2      
##  [4] modelr_0.1.8         bslib_0.2.4          assertthat_0.2.1    
##  [7] lgr_0.4.2            highr_0.9            cellranger_1.1.0    
## [10] yaml_2.2.1           mlr3misc_0.8.0       globals_0.14.0      
## [13] pillar_1.6.0         backports_1.2.1      lattice_0.20-41     
## [16] glue_1.4.2           uuid_0.1-4           digest_0.6.27       
## [19] checkmate_2.0.0      rvest_1.0.0          colorspace_2.0-0    
## [22] htmltools_0.5.1.1    Matrix_1.2-18        pkgconfig_2.0.3     
## [25] mlr3measures_0.3.1   broom_0.7.6.9001     listenv_0.8.0       
## [28] haven_2.3.1          scales_1.1.1         ranger_0.12.1       
## [31] farver_2.0.3         generics_0.1.0       ellipsis_0.3.1      
## [34] withr_2.4.1          repr_1.1.3           skimr_2.1.3         
## [37] cli_2.4.0            magrittr_2.0.1       crayon_1.4.1        
## [40] readxl_1.3.1         paradox_0.7.1        evaluate_0.14       
## [43] future_1.21.0        fs_1.5.0             fansi_0.4.2         
## [46] parallelly_1.24.0    xml2_1.3.2           palmerpenguins_0.1.0
## [49] tools_4.0.3          hms_1.0.0            lifecycle_1.0.0     
## [52] munsell_0.5.0        reprex_2.0.0         compiler_4.0.3      
## [55] jquerylib_0.1.3      rlang_0.4.10         grid_4.0.3          
## [58] rstudioapi_0.13      base64enc_0.1-3      labeling_0.4.2      
## [61] rmarkdown_2.7        codetools_0.2-16     gtable_0.3.0        
## [64] DBI_1.1.1            R6_2.5.0             lubridate_1.7.9.2   
## [67] knitr_1.33           future.apply_1.7.0   utf8_1.2.1          
## [70] stringi_1.5.3        parallel_4.0.3       Rcpp_1.0.6          
## [73] vctrs_0.3.7          dbplyr_2.1.0         tidyselect_1.1.0    
## [76] xfun_0.22

Reference

Lovelace, Robin, Jakub Nowosad, and Jannes Muenchow. 2019. Geocomputation with r. CRC Press.

Lang, Michel. 2017. “checkmate: Fast Argument Checks for Defensive R Programming.” The R Journal 9 (1): 437–45. https://doi.org/10.32614/RJ-2017-028.

Funk, et al. (2020, July 27). mlr3gallery: Bike Sharing Demand - Use Case. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-07-27-bikesharing-demand/

Binder & Pfisterer (2020, March 11). mlr3gallery: mlr3tuning Tutorial - German Credit. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-03-11-mlr3tuning-tutorial-german-credit/

Pfisterer (2020, April 27). mlr3gallery: A Pipeline for the Titanic Data Set - Advanced. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-04-27-mlr3pipelines-Imputation-titanic/

Li, Lisha, Kevin G. Jamieson, Giulia DeSalvo, Afshin Rostamizadeh, and Ameet Talwalkar. 2016. “Efficient Hyperparameter Optimization and Infinitely Many Armed Bandits.” CoRR abs/1603.06560. http://arxiv.org/abs/1603.06560.

Schratz, Patrick, Jannes Muenchow, Eugenia Iturritxa, Jakob Richter, and Alexander Brenning. 2019. “Hyperparameter Tuning and Performance Assessment of Statistical and Machine-Learning Algorithms Using Spatial Data.” Ecological Modelling 406 (August): 109–20. https://doi.org/10.1016/j.ecolmodel.2019.06.002.