第100+12步 ChatGPT学习:R实现KNN分类

基于R 4.2.2版本演示

一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT或者Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。

二、R代码实现KNN分类

(1)导入数据

我习惯用RStudio自带的导入功能:

(2)建立KNN模型

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)# Fit KNN model on the training set
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl, preProcess = "scale")# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]# Convert true values to factor for ROC analysis
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "blue") +geom_area(alpha = 0.2, fill = "blue") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Training ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "red") +geom_area(alpha = 0.2, fill = "red") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Validation ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {conf_mat_df <- as.data.frame(as.table(conf_mat))colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +geom_tile(color = "white") +geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +scale_fill_gradient(low = "white", high = "steelblue") +labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +theme_minimal() +theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))print(p)
}
# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R语言中,caret包提供了一个通用的接口来训练KNN模型。使用caret的train函数来训练KNN模型时,可以调整多种参数来优化模型的性能:

基本参数:

①formula: 指定模型的公式,如Y ~ .,表示使用数据框中的所有其他变量来预测Y。

②data: 提供包含训练数据的数据框。

③method: 对于KNN模型,这个参数应设置为"knn"。

④preProcess: 预处理步骤,常用的包括标准化("scale")和中心化("center"),对于KNN这一步非常重要因为KNN依赖于变量的距离度量。

⑤trControl: 一个trainControl对象,定义了模型训练的各种控制策略,如交叉验证的类型和重复次数。

trainControl 函数的参数:

①method: 训练的方法,如交叉验证("cv"),重复交叉验证("repeatedcv"),留一交叉验证("LOOCV")等。

②number: 对于"cv"和"repeatedcv",这个参数定义了折数。

③repeats: 当使用"repeatedcv"时,定义重复的次数。

④search: 参数搜索方法,默认为"grid"。也可以设置为"random"进行随机搜索。

⑤savePredictions: 是否保存预测结果,通常用于后续分析。

模型性能调整参数:

使用KNN时,最关键的参数之一是邻居的数量(K值)。这可以通过train函数的以下参数来调整:

①tuneLength: 这个参数决定了在参数搜索中考虑多少个不同的K值。

②tuneGrid: 这是一个数据框,可以自定义K值的具体范围,例如expand.grid(k = c(1, 5, 10))

结果输出(默认参数):

三、KNN调参方法

如前所述,KNN的关键参数就是K值,所以可以对其进行一个暴力测试,比如取值1到10:

# 定义交叉验证的控制方法,启用网格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定义K值的网格搜索范围
tuneGrid <- expand.grid(k = 1:10)
# 在训练集上拟合KNN模型,指定网格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型结果,找出最优的K值
print(model)

解读:

定义交叉验证的控制方法:使用trainControl函数设定交叉验证的详细参数。

定义K值的网格:使用tuneGrid参数在train函数中指定K值的范围。

拟合模型:使用train函数训练模型,同时应用预处理步骤(比如标准化数据),以确保每个特征在距离计算中具有等同的权重。

结果输出:

注意:用了caret包的train函数,并且通过网格搜索指定了一系列的参数(如K值的范围),那么这个函数会自动选择表现最好的参数配置来训练最终的模型。train函数的输出即是基于你提供的训练数据和参数搜索范围内表现最优的模型。因此,当你调用predict函数进行预测时,使用的就是这个最优化的模型。所以,下面的代码不变。

结果吧,跟之前的完全一样:

因为caret包对于KNN模型默认进行一系列的K值尝试,通常这个范围是1到最多的邻居数,但具体的最大K值依赖于caret的内部设置。在大多数情况下,它会尝试如1, 5, 7, 9等常用的K值。所以,我们默认参数的时候,其实软件自动给我们寻找最优K值了。可以用这个代码输出最有K值:

# Print the best K value used by the model
best_k <- model$bestTune$k
cat("The best K value found is:", best_k, "\n")

K值就是9,跟我们自行调参的一致。

那我们猛点,把K的范围设置的宽一些:

# 定义交叉验证的控制方法,启用网格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定义K值的网格搜索范围
tuneGrid <- expand.grid(k = 1:20)
# 在训练集上拟合KNN模型,指定网格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型结果,找出最优的K值
print(model)

结果:

K=19,性能指标如下,似乎大同小异:

四、最后

数据嘛:

链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取码:x8xm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/35046.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker】Consul 和API

目录 一、Consul 1. 拉取镜像 2. 启动第一个consul服务&#xff1a;consul1 3. 查看consul service1 的ip地址 4. 启动第二个consul服务&#xff1a;consul2&#xff0c; 并加入consul1&#xff08;使用join命令&#xff09; 5. 启动第三个consul服务&#xff1a;consul3&…

攻击者开始使用 XLL 文件进行攻击

近期&#xff0c;研究人员发现使用恶意 Microsoft Excel 加载项&#xff08;XLL&#xff09;文件发起攻击的行动有所增加&#xff0c;这项技术的 MITRE ATT&CK 技术项编号为 T1137.006。 这些加载项都是为了使用户能够利用高性能函数&#xff0c;为 Excel 工作表提供 API …

【SQL Server数据库】关系模式与关系代数

目录 一、请用关系代数完成下列查询 1. 求 供应工程J1 零件P1的供应商号码SNO&#xff1b; 2. 求 供应工程J1 零件&#xff08;P&#xff09;为红色 的供应商号码SNO&#xff1b; 3. 求 没有使用 天津供应商&#xff08;P&#xff09;生产的红色零件&#xff08;S&#xff0…

pycharm中的使用技巧

1、更改主题&#xff1a;找到设置&#xff0c;然后更改主题 点击选择自己喜欢的主题&#xff0c;然后就可以更改主题了 2、设置字体的快捷键 找到设置&#xff0c;如下&#xff1a; 找到increase&#xff0c;如下&#xff1a; 右键选择&#xff0c;增加字体快捷键 按住ctrl滑轮…

Excel 查找后隐去右边列

Excel 有几列数字 ABC11002042002202100102326027010841199100512100100 当给定参数时&#xff0c;请从每行找到该参数&#xff0c;隐去右边的列。如果某行不含该参数&#xff0c;则隐去整行。当参数是 100 时&#xff0c;结果如下&#xff1a; ABC710082021009119910010121…

shell之免交互

免交互 交互&#xff1a;发出指令控制指令的运行&#xff0c;程序再接收到指令的效果做出对应的反应。 免交互&#xff1a;间接的&#xff0c;通过第三方的方式把指令传送给程序&#xff0c;不用直接的下达指令 Hhere Document 免交互 这是命令行格式&#xff0c;也可以写在脚本…

QTableWidget的使用

使用QTableWidget&#xff0c;初始化数据、设置列头及格式&#xff0c;设置行数&#xff0c;设置每个单元格的编辑&#xff0c;间隔行底色变换、行选择 &#xff0c;模式&#xff0c;单元格选择模式、插入行 、追加行、删除行&#xff0c;单元格加图标&#xff0c;单元格显示ch…

好记性不如烂笔头(三)——文件保存后打开呈现乱码问题

现象 请随博主进行下列操作&#xff0c;神奇的事情会发生—— 1、新建记事本&#xff0c;里面输入“同”字&#xff0c;保存为ANSI格式 2、再次打开会发现&#xff0c;“同”已经变成了乱码 3、类似的字还有很多&#xff0c;例如“同学”的“学”。而有些字则不会出现这种情况…

3_电机的发展及学习方法

一、电机组成及发展 1、什么是励磁&#xff1f; 在电磁学中&#xff0c;励磁是通过电流产生磁场的过程。 发电机或电动机由在磁场中旋转的转子组成。磁场可以由 永磁体或励磁线圈产生。对于带有励磁线圈的机器&#xff0c;电流必须在线圈中流动才能产生&#xff08;激发&#x…

香港服务器托管对外贸行业必要性和优势

在当今全球化的经济环境下&#xff0c;外贸企业面临着前所未有的机遇与挑战。其中&#xff0c;服务器托管的选择对于外贸企业的运营效率和市场拓展具有举足轻重的作用。香港服务器&#xff0c;凭借其独特的地理位置、优质的网络环境和卓越的服务性能&#xff0c;一直是外贸企业…

“Hello, World” 的历史

“Hello, World!” —— 初学者进入编程世界的第一步 由布莱恩柯林汉 撰写的“Hello, world”程序 (1978年) 布莱恩W.克尼汉&#xff08;Brian W. Kernighan&#xff09;—— Unix 和 C 语言背后的巨人 布莱恩W.克尼汉 布莱恩W.克尼汉在 1942 年出生在加拿大多伦多&#xff…

OS中断机制-嵌套和竞争

对于FreeRTOS最好不去用中断嵌套,中断嵌套会增加堆栈空间的使用,因为每个中断服务程序都需要保存和恢复寄存器状态,这可能会耗尽有限的堆栈空间,从而导致系统故障。以及中断嵌套时,不同的中断服务程序可能会竞争访问共享资源,从而增加死锁的风险。这可能会导致系统出现故…

Verilog进行结构描述(structural modeling)(一):基本概念

目录 1.结构描述(structural modeling)的内容&#xff1a;2.实例 微信公众号获取更多FPGA相关源码&#xff1a; 1.结构描述(structural modeling)的内容&#xff1a; 用门来描述器件的功能基于基本元件和底层模块例化语句最接近实际的硬件结构主要使用元件的定义、使用声明以…

Flink——最流批的大数据框架(流批一体)

Apache Flink基础教程 资料来源&#xff1a;Apache Flink Tutorial (tutorialspoint.com) Apache Flink是Apache Hadoop的开源本地分析数据库。它由Cloudera、MapR、Oracle和Amazon等供应商提供。本教程中提供的示例是使用Cloudera Apache Flink开发的。 本教程是为那些想要学…

fork 是一个创建新进程的系统调用

在计算机科学中&#xff0c;fork 是一个创建新进程的系统调用。具体来说&#xff0c;fork 调用会创建一个与当前进程几乎完全相同的副本&#xff0c;包括父进程的内存布局、环境变量、打开的文件描述符等。这个新的进程被称为子进程&#xff0c;而原始进程被称为父进程。 以下…

光伏开发有没有难点?如何解决?

随着全球对可再生能源的日益重视&#xff0c;光伏技术作为其中的佼佼者&#xff0c;已成为实现能源转型的关键手段。然而&#xff0c;光伏开发并非一帆风顺&#xff0c;其过程中也面临着诸多难点和挑战。本文将对这些难点进行探讨&#xff0c;并提出相应的解决策略。 一、光伏开…

12 学习总结:操作符

目录 一、操作符的分类 二、二进制和进制转换 &#xff08;一&#xff09;概念 &#xff08;二&#xff09;二进制 &#xff08;三&#xff09;进制转换 1、2进制与10进制的互换 &#xff08;1&#xff09;2进制转化10进制 &#xff08;2&#xff09;10进制转化2进制 2…

解决vs2022scanf报错问题

vs2022scanf报错问题 大家下完vs2022之后,开心的写下一段简单的代码: #include <stdio.h> #include <stdlib.h>int main() {int a;scanf("%d", &a);printf("%d", a);return 0; } vs2022会毫不犹豫的报错,下面是报错信息: 翻译过来就是v…

探究InnoDB Compact行格式背后

目录 一、InnoDB 行格式数据准备 二、COMPACT行格式整体说明 三、记录的额外信息 &#xff08;一&#xff09;变长字段长度列表 数据结构 存储过程 读取过程 变长字段长度列表存储示例 &#xff08;二&#xff09;NULL 值位图 数据结构 存储过程 读取过程 NULL 值…

【MySQL进阶之路 | 高级篇】索引的声明与使用

1. 索引的分类 MySQL的索引包括普通索引&#xff0c;唯一性索引&#xff0c;全文索引&#xff0c;单列索引和空间索引. 从功能逻辑上说&#xff0c;索引主要分为普通索引&#xff0c;唯一索引&#xff0c;主键索引和全文索引.按物理实现方式&#xff0c;索引可以分为聚簇索引…