【Machine Learning in R - Next Generation • mlr3】

本篇主要介绍mlr3包的基本使用。

一个简单的机器学习流程在mlr3中可被分解为以下几个部分:

  • 创建任务
    比如回归、分裂、生存分析、降维、密度任务等等
  • 挑选学习器(算法/模型)
    比如随机森林、决策树、SVM、KNN等等
  • 训练和预测

创建任务

本次示例将使用mtcars数据集创建一个回归任务,结果变量(或者叫因变量等等)是mpg

# 首先加载数据
data("mtcars",package = "datasets")
data <- mtcars[,1:3]
str(data)
## 'data.frame':	32 obs. of  3 variables:
##  $ mpg : num  21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
##  $ cyl : num  6 6 4 6 8 6 8 4 4 6 ...
##  $ disp: num  160 160 108 258 360 ...

使用as_task_regr()创建回归任务,as_task_classif()可创建分类任务。

library(mlr3)task_mtcars <- as_task_regr(data,target = "mpg",id="cars") # id是随便起一个名字
print(task_mtcars)
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp

可以看到数据以供32行,3列,target是mpg,feature是cyl和disp,都是bdl类型。

在创建模型前先探索数据:

library("mlr3viz") # 使用此包可视化数据
autoplot(task_mtcars, type = "pairs") # 基于GGally,我之前介绍过
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2

plot of chunk unnamed-chunk-3

如果你觉得每次加载1个R包很烦,可以直接使用library(mlr3verse)加载所有基础包!

如果你想使用自带数据集进行学习,此包也自带了很多流行的机器学习数据集。

查看内置数据集:

as.data.table(mlr_tasks)
##                key task_type nrow ncol properties lgl int dbl chr fct ord pxc
##  1: boston_housing      regr  506   19              0   3  13   0   2   0   0
##  2:  breast_cancer   classif  683   10   twoclass   0   0   0   0   0   9   0
##  3:  german_credit   classif 1000   21   twoclass   0   3   0   0  14   3   0
##  4:           iris   classif  150    5 multiclass   0   0   4   0   0   0   0
##  5:         mtcars      regr   32   11              0   0  10   0   0   0   0
##  6:       penguins   classif  344    8 multiclass   0   3   2   0   2   0   0
##  7:           pima   classif  768    9   twoclass   0   0   8   0   0   0   0
##  8:          sonar   classif  208   61   twoclass   0   0  60   0   0   0   0
##  9:           spam   classif 4601   58   twoclass   0   0  57   0   0   0   0
## 10:           wine   classif  178   14 multiclass   0   2  11   0   0   0   0
## 11:            zoo   classif  101   17 multiclass  15   1   0   0   0   0   0

结果很详细,给出了任务类型,行列数,变量类型等。

如果想要使用内置数据集,可使用以下代码:

task_penguin <- tsk("penguins")
print(task_penguin)
## <TaskClassif:penguins> (344 x 8)
## * Target: species
## * Properties: multiclass
## * Features (7):
##   - int (3): body_mass, flipper_length, year
##   - dbl (2): bill_depth, bill_length
##   - fct (2): island, sex

可以非常方便的取子集查看:

library("mlr3verse")
as.data.table(mlr_tasks)[, 1:4]
##                key task_type  nrow ncol
##  1:           actg      surv  1151   13
##  2:   bike_sharing      regr 17379   14
##  3: boston_housing      regr   506   19
##  4:  breast_cancer   classif   683   10
##  5:       faithful      dens   272    1
##  6:           gbcs      surv   686   10
##  7:  german_credit   classif  1000   21
##  8:          grace      surv  1000    8
##  9:           ilpd   classif   583   11
## 10:           iris   classif   150    5
## 11:     kc_housing      regr 21613   20
## 12:           lung      surv   228   10
## 13:      moneyball      regr  1232   15
## 14:         mtcars      regr    32   11
## 15:      optdigits   classif  5620   65
## 16:       penguins   classif   344    8
## 17:           pima   classif   768    9
## 18:         precip      dens    70    1
## 19:           rats      surv   300    5
## 20:          sonar   classif   208   61
## 21:           spam   classif  4601   58
## 22:        titanic   classif  1309   11
## 23:   unemployment      surv  3343    6
## 24:      usarrests     clust    50    4
## 25:           whas      surv   481   11
## 26:           wine   classif   178   14
## 27:            zoo   classif   101   17
##                key task_type  nrow ncol

支持非常多探索数据的操作:

task_penguin$ncol
## [1] 8
task_penguin$nrow
## [1] 344
task_penguin$feature_names
## [1] "bill_depth"     "bill_length"    "body_mass"      "flipper_length"
## [5] "island"         "sex"            "year"
task_penguin$feature_types
##                id    type
## 1:     bill_depth numeric
## 2:    bill_length numeric
## 3:      body_mass integer
## 4: flipper_length integer
## 5:         island  factor
## 6:            sex  factor`在这里插入代码片`
## 7:           year integer
task_penguin$target_names
## [1] "species"
task_penguin$task_type
## [1] "classif"
task_penguin$data()
##        species bill_depth bill_length body_mass flipper_length    island    sex
##   1:    Adelie       18.7        39.1      3750            181 Torgersen   male
##   2:    Adelie       17.4        39.5      3800            186 Torgersen female
##   3:    Adelie       18.0        40.3      3250            195 Torgersen female
##   4:    Adelie         NA          NA        NA             NA Torgersen   <NA>
##   5:    Adelie       19.3        36.7      3450            193 Torgersen female
##  ---                                                                           
## 340: Chinstrap       19.8        55.8      4000            207     Dream   male
## 341: Chinstrap       18.1        43.5      3400            202     Dream female
## 342: Chinstrap       18.2        49.6      3775            193     Dream   male
## 343: Chinstrap       19.0        50.8      4100            210     Dream   male
## 344: Chinstrap       18.7        50.2      3775            198     Dream female
##      year
##   1: 2007
##   2: 2007
##   3: 2007
##   4: 2007
##   5: 2007
##  ---     
## 340: 2009
## 341: 2009
## 342: 2009
## 343: 2009
## 344: 2009
task_penguin$head(3)
##    species bill_depth bill_length body_mass flipper_length    island    sex
## 1:  Adelie       18.7        39.1      3750            181 Torgersen   male
## 2:  Adelie       17.4        39.5      3800            186 Torgersen female
## 3:  Adelie       18.0        40.3      3250            195 Torgersen female
##    year
## 1: 2007
## 2: 2007
## 3: 2007
# 还有很多行列选择操作、改变变量的id(比如某个变量不参与模型训练)等多种操作

可视化数据:很多都是基于GGally包,可以看我之前的介绍

autoplot(task_penguin)
  • 1

plot of chunk unnamed-chunk-8

autoplot(task_penguin, type = "pairs")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plot of chunk unnamed-chunk-9

autoplot(task_penguin, type = "duo")
  • 1

plot of chunk unnamed-chunk-10

创建learner

所有的学习器都通过以下2个步骤工作:
image-20220227212306631

mlr3verse只支持常见的学习器,比如随机森林、决策树、SVM、KNN等,如果想要查看所有的学习器,可以安装mlr3extralearners

查看所有的支持的learner:
All learners

# 加载R包,常见的算法
library("mlr3verse")
mlr_learners
## <DictionaryLearner> with 53 stored values
## Keys: classif.cv_glmnet, classif.debug, classif.featureless,
##   classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
##   classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
##   classif.ranger, classif.rpart, classif.svm, classif.xgboost,
##   clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
##   clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
##   clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans,
##   clust.meanshift, clust.pam, clust.SimpleKMeans, clust.xmeans,
##   dens.hist, dens.kde, regr.cv_glmnet, regr.debug, regr.featureless,
##   regr.glmnet, regr.kknn, regr.km, regr.lm, regr.ranger, regr.rpart,
##   regr.svm, regr.xgboost, surv.coxph, surv.cv_glmnet, surv.glmnet,
##   surv.kaplan, surv.ranger, surv.rpart, surv.xgboost

创建learner

# 决策树
learner = lrn("classif.rpart")
print(learner)
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

查看支持的超参数

learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

一目了然,方便使用,记不住了可以看看,毕竟太多了,这一点比tidymodels贴心。

设定超参数的值

learner$param_set$values = list(cp = 0.01, xval = 0)
learner
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

也可以在指定learner时设定

learner = lrn("classif.rpart", xval=0, cp = 0.001)
learner$param_set$values
## $xval
## [1] 0
## 
## $cp
## [1] 0.001

训练、预测和性能评价

创建任务,选择模型

library("mlr3verse")task = tsk("penguins") # 使用内置数据集
learner = lrn("classif.rpart") #决策树分类

划分训练集和测试集

spilt <- partition(task,ratio = 0.6, stratify = T)
spilt$train
##   [1]   2   3   4   5   7   8  10  11  12  14  15  16  17  19  23  25  26  27
##  [19]  28  30  31  33  34  36  37  40  42  45  46  48  50  51  53  56  59  60
##  [37]  61  62  64  66  67  68  69  71  73  75  78  82  83  84  88  89  91  94
##  [55]  96  97  99 100 101 102 104 107 108 113 114 115 118 120 121 123 125 126
##  [73] 127 128 129 130 131 132 133 135 136 137 138 139 142 143 145 149 150 151
##  [91] 152 154 156 157 159 160 163 169 170 171 172 173 175 176 179 180 181 182
## [109] 183 186 187 188 189 193 194 197 199 200 201 203 206 208 210 211 212 213
## [127] 214 215 216 218 219 220 222 223 224 225 226 228 229 230 233 236 237 239
## [145] 240 241 242 243 247 248 249 252 253 254 255 256 257 259 260 262 266 271
## [163] 272 273 274 277 279 280 285 288 290 291 293 294 295 296 297 299 300 301
## [181] 302 304 305 306 309 310 312 313 317 319 321 322 323 324 325 328 330 331
## [199] 332 334 337 338 339 340 341 342

训练模型

learner$train(task, row_ids = spilt$train)
print(learner$model)
## n= 206 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 206 115 Adelie (0.44174757 0.19902913 0.35922330)  
##   2) flipper_length< 207.5 128  39 Adelie (0.69531250 0.30468750 0.00000000)  
##     4) bill_length< 42.35 86   0 Adelie (1.00000000 0.00000000 0.00000000) *
##     5) bill_length>=42.35 42   3 Chinstrap (0.07142857 0.92857143 0.00000000) *
##   3) flipper_length>=207.5 78   4 Gentoo (0.02564103 0.02564103 0.94871795) *

预测

prediction <- learner$predict(task, row_ids = spilt$test)
print(prediction)
## <PredictionClassif> for 138 observations:
##     row_ids     truth  response
##           1    Adelie    Adelie
##           6    Adelie    Adelie
##           9    Adelie    Adelie
## ---                            
##         336 Chinstrap Chinstrap
##         343 Chinstrap    Gentoo
##         344 Chinstrap Chinstrap

混淆矩阵

prediction$confusion
##            truth
## response    Adelie Chinstrap Gentoo
##   Adelie        53         1      0
##   Chinstrap      8        24      2
##   Gentoo         0         2     48

可视化

autoplot(prediction)
  • 1

plot of chunk unnamed-chunk-21

模型评价

先查看下支持的评价指标

mlr_measures
## <DictionaryMeasure> with 87 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
##   clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
##   regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
##   regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
##   regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
##   regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
##   surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
##   surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
##   surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
##   surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr, surv.song_tpr,
##   surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both,
##   time_predict, time_train

这里我们选择accuracy

measure <- msr("classif.acc")
prediction$score(measure)
## classif.acc 
##   0.9057971

选择多个指标:

measures <- msrs(c("classif.acc","classif.auc","classif.ce"))
prediction$score(measures)
## classif.acc classif.auc  classif.ce 
##   0.9057971         NaN   0.0942029

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/146943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言每日一题(32)环形链表

力扣网 141.环形链表 题目描述 给你一个链表的头节点 head &#xff0c;判断链表中是否有环。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;评测系统内部使用整数 pos 来表示链表尾…

LLM大模型4位量化实战【GPTQ】

权重量化方面的最新进展使我们能够在消费类硬件上运行大量大型语言模型&#xff0c;例如 RTX 3090 GPU 上的 LLaMA-30B 模型。 这要归功于性能下降最小的新型 4 位量化技术&#xff0c;例如 GPTQ、GGML 和 NF4。 在上一篇文章中&#xff0c;我们介绍了简单的 8 位量化技术和出…

GZ038 物联网应用开发赛题第10套

2023年全国职业院校技能大赛 高职组 物联网应用开发 任 务 书 &#xff08;第10套卷&#xff09; 工位号&#xff1a;______________ 第一部分 竞赛须知 一、竞赛要求 1、正确使用工具&#xff0c;操作安全规范&#xff1b; 2、竞赛过程中如有异议&#xff0c;可向现场考…

Spring学习③__Bean管理

目录 IOC接口ApplicationContext 详解IOC操作Bean管理基于xml方式基于xml方式创建对象基于xml方式注入属性使用set方法进行注入通过有参数的构造进行注入p 名称空间注入&#xff08;了解&#xff09; 基于xml方式注入其他类型属性xml 注入数组类型属性 IOC接口 IOC思想基于IOC…

Linux 无名管道实现文件复制

无名管道 通过一个管道&#xff08;假象&#xff09;进行传输数据&#xff0c;但是这个管道的传输方式是单工&#xff08;半双工&#xff09;的&#xff0c;就是这个管道允许进行发送和接受数据&#xff0c;不过不能同时进行。 创建无名管道 这里用到一个pipe&#xff08;&…

代码随想录算法训练营第三十九天【动态规划part02】 | 62.不同路径、63. 不同路径 II

62.不同路径 题目链接&#xff1a; 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 求解思路&#xff1a; 动规五部曲 确定dp数组及其下标含义&#xff1a;dp[i][j] 表示从&#xff08;0,0&#xff09;出发&#xff0c;到&#xff08;i,j&#x…

性能测试【第三篇】Jmeter的使用

线程数:10 ,设置10个并发 Ramp-Up时间(秒):所有线程在多少时间内启动,如果设置5,那么每秒启动2个线程 循环次数:请求的重复次数,如果勾选"永远"将一直发送请求 持续时间时间:设置场景运行的时间 启动延迟:设置场景延迟启动时间 响应断言 响应断言模式匹配规则 包括…

AIGC创作系统ChatGPT源码,AI绘画源码,支持最新GPT-4-Turbo模型,支持DALL-E3文生图

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…

​软考-高级-系统架构设计师教程(清华第2版)【第7章 系统架构设计基础知识(263~285)-思维导图】​

软考-高级-系统架构设计师教程&#xff08;清华第2版&#xff09;【第7章 系统架构设计基础知识&#xff08;263~285&#xff09;-思维导图】 课本里章节里所有蓝色字体的思维导图

LeetCode(18)整数转罗马数字【数组/字符串】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 12. 整数转罗马数字 1.题目 罗马数字包含以下七种字符&#xff1a; I&#xff0c; V&#xff0c; X&#xff0c; L&#xff0c;C&#xff0c;D 和 M。 字符 数值 I 1 V 5 X …

原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

文章目录 原型网络进行分类的基本流程一、原始代码---计算欧氏距离&#xff0c;设计原型网络&#xff08;计算原型开始训练&#xff09;二、每一行代码的详细解释总结 原型网络进行分类的基本流程 利用原型网络进行分类&#xff0c;基本流程如下&#xff1a; 1.对于每一个样本…

Confluence 快速安装教程

安装jdk yum install -y java-1.8.0-openjdk.x86_64 java -version 安装MySQL mkdir -p /data/mysql/data chmod 777 /data/mysql/datadocker rm -f mysql docker run -d --name mysql \-p 3306:3306 \-e MYSQL_ROOT_PASSWORDfingard1 \-v /data/mysql/data:/var/lib/mysql …

​软考-高级-系统架构设计师教程(清华第2版)【第18章 安全架构设计理论与实践(P648~690)-思维导图】​

软考-高级-系统架构设计师教程&#xff08;清华第2版&#xff09;【第18章 安全架构设计理论与实践&#xff08;P648~690&#xff09;-思维导图】 课本里章节里所有蓝色字体的思维导图

视频剪辑技巧:简单步骤,批量剪辑并随机分割视频

随着社交媒体平台的广泛普及和视频制作需求的急剧增加&#xff0c;视频剪辑已经成为了当今社会一项不可或缺的技能。然而&#xff0c;对于许多初学者来说&#xff0c;视频剪辑可能是一项令人望而生畏的复杂任务。可能会面临各种困难&#xff0c;如如何选择合适的软件和硬件、如…

VBA技术资料MF84:判断文件夹是否存在并创建

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。我的教程一共九套&#xff0c;分为初级、中级、高级三大部分。是对VBA的系统讲解&#xff0c;从简单的入门&#xff0c;到…

【Qt-23】基于QCharts绘制曲线图

一、QChart简介 QChart是Qt中专门用于绘制图表的模块&#xff0c;支持折线图、柱状图、饼图等常见类型。其主要组成部分有&#xff1a; QChart&#xff1a;整个图表的容器&#xff0c;管理图表中的所有数据和图形属性QChartView&#xff1a;继承自QGraphicsView&#xff0c;用于…

基于单片机C51全自动洗衣机仿真设计

**单片机设计介绍&#xff0c; 基于单片机C51全自动洗衣机仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机C51的全自动洗衣机仿真设计是一个复杂的项目&#xff0c;它涉及到硬件和软件的设计和实现。以下是对这…

镀膜与干刻中的平均自由程是什么?

在芯片制造中&#xff0c;镀膜和干刻是其中的重要环节&#xff0c;通常要用到CVD&#xff0c;RIE等技术&#xff0c;对材料表面进行纳米级的精细操作。在这些工序中&#xff0c;原子&#xff0c;分子&#xff0c;离子等&#xff0c;会在气体或真空中进行自由运动&#xff0c;直…

IDEA 高分辨率卡顿优化

VM设置优化 -Dsun.java2d.uiScale.enabledfalse 增加该条设置&#xff0c;关闭高分切换 https://intellij-support.jetbrains.com/hc/en-us/articles/115001260010-Troubleshooting-IDE-scaling-DPI-issues-on-Windows​intellij-support.jetbrains.com/hc/en-us/articles/1…