【Sklearn】基于决策树算法的数据分类预测(Excel可直接替换数据)
- 1.模型原理
- 1.1 模型原理
- 1.2 数学模型
- 2.模型参数
- 3.文件结构
- 4.Excel数据
- 5.下载地址
- 6.完整代码
- 7.运行结果
1.模型原理
决策树是一种基于树状结构的分类和回归模型,它通过一系列的决策规则来将数据划分为不同的类别或预测值。决策树的模型原理和数学模型如下:
1.1 模型原理
决策树的基本思想是从根节点开始,通过一系列的节点和分支,根据不同特征的取值将数据集划分成不同的子集,直到达到叶节点,然后将每个叶节点分配到一个类别或预测值。决策树的构建过程就是确定如何选择特征以及如何划分数据集的过程。
决策树的主要步骤:
-
选择特征: 从所有特征中选择一个最佳特征作为当前节点的划分特征,这个选择通常基于某个度量(如信息增益、基尼系数)来评估不同特征的重要性。
-
划分数据集: 根据选择的特征,将数据集划分成多个子集,每个子集对应一个分支。
-
递归构建: 对每个子集递归地重复步骤1和步骤2,直到满足停止条件,如达到最大深度、样本数不足等。
-
叶节点赋值: 在构建过程中,根据训练数据的真实标签或均值等方式,为叶节点分配类别或预测值。
1.2 数学模型
在数学上,决策树可以表示为一个树状结构,其中每个节点表示一个特征的划分,每个分支代表一个特征取值的分支。具体来说,每个节点可以由以下元素定义:
-
划分特征: 表示选择哪个特征进行划分。
-
划分阈值: 表示在划分特征上的取值阈值,用于将数据分配到不同的子集。
-
叶节点值: 表示在达到叶节点时所预测的类别或预测值。
在决策树的训练过程中,我们寻找最优的划分特征和划分阈值,以最大程度地减少不纯度(或最大程度地增加信息增益、降低基尼指数等)。
数学模型可以用以下形式表示:
f ( x ) = { C 1 , if x belongs to region R 1 C 2 , if x belongs to region R 2 ⋮ ⋮ C k , if x belongs to region R k f(x) = \begin{cases} C_1, & \text{if } x \text{ belongs to region } R_1 \\ C_2, & \text{if } x \text{ belongs to region } R_2 \\ \vdots & \vdots \\ C_k, & \text{if } x \text{ belongs to region } R_k \end{cases} f(x)=⎩ ⎨ ⎧C1,C2,⋮Ck,if x belongs to region R1if x belongs to region R2⋮if x belongs to region Rk
其中, C i C_i Ci表示叶节点的类别或预测值, R i R_i Ri表示根据特征划分得到的子集。
总之,决策树通过递归地选择最佳特征和阈值,将数据集划分为多个子集,最终形成一个树状结构的模型,用于分类或回归预测。
2.模型参数
DecisionTreeClassifier
是scikit-learn
库中用于构建决策树分类器的类。它具有多个参数,用于调整决策树的构建和性能。以下是一些常用的参数及其说明:
-
criterion: 衡量分割质量的标准。可以是"gini"(基尼系数)或"entropy"(信息熵)。默认为"gini"。
-
splitter: 用于选择节点分割的策略。可以是"best"(选择最优的分割)或"random"(随机选择分割)。默认为"best"。
-
max_depth: 决策树的最大深度。如果为None,则节点会扩展,直到所有叶节点都是纯的,或者包含少于min_samples_split个样本。默认为None。
-
min_samples_split: 节点分裂所需的最小样本数。如果一个节点的样本数少于这个值,就不会再分裂。默认为2。
-
min_samples_leaf: 叶节点所需的最小样本数。如果一个叶节点的样本数少于这个值,可以合并到一个叶节点。默认为1。
-
min_weight_fraction_leaf: 叶节点所需的最小权重分数总和。与min_samples_leaf类似,但是使用样本权重而不是样本数量。默认为0。
-
max_features: 寻找最佳分割时要考虑的特征数量。可以是整数、浮点数、字符串或None。默认为None。
-
random_state: 随机数生成器的种子,用于随机性控制。默认为None。
-
max_leaf_nodes: 最大叶节点数。如果设置,算法会通过去掉最不重要的叶节点来合并其他节点。默认为None。
-
min_impurity_decrease: 分割需要达到的最小不纯度减少量。如果分割不会降低不纯度超过这个阈值,则节点将被视为叶节点。默认为0。
-
class_weight: 类别权重,用于处理不平衡数据集。
这些是DecisionTreeClassifier
中一些常用的参数。根据你的数据和问题,你可以根据需要调整这些参数的值,以获得更好的模型性能。在实际应用中,根据数据的特点进行调参非常重要。
3.文件结构
iris.xlsx % 可替换数据集
Main.py % 主函数
4.Excel数据
5.下载地址
- 资源下载地址
6.完整代码
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as snsdef decision_tree_classification(data_path, test_size=0.2, random_state=42):# 加载数据data = pd.read_excel(data_path)# 分割特征和标签X = data.iloc[:, :-1] # 所有列除了最后一列y = data.iloc[:, -1] # 最后一列# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)# 创建决策树分类器# 1. ** criterion: ** 衡量分割质量的标准。可以是"gini"(基尼系数)或"entropy"(信息熵)。默认为"gini"。# 2. ** splitter: ** 用于选择节点分割的策略。可以是"best"(选择最优的分割)或"random"(随机选择分割)。默认为"best"。# 3. ** max_depth: ** 决策树的最大深度。如果为None,则节点会扩展,直到所有叶节点都是纯的,或者包含少于min_samples_split个样本。默认为None。# 4. ** min_samples_split: ** 节点分裂所需的最小样本数。如果一个节点的样本数少于这个值,就不会再分裂。默认为2。# 5. ** min_samples_leaf: ** 叶节点所需的最小样本数。如果一个叶节点的样本数少于这个值,可以合并到一个叶节点。默认为1。# 6. ** min_weight_fraction_leaf: ** 叶节点所需的最小权重分数总和。与min_samples_leaf类似,但是使用样本权重而不是样本数量。默认为0。# 7. ** max_features: ** 寻找最佳分割时要考虑的特征数量。可以是整数、浮点数、字符串或None。默认为None。# 8. ** random_state: ** 随机数生成器的种子,用于随机性控制。默认为None。# 9. ** max_leaf_nodes: ** 最大叶节点数。如果设置,算法会通过去掉最不重要的叶节点来合并其他节点。默认为None。# 10. ** min_impurity_decrease: ** 分割需要达到的最小不纯度减少量。如果分割不会降低不纯度超过这个阈值,则节点将被视为叶节点。默认为0。# 11. ** class_weight: ** 类别权重,用于处理不平衡数据集。# 使用gini作为分割标准,设置最大深度为3,最小样本数为5model = DecisionTreeClassifier(criterion='gini', max_depth=3, min_samples_split=5)# 在训练集上训练模型model.fit(X_train, y_train)# 在测试集上进行预测y_pred = model.predict(X_test)# 计算准确率accuracy = accuracy_score(y_test, y_pred)return confusion_matrix(y_test, y_pred), y_test.values, y_pred, accuracyif __name__ == "__main__":# 使用函数进行分类任务data_path = "iris.xlsx"confusion_mat, true_labels, predicted_labels, accuracy = decision_tree_classification(data_path)print("真实值:", true_labels)print("预测值:", predicted_labels)print("准确率:{:.2%}".format(accuracy))# 绘制混淆矩阵plt.figure(figsize=(8, 6))sns.heatmap(confusion_mat, annot=True, fmt="d", cmap="Blues")plt.title("Confusion Matrix")plt.xlabel("Predicted Labels")plt.ylabel("True Labels")plt.show()# 用圆圈表示真实值,用叉叉表示预测值# 绘制真实值与预测值的对比结果plt.figure(figsize=(10, 6))plt.plot(true_labels, 'o', label="True Labels")plt.plot(predicted_labels, 'x', label="Predicted Labels")plt.title("True Labels vs Predicted Labels")plt.xlabel("Sample Index")plt.ylabel("Label")plt.legend()plt.show()
7.运行结果