【python】利用 GridSearchCV 和 SVM 进行学生成绩预测

在机器学习领域,寻找最优模型参数是一个重要的步骤,它直接影响模型的泛化能力和预测准确性。本文将通过一个具体案例介绍如何使用支持向量机(SVM)和网格搜索(GridSearchCV)来预测学生的成绩,并通过调整参数来优化模型性能。

数据集:公众号“码银学编程”后台回复:学生成绩-SVM

引言

学生的成绩预测对于教育领域来说是一个重要的问题,它可以帮助教师更好地了解学生的学习情况,从而进行针对性的教学改进。在本案例中,我们将使用 SVM 作为分类器,并利用 GridSearchCV 对模型参数进行调优,以期获得更好的预测效果。

正文 

1. 数据准备

首先,我们从 CSV 文件中加载学生成绩数据集。数据集包含了多门课程的成绩,我们将这些成绩转换为等级,并使用 LabelEncoder 将分类数据转换为数值形式,以便 SVM 模型能够处理。

stu_grade = pd.read_csv('student-mat.csv')
print(stu_grade.head())

  1. age-学生年龄(数字:从15岁到22岁)
  2. Medu - 母亲教育(数字:0 -无,1 -小学教育(四年级),2 -" 5 - 9年级,3 -"中等教育或,4 -"高等教育)
  3. Fedu - 父亲教育(数字:0 -无,1 -小学教育(4年级),2 â€" 5 - 9年级,3 â€"中等教育或4 â€"高等教育)
  4. traveltime - 从家到学校的旅行时间(数字:1 - <15分钟,2 - 15 - 30分钟,3 - 30分钟到1小时,或4 - >1小时)
  5. studytime - 每周学习时间(数字:1 - <2小时,2 - 2 - 5小时,3 - 5 - 10小时,或4 - >10小时)
  6. failures - 过去班级失败的次数(数值:如果1<=n<3,则为n,否则为4)
  7. famrel - 家庭关系质量(数字:1 -非常差到5 -极好)
  8. freetime - 放学后的空闲时间(数字:从1 -非常少到5 -非常多)
  9. goout - 和朋友出去(数字:从1 -非常低到5 -非常高)
  10. Dalc - 工作日酒精消耗量(数字:1 -极低至5 -极高)
  11. Walc - 周末饮酒(数字:1 -极低至5 -极高)
  12. health - 当前健康状况(数字:从1-非常差到5-非常好)
  13. absences - 学校缺勤次数(数字:从0到93)
  14. G1 -第一阶段等级(数字:从0到20)
  15. G2 -第二阶段等级(数字:从0到20)
  16. G3 -最终等级(数字:从0到20,输出目标)

2. 模型选择与训练

SVM 是一种强大的分类算法,适用于各种类型的数据。在本案例中,我们选择了 SVC 类,它实现了支持向量分类算法。模型训练前,我们通过 train_test_split 将数据集划分为训练集和测试集。

X_train, X_test, Y_train, Y_test = train_test_split(stu_data.drop('G3', axis=1),  # 特征stu_data['G3'],  # 目标变量test_size=0.3, random_state=5)svm_model = SVC(random_state=6)
# 训练模型
svm_model.fit(X_train, Y_train)# 使用训练好的模型进行预测
Y_pred = svm_model.predict(X_test)
print(Y_pred)

3. 模型评估

在初始模型训练完成后,我们使用准确度(accuracy_score)和均方误差(mean_squared_error, MSE)来评估模型性能。准确度是衡量分类模型性能的常用指标,而 MSE 则提供了模型预测值与实际值之间差异的量化度量。

A = mean_absolute_error(Y_test, Y_pred)
S = mean_squared_error(Y_test, Y_pred)# 输出训练集和测试集的分数
print("训练集准确度:", accuracy_score(Y_train, svm_model.predict(X_train)))
print("测试集准确度:", accuracy_score(Y_test, Y_pred))
print(f"Scores:  MAE={A}, MSE={S}")

4. 参数调优

为了进一步提升模型性能,我们采用了 GridSearchCV 进行参数调优。GridSearchCV 是 scikit-learn 提供的一个工具,它通过遍历给定的参数网格,使用交叉验证来寻找最佳的参数组合。在本案例中,我们调整了 SVM 的 `C`、`gamma` 和 `kernel` 参数。

# 参数网格,用于GridSearchCV寻找最佳参数
param_grid = {'C': [0.01, 0.1, 0.2 , 0.3 , 0.4 , 0.5, 1, 10],'gamma': [1, 0.1, 0.01, 0.001],'kernel': ['rbf', 'poly', 'sigmoid']
}# 创建GridSearchCV实例
grid_search = GridSearchCV(estimator=svm_model, param_grid=param_grid, cv=5, return_train_score=True)

5. 最佳模型选择

经过 GridSearchCV 的搜索,我们得到了最佳的参数组合,并使用这些参数重新训练了 SVM 模型。再次使用准确度和 MSE 对优化后的模型进行评估,以验证参数调优的效果。

# 执行网格搜索找到最佳参数
grid_search.fit(X_train, Y_train)# 输出最佳参数和对应的最佳分数
print("最佳参数:", grid_search.best_params_)
print("最佳分数:", grid_search.best_score_)
# 使用最佳参数创建新的SVM模型
best_svm_model = grid_search.best_estimator_# 使用最佳模型进行预测
best_Y_pred = best_svm_model.predict(X_test)# 计算并输出最佳模型的测试集准确度
print("优化后测试集准确度:", accuracy_score(Y_test, best_Y_pred))MAE = mean_absolute_error(Y_test, best_Y_pred)
MSE = mean_squared_error(Y_test, best_Y_pred)print(f"优化后Scores:  MAE={MAE}, MSE={MSE}")

6. 结果分析

通过比较优化前后的模型性能,我们可以得出参数调优是否有效的结论。在本案例中,我们发现通过调整参数,模型的准确度和 MSE 都有所改善,这表明 GridSearchCV 是一个有效的工具,可以帮助我们找到更好的模型参数。

小结

本案例展示了如何使用 SVM 和 GridSearchCV 对学生成绩进行预测,并优化模型参数。通过实验,我们证明了参数调优对于提升模型性能的重要性。

尽管本案例取得了一定的成果,但仍有改进的空间。例如,可以尝试更多的机器学习算法,或者使用更复杂的特征工程技术来进一步提升模型性能。此外,对于数据集的深入分析,如探索不同课程成绩之间的关联,也可能是一个有价值的研究方向。

参考文章

GitHub地址

完整代码与运行结果

import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.metrics import mean_absolute_error, mean_squared_error
import warnings# 忽略警告
warnings.filterwarnings("ignore", category=DeprecationWarning)# 读取数据集
stu_grade = pd.read_csv('student.csv')
print(stu_grade.head())# 特征选取
new_data = stu_grade.iloc[:, :] #所有的行和列
print(new_data.head())# 将成绩转换为等级
def convert(x):x = int(x)if x < 5:return 'bad'elif x >= 5 and x < 10:return 'medium'elif x >= 10 and x < 15:return 'good'else:return 'excellent'stu_data = new_data.copy()
stu_data['G1'] = stu_data['G1'].map(lambda x: convert(x))
stu_data['G2'] = stu_data['G2'].map(lambda x: convert(x))
stu_data['G3'] = stu_data['G3'].map(lambda x: convert(x))
print(stu_data.head())# 将分类特征转换为数值形式
label_encoders = {}
for column in ['G1', 'G2', 'G3']:le = LabelEncoder()stu_data[column] = le.fit_transform(stu_data[column])label_encoders[column] = le# 划分训练集和测试集
X_train, X_test, Y_train, Y_test = train_test_split(stu_data.drop('G3', axis=1),  # 特征stu_data['G3'],  # 目标变量test_size=0.3, random_state=5)# 创建SVM模型实例
svm_model = SVC(random_state=6)# 训练模型
svm_model.fit(X_train, Y_train)# 使用训练好的模型进行预测
Y_pred = svm_model.predict(X_test)
print(Y_pred)A = mean_absolute_error(Y_test, Y_pred)
S = mean_squared_error(Y_test, Y_pred)# 输出训练集和测试集的分数
print("训练集准确度:", accuracy_score(Y_train, svm_model.predict(X_train)))
print("测试集准确度:", accuracy_score(Y_test, Y_pred))
print(f"Scores:  MAE={A}, MSE={S}")# 参数网格,用于GridSearchCV寻找最佳参数
param_grid = {'C': [0.01, 0.1, 0.2 , 0.3 , 0.4 , 0.5, 1, 10],'gamma': [1, 0.1, 0.01, 0.001],'kernel': ['rbf', 'poly', 'sigmoid']
}# 创建GridSearchCV实例
grid_search = GridSearchCV(estimator=svm_model, param_grid=param_grid, cv=5, return_train_score=True)# 执行网格搜索找到最佳参数
grid_search.fit(X_train, Y_train)# 输出最佳参数和对应的最佳分数
print("最佳参数:", grid_search.best_params_)
print("最佳分数:", grid_search.best_score_)# 使用最佳参数创建新的SVM模型
best_svm_model = grid_search.best_estimator_# 使用最佳模型进行预测
best_Y_pred = best_svm_model.predict(X_test)# 计算并输出最佳模型的测试集准确度
print("优化后测试集准确度:", accuracy_score(Y_test, best_Y_pred))MAE = mean_absolute_error(Y_test, best_Y_pred)
MSE = mean_squared_error(Y_test, best_Y_pred)print(f"优化后Scores:  MAE={MAE}, MSE={MSE}")

结果: 

   age  Medu  Fedu  traveltime  studytime  ...  health  absences  G1  G2  G3
0   18     4     4           2          2  ...       3         6   5   6   6
1   17     1     1           1          2  ...       3         4   5   5   6
2   15     1     1           1          2  ...       3        10   7   8  10
3   15     4     2           1          3  ...       5         2  15  14  15
4   16     3     3           1          2  ...       5         4   6  10  10[5 rows x 16 columns]age  Medu  Fedu  traveltime  studytime  ...  health  absences  G1  G2  G3
0   18     4     4           2          2  ...       3         6   5   6   6
1   17     1     1           1          2  ...       3         4   5   5   6
2   15     1     1           1          2  ...       3        10   7   8  10
3   15     4     2           1          3  ...       5         2  15  14  15
4   16     3     3           1          2  ...       5         4   6  10  10[5 rows x 16 columns]age  Medu  Fedu  traveltime  ...  absences         G1      G2         G3
0   18     4     4           2  ...         6     medium  medium     medium
1   17     1     1           1  ...         4     medium  medium     medium
2   15     1     1           1  ...        10     medium  medium       good
3   15     4     2           1  ...         2  excellent    good  excellent
4   16     3     3           1  ...         4     medium    good       good[5 rows x 16 columns]
[2 0 2 2 2 2 3 2 2 2 2 0 2 1 2 1 2 2 2 2 2 2 2 2 2 2 1 2 2 2 0 3 2 2 1 3 22 1 2 2 1 2 2 2 3 2 3 2 2 1 2 2 2 2 2 2 3 3 3 2 2 3 2 2 2 1 2 2 2 2 2 2 23 2 2 1 3 1 2 1 3 0 1 0 3 2 2 2 0 2 2 2 2 1 2 3 2 3 2 2 2 2 0 2 2 2 3 2 22 2 2 2 3 2 1 3]
训练集准确度: 0.927536231884058
测试集准确度: 0.6218487394957983
Scores:  MAE=0.4957983193277311, MSE=0.7983193277310925
最佳参数: {'C': 0.3, 'gamma': 0.01, 'kernel': 'poly'}
最佳分数: 0.8115942028985508
优化后测试集准确度: 0.8235294117647058
优化后Scores:  MAE=0.2857142857142857, MSE=0.5882352941176471Process finished with exit code 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/4447.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

可审批可审计追溯的单网络导出文件方案,了解一下

在物理隔离状态下&#xff0c;单网络导出文件是一个重要的安全需求&#xff0c;特别是在处理敏感数据时。在这种环境下&#xff0c;数据导出需要采取特殊的安全措施&#xff0c;以确保数据传输的安全性和合规性。需要考虑以下因素&#xff1a; 安全性&#xff1a;确保传输过程加…

筛选日志并生成序列化文件

1.在idea中创建项目 selectData. 2.添加依赖&#xff0c;插件包&#xff0c;指定打包方式&#xff0c;日志文件 大家可以直接从前面项目复制。 3.本次只需要进行序列化操作&#xff0c;所以不需要Reducer模块&#xff0c;编写Mapper模块 package com.maidu.selectdata;import…

Bert基础(十八)--Bert实战:NER命名实体识别

1、命名实体识别介绍 1.1 简介 命名实体识别&#xff08;NER&#xff09;是自然语言处理&#xff08;NLP&#xff09;中的一项关键技术&#xff0c;它的目标是从文本中识别出具有特定意义或指代性强的实体&#xff0c;并对这些实体进行分类。这些实体通常包括人名、地名、组织…

极简shell制作

&#x1f30e;自定义简单shell制作 &#xff08;ps: 文末有完整代码&#xff09; 文章目录&#xff1a; 自定义简单shell制作 简单配置Linux文件 自定义Shell编写 命令行解释器       获取输入的命令       字符串分割       子进程进行进程替换 内建命令…

28.Gateway-网关过滤器

GatewayFilter是网关中提供的一种过滤器&#xff0c;可以多进入网关的请求和微服务返回的响应做处理。 GatewayFilter(当前路由过滤器&#xff0c;DefaultFilter) spring中提供了31种不同的路由过滤器工厂。 filters针对部分路由的过滤器。 default-filters针对所有路由的默认…

opencv基础篇 ——(九)图像几何变换

图像几何变换是通过对图像的几何结构进行变换来改变图像的形状、大小、方向或者透视关系。常见的图像几何变换包括缩放、旋转、平移、仿射变换和透视变换等。下面对这些几何变换进行简要介绍&#xff1a; 矩阵的转置&#xff08;transpose &#xff09;&#xff1a; 对于图像来…

微服务之SpringCloud AlibabaNacos服务注册和配置中心

一、概述 1.1注册中心原理 在微服务远程调用的过程中&#xff0c;包括两个角色&#xff1a; 服务提供者&#xff1a;提供接口供其它微服务访问&#xff0c;比如item-service 服务消费者&#xff1a;调用其它微服务提供的接口&#xff0c;比如cart-service 在大型微服务项目…

符合医药行业规范的液氮罐运输和存储温度监测解决方案

API原料药、冻干物质和人体样本必须在玻璃相中以尽可能低的温度运输和存储。专门的低温容器——干式液氮罐——可通过液氮&#xff08;LN2&#xff09;将温度保持在-196 C。由于温度极低&#xff0c;低温容器的温度数据监测不仅具有挑战性&#xff0c;而且还需要更复杂的过程&a…

Linux下的常用基本指令

基本指令 前言ls 指令语法功能常用选项举例注意要点关于拼接关于 -a关于文件ls与/的联用ls与根目录ls与任意文件夹ls与常用选项与路径 ls -d与ls -ldls与ll pwd命令语法功能常用选项注意要点window与Linux文件路径的区别家目录 cd 指令语法功能举例注意要点cd路径.. .相对路径与…

Cesium116版本安装跑错,注意Node版本

SyntaxError: Unexpected token ?? at Loader.moduleStrategy (internal/modules/esm/translators.js:149:18) 无法解析ES node.js本本过低 nvm use无效NVM踩坑不完全指南&#xff0c;nvm use没有*_nvm use 无效-CSDN博客

决策树模型示例

通过5个条件判定一件事情是否会发生&#xff0c;5个条件对这件事情是否发生的影响力不同&#xff0c;计算每个条件对这件事情发生的影响力多大&#xff0c;写一个决策树模型pytorch程序,最后打印5个条件分别的影响力。 一 决策树模型是一种非参数监督学习方法&#xff0c;主要…

centos7 openresty lua 自适应webp和缩放图片

目录 背景效果图准备安装cwebp等命令&#xff0c;转换文件格式安装ImageMagick&#xff0c;压缩文件下载Lua API 操控ImageMagick的依赖包 代码参考 背景 缩小图片体积&#xff0c;提升加载速度&#xff0c;节省流量。 效果图 参数格式 &#xff1a; ?image_processformat,…

Llama-7b-Chinese本地推理

Llama-7b-Chinese 本地推理 基础环境信息&#xff08;wsl2安装Ubuntu22.04 miniconda&#xff09; 使用miniconda搭建环境 (base) :~$ conda create --name Llama-7b-Chinese python3.10 Channels:- defaults Platform: linux-64 Collecting package metadata (repodata.js…

Linux下软硬链接和动静态库制作详解

目录 前言 软硬链接 概念 软链接的创建 硬链接的创建 软硬链接的本质区别 理解软链接 理解硬链接 小结 动静态库 概念 动静态库的制作 静态库的制作 动态库的制作 前言 本文涉及到inode和地址空间等相关概念&#xff0c;不知道的小伙伴可以先阅读以下两篇文章…

智慧校园建设指导

智慧校园是一个庞大的业务系统&#xff0c;他涉及到校园事务的各个方面&#xff0c;包括教务&#xff0c;考务&#xff0c;教工&#xff0c;学工&#xff0c;办公&#xff0c;科研等。因此&#xff0c;建设符合学校业务需求的智慧校园平台&#xff0c;不仅需要做到认真负责外&a…

C语言位运算详解(移位操作符、位操作符)

目录 一、整数在内存中的存储方式 二、移位操作符 1、左移操作符 2、右移操作符 a.逻辑右移 b.算数右移 ps、移位操作符使用警告 三、位操作符 用例代码&#xff1a; a.按位与&#xff08;&&#xff09; b.按位或&#xff08;|&#xff09; c.按位异或&#xf…

【笔试强训】Day4 --- Fibonacci数列 + 单词搜索 + 杨辉三角

文章目录 1. Fibonacci数列2. 单词搜索3. 杨辉三角 1. Fibonacci数列 【链接】&#xff1a;Fibonacci数列 解题思路&#xff1a;简单模拟题&#xff0c;要最少的步数就是找离N最近的Fibonacci数&#xff0c;即可能情况只有比他小的最大的那个Fibonacci数以及比他大的最小的那…

《软件设计师教程:计算机网络浅了解计算机之间相互运运作的模式》

​ 个人主页&#xff1a;李仙桎 &#x1f525; 个人专栏: 《软件设计师》 ⛺️生活的理想&#xff0c;就是为了理想的生活! ​ ⛺️前言&#xff1a;各位铁汁们好啊&#xff01;&#xff01;&#xff01;&#xff0c;今天开始继续学习中级软件设计师考试相关的内容&#xff0…

Java高阶私房菜:JVM垃圾回收机制及算法原理探究

目录 垃圾回收机制 什么是垃圾回收机制 JVM的自动垃圾回收机制 垃圾回收机制的关键知识点 初步了解判断方法-引用计数法 GCRoot和可达性分析算法 什么是可达性分析算法 什么是GC Root 对象回收的关键知识点 标记对象可回收就一定会被回收吗&#xff1f; 可达性分析算…

【免费源码下载】完美运营版商城 虚拟商品全功能商城 全能商城小程序 智慧商城系统 全品类百货商城php+uniapp

简介 完美运营版商城/拼团/团购/秒杀/积分/砍价/实物商品/虚拟商品等全功能商城 干干净净 没有一丝多余收据 还没过手其他站 还没乱七八走的广告和后门 后台可以自由拖曳修改前端UI页面 还支持虚拟商品自动发货等功能 挺不错的一套源码 前端UNIAPP 后端PHP 一键部署版本&am…