支持向量机 (SVM):初学者指南

照片由 Unsplash上的 vackground.com提供

一、说明

        SVM(支持向量机)简单而优雅用于分类和回归的监督机器学习方法。该算法试图找到一个超平面,将数据分为不同的类,并具有尽可能最大的边距。本篇我们将介绍如果最大边距不存在的时候,如何创造最大边距。

二、让我们逐步了解 SVM 

        假设我们有一维湿度数据,红点代表不下雨的日子,蓝点代表下雨的日子。

虚拟一维分类数据。图片由作者提供。

      

        根据我们拥有的一维观测数据,我们可以确定阈值。该阈值将充当分类器。由于我们的数据是一维的,分类器将有一个阈值。如果我们的数据是二维的,我们会使用一条线。

        观察到的数据(最近的数据点)与分类器阈值之间的最短距离称为边距。能够提供最大margin的阈值称为Maximal Margin Classifier (Hyperplane)。在我们的例子中,它将位于双方最接近数据的中点。

最大边际分类器。图片由作者提供。

       

        最大保证金在实践中不太适用。因为它对异常值没有抵抗力。想象一下,我们有一个具有蓝色值的离群红点。在这种情况下,分类器将非常接近蓝点,远离红点。

对异常值敏感。图片由作者提供。

为了改善这一点,我们应该允许异常值和错误分类。我们在系统中引入偏差(并减少方差)。现在,边距称为软边距。使用软间隔的分类器称为支持向量分类器或软间隔分类器。边缘上和软边缘内的数据点称为支持向量。

支持向量。图片由作者提供。

        我们使用交叉验证来确定软边距应该在哪里。

        在 2D 数据中,支持向量分类器是一条线。在 3D 中,它是一个平面。在 4 个或更多维度中,支持向量分类器是一个超平面。从技术上讲,所有 SVC 都是超平面,但在 2D 情况下更容易将它们称为平面。

2D 和 3D。资料来源: https: 

//www.analyticsvidhya.com/blog/2021/05/support-vector-machines/和https://www.sciencedirect.com/topics/computer-science/support-vector-machine

        正如我们在上面看到的,支持向量分类器可以处理异常值并允许错误分类。但是,我们如何处理如下所示的重叠数据呢?

数据重叠。图片由作者提供。

        这就是支持向量机发挥作用的地方。让我们为问题添加另一个维度。我们有特征 X,作为新的维度,我们取 X 的平方并将其绘制在 y 轴上。

由于现在的数据是二维的,我们可以画一条支持向量分类器线。

将问题二维化。图片由作者提供。

        支持向量机获取低维数据,将其移至更高维度,并找到支持向量分类器。

        与我们上面所做的类似,支持向量机使用核函数来查找更高维度的支持向量分类器。核函数是一种函数,它采用原始输入空间中的两个输入数据点,并计算变换后(高维)特征空间中它们对应的特征向量的内积。

内部产品。图片由作者提供。

        核函数允许 SVM 在变换后的特征空间中运行,而无需显式计算变换后的特征向量,这对于大型数据集或复杂的变换来说计算成本可能很高。相反,核函数直接在原始输入空间中计算特征向量之间的内积。这称为内核技巧

三、多项式核

        多项式核用于将输入数据从低维空间变换到高维空间,在高维空间中使用线性决策边界更容易分离类。

        多项式核。

        a和b是两个不同的观测值,r是多项式系数,d是多项式的次数。假设d为 2,r为 1/2。

数学。

        我们最终得到一个点积。第一项(ab)是 x 轴,第二项()是 y 轴。因此,我们需要做的就是计算每对点之间的点积。例如更高维度中两点之间的关系;a = 9,b = 14 => (9 x 114 + 1/2)² = 16000,25。

四、径向内核 (RBF)

        径向核在无限维度中查找支持向量分类器。

        它为距离测试点较近的点分配较高的权重,为较远的点(如最近的邻居)分配较低的权重。较远的观察对数据点的分类影响相对较小。

内核函数。

        它计算两个数据之间的平方距离。Gamma 由交叉验证确定,它会缩放平方距离,这意味着它会缩放两个点彼此之间的影响。在此公式中,随着两点之间的距离增加,该值将接近于零。

        当类之间的决策边界是非线性且复杂的时,径向核特别有用,因为它可以捕获输入特征之间的复杂关系。

五、Python实现

        我们可以使用支持向量机sklearn.

from sklearn.svm import SVC

具有不同内核的 SVC。来源

SVC接受一些参数:

  • C是正则化参数。较大的值会使模型在训练数据上犯更多错误(错误分类)。因此,它的目的是有一个更好的概括。默认值为 1。
  • kernel设置核函数。默认为rbf。其他选择是:Linearpolysigmoidprecompulated。此外,您还可以传递自己的内核函数。
  • degree指定多项式核的次数。仅当内核是多项式时它才可用。默认值为 3。
  • gamma控制核函数的形状。它可用于rbfpolysigmoid内核,较小的 gamma 值使决策边界更平滑,较大的值使决策边界更复杂。默认值是比例,等于 1 / (n_features x X.var())。auto是 1 / n_features。或者您可以传递一个浮点值。
  • coef0仅用于 poly 和 sigmoid 内核。它控制多项式核函数中高阶项的影响。默认值为 0。
  • shrinking控制是否使用收缩启发式。这是一个加速启发式过程。
  • tol是停止标准的容差。当目标函数的变化小于tol时,优化过程将停止。
  • class_weight平衡分类问题中类别的权重。可以将其设置为平衡,以根据课程频率自动调整权重。默认值为“无”。
  • max_iter是迭代极限。-1 表示无限制(默认)。
  • probability指定是否启用概率估计。当它设置为 True 时,估计器将估计类概率,而不仅仅是返回预测的类标签。当probability设置为 True 时,可以使用predict_proba该类的方法来获取新数据点的类标签的估计概率。SVC
  • cache_size用于设置SVM算法使用的内核缓存的大小。当训练样本数量非常大或者内核计算成本很高时,内核缓存会很有用。通过将核评估存储在缓存中,SVM 算法可以在计算正则化参数 C 的不同值的决策函数时重用结果。

SVC 使用具有不同参数的 RBF 内核。来源

一个简单的实现:

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC# cancer data
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, random_state=42)# parameters
params = {'C': 1.0, 'kernel': 'rbf', 'gamma': 'scale', 
'probability': False, 'cache_size': 200}# training
svc = SVC(**params)
svc.fit(X_train, y_train)# we can use svc's own score function
score = svc.score(X_test, y_test)
print("Accuracy on test set: {:.2f}".format(score))
#Accuracy on test set: 0.95

六、回归

        我们也可以在回归问题中使用支持向量机。

from sklearn.svm import SVR

  • epsilon是指定回归线周围容差大小的参数。回归线由 SVR 算法确定,使其在一定的误差范围内拟合训练数据,该误差范围由参数定义epsilon
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error# the California Housing dataset
california = fetch_california_housing()
X_train, X_test, y_train, y_test = train_test_split(california.data, california.target, random_state=42)# training
svr = SVR(kernel='rbf', C=1.0, epsilon=0.1)
svr.fit(X_train, y_train)# Evaluate the model on the testing data
y_pred = svr.predict(X_test)
mse = mean_squared_error(y_test, y_pred)print("MSE on test set: {:.2f}".format(mse))
#MSE on test set: 1.35

        我们还可以使用 来绘制边界matplotlib

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.svm import SVC# Load the Iris dataset
iris = load_iris()# Extract the first two features (sepal length and sepal width)
X = iris.data[:, :2]
y = iris.target# Create an SVM classifier
svm = SVC(kernel='linear', C=1.0)
svm.fit(X, y)# Create a mesh of points to plot in
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),np.arange(y_min, y_max, 0.02))# Plot the decision boundary
Z = svm.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)# Plot the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.title('SVM decision boundary for Iris dataset')plt.show()

边界。图片由作者提供。

SVM 是一种相对较慢的方法。

import time
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split# Load the breast cancer dataset
data = load_breast_cancer()
X, y = data.data, data.target# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)# Fit a logistic regression model and time it
start_time = time.time()
lr = LogisticRegression(max_iter=1000)
lr.fit(X_train, y_train)
end_time = time.time()
lr_runtime = end_time - start_time# Fit an SVM model and time it
start_time = time.time()
svm = SVC(kernel='linear', C=1.0)
svm.fit(X_train, y_train)
end_time = time.time()
svm_runtime = end_time - start_time# Print the runtimes
print("Logistic regression runtime: {:.3f} seconds".format(lr_runtime))
print("SVM runtime: {:.3f} seconds".format(svm_runtime))"""
Logistic regression runtime: 0.112 seconds
SVM runtime: 0.547 seconds
"""

支持向量机 (SVM) 可能会很慢,原因如下:

  • SVM 是计算密集型的:SVM 涉及解决凸优化问题,对于具有许多特征的大型数据集来说,计算成本可能很高。SVM 的时间复杂度通常至少为 O(n²),其中 n 是数据点的数量,对于非线性内核来说,时间复杂度可能要高得多。
  • 用于调整超参数的交叉验证:SVM需要调整超参数,例如正则化参数C和核超参数,这涉及使用交叉验证来评估不同的超参数设置。这可能非常耗时,尤其是对于大型数据集或复杂模型。
  • 大量支持向量:对于非线性SVM,支持向量的数量会随着数据集的大小或模型的复杂性而快速增加。这可能会减慢预测时间,尤其是在模型需要频繁重新训练的情况下。

我们可以通过尝试以下一些方法来加速 SVM:

  • 使用线性核:线性 SVM 的训练速度比非线性 SVM 更快,因为优化问题更简单。如果您的数据是线性可分的或者不需要高度复杂的模型,请考虑使用线性核。
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC, LinearSVC
import time# Load MNIST digits dataset
mnist = fetch_openml('mnist_784', version=1)
data, target = mnist['data'], mnist['target']
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=42)# Train linear SVM
start_time = time.time()
linear_svc = LinearSVC()
linear_svc.fit(X_train, y_train)
linear_train_time = time.time() - start_time# Train non-linear SVM with RBF kernel
start_time = time.time()
rbf_svc = SVC(kernel='rbf')
rbf_svc.fit(X_train, y_train)
rbf_train_time = time.time() - start_timeprint('Linear SVM training time:', linear_train_time)
print('Non-linear SVM training time:', rbf_train_time)"""
Linear SVM training time: 109.03955698013306
Non-linear SVM training time: 165.98812198638916
"""
  • 使用较小的数据集:如果您的数据集非常大,请考虑使用较小的数据子集进行训练。您可以使用随机抽样或分层抽样等技术来确保子集代表完整数据集。
  • 使用特征选择:如果您的数据集具有许多特征,请考虑使用特征选择技术来减少特征数量。这可以降低问题的维度并加快训练速度。
  • 使用较小的值C:正则化参数C控制最大化边际和最小化分类误差之间的权衡。较小的值C可以产生具有较少支持向量的更简单的模型,这可以加速训练和预测。
import time
from sklearn.datasets import load_breast_cancer
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split# Load the breast cancer dataset
data = load_breast_cancer()
X, y = data.data, data.target# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)for C in [0.1, 1, 10]:start_time = time.time()svm = SVC(kernel='linear', C=C, random_state=42)svm.fit(X_train, y_train)train_time = time.time() - start_timeprint('Training time with C={}: {:.2f}s'.format(C, train_time))"""
Training time with C=0.1: 0.08s
Training time with C=1: 0.55s
Training time with C=10: 0.90s
"""
  • 使用缓存:SVM 涉及计算数据点对之间的内积,这可能会导致计算成本高昂。Scikit-learn 的 SVM 实现包括一个缓存,用于存储常用数据点的内积值,这可以加快训练和预测速度。您可以使用参数调整缓存的大小cache_size
from sklearn.datasets import load_breast_cancer
from sklearn.svm import SVC
import time# Load the dataset
X, y = load_breast_cancer(return_X_y=True)# Train the model without a cache
start_time = time.time()
clf = SVC(kernel='linear', cache_size=1).fit(X, y)
end_time = time.time()
print(f"Training time without cache: {end_time - start_time:.3f} seconds")# Train the model with a cache of 200 MB
start_time = time.time()
clf_cache = SVC(kernel='linear', cache_size=200, max_iter=10000).fit(X, y)
end_time = time.time()
print(f"Training time with cache: {end_time - start_time:.3f} seconds")"""
Training time without cache: 0.535 seconds
Training time with cache: 0.014 seconds
"""

七、结论

        一般来说,SVM 适用于特征数量与样本数量相比相对较少且不同类之间有明显分离余量的分类任务。SVM 还可以处理高维数据以及特征和目标变量之间的非线性关系。然而,SVM 可能不适合非常大的数据集,因为它们可能是计算密集型的并且需要大量内存。

参考文章:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/135599.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android修行手册-实现利用POI将图片插入到Excel中(文末送书)

点击跳转>Unity3D特效百例点击跳转>案例项目实战源码点击跳转>游戏脚本-辅助自动化点击跳转>Android控件全解手册点击跳转>Scratch编程案例点击跳转>软考全系列 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&…

ZZ308 物联网应用与服务赛题第G套

2023年全国职业院校技能大赛 中职组 物联网应用与服务 任 务 书 (G卷) 赛位号:______________ 竞赛须知 一、注意事项 1.检查硬件设备、电脑设备是否正常。检查竞赛所需的各项设备、软件和竞赛材料等; 2.竞赛任务中所使用…

lsky Pci-go nas个人图床搭建

①安装PicGo 应用 http://192.168.50.249:18189/api/v1 上传电脑需要有node 和 npm环境,官网下载最新安装板node后,自动会配置npm环境。 ②安装 Lsky-Uploader 获取token: http://www.metools.info/code/post278.html 服务器域名为 Lsky p…

UICollectionView左上对齐布局

最近完成的项目需要左上对齐的瀑布流,每个格子的尺寸不同,可以使用UICollectionView定义不同的尺寸,但是CollectionView的格子高度是相同的,我想要的是这样 左上对齐分别是0、1、2;3、4; 当前只能自定义一个…

OpenCV 输出文本

PutText() 输出文本 OpenCV5 将支持中文字符的输出, 当前版本OpenCV4原生不支持, 可以使用Contrib包FreeType方式实现, 不过比较麻烦.为了省事, 也可以通过将Mat转成bitmap,然后使用GDI方式输出中文字符. 示例代码 /// <summary>/// OpenCV暂时不能支持中文字符输出,显示…

Node.js中的回调地狱

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

JAVA将List转成Tree树形结构数据和深度优先遍历

引言&#xff1a; 在日常开发中&#xff0c;我们经常会遇到需要将数据库中返回的数据转成树形结构的数据返回&#xff0c;或者需要对转为树结构后的数据绑定层级关系再返回&#xff0c;比如需要统计当前节点下有多少个节点等&#xff0c;因此我们需要封装一个ListToTree的工具类…

Python 海龟绘图基础教学教案(二十六)

Python 海龟绘图——第 49 题 题目&#xff1a;绘制下面的图形 解析&#xff1a; 使用二重循环绘制六叶长方形风车。 答案&#xff1a; Python 海龟绘图——第 50 题 题目&#xff1a;绘制下面的图形 解析&#xff1a;使用二重循环绘制由四个相同大小菱形组成的四叶风车图案…

opengauss权限需求

创建角色 "u_rts" 并授予对数据库 "rts_opsdb" 的只读权限&#xff1a; CREATE ROLE u_rts LOGIN PASSWORD Cloud1234; GRANT CONNECT ON DATABASE rts_opsdb TO u_rts; GRANT USAGE ON SCHEMA public TO u_rts; GRANT SELECT ON ALL TABLES IN SCHEMA pub…

STM32MPU6050角度的读取(STM32驱动MPU6050)

注&#xff1a;文末附STM32驱动MPU6050代码工程链接&#xff0c;需要的读者请自取。 一、MPU6050介绍 MPU6050是一款集成了三轴陀螺仪和三轴加速度计的传感器芯片&#xff0c;由英国飞利浦半导体&#xff08;现为恩智浦半导体&#xff09;公司生产。它通过电子接口&#xff08…

多测师肖sir_高级金牌讲师_jenkins搭建

jenkins操作手册 一、jenkins介绍 1、持续集成&#xff08;CI&#xff09; Continuous integration 持续集成 团队开发成员每天都有集成他们的工作&#xff0c;通过每个成员每天至少集成一次&#xff0c;也就意味着一天有可 能多次集成。在工作中我们引入持续集成&#xff0c;通…

大模型时代,开发者成长指南 | 新程序员

【编者按】GPT 系列的面世影响了全世界、各个行业&#xff0c;对于开发者们的感受则最为深切。以 ChatGPT、Github Copilot 为首&#xff0c;各类 AI 编程助手层出不穷。编程范式正在发生前所未有的变化&#xff0c;从汇编到 Java 等高级语言&#xff0c;再到今天以自然语言为特…

Python高级语法----Python多线程与多进程

文章目录 多线程多进程注意事项多线程与多进程是提高程序性能的两种常见方法。在深入代码之前,让我们先用一个简单的比喻来理解它们。 想象你在一家餐厅里工作。如果你是一个服务员,同时负责多个桌子的顾客,这就类似于“多线程”——同一个人(程序)同时进行多项任务(线程…

SSM 线上知识竞赛系统-计算机毕设 附源码 27170

SSM线上知识竞赛系统 摘 要 科技进步的飞速发展引起人们日常生活的巨大变化&#xff0c;电子信息技术的飞速发展使得电子信息技术的各个领域的应用水平得到普及和应用。信息时代的到来已成为不可阻挡的时尚潮流&#xff0c;人类发展的历史正进入一个新时代。在现实运用中&#…

虚幻引擎:如何使用 独立进程模式进行模拟

第一步:先更改配置 第二步,在启动的两个玩家里面,一个设为服务器,一个链接进去地图就可以了 1.设置服务器 2.另一个玩家链接

企业级低代码开发,科技赋能让企业具备“驾驭软件的能力”

科技作为第一生产力&#xff0c;其强大的影响力在各个领域中都有所体现。数字技术&#xff0c;作为科技领域中的一股重要力量&#xff0c;正在对传统的商业模式进行深度的变革&#xff0c;为各行业注入新的生命力。随着数字技术的不断发展和应用&#xff0c;企业数字化转型的趋…

远程运维用什么软件?可以保障更安全?

远程运维顾名思义就是通过远程的方式IT设备等运行、维护。远程运维适用场景包含因疫情居家办公&#xff0c;包含放假期间出现运维故障远程解决&#xff0c;包含项目太远需要远程操作等等。但远程运维过程存在一定风险&#xff0c;安全性无法保障&#xff0c;所以一定要选择靠谱…

如何快速教你看自己电脑cpu是几核几线程

目录 一、我们日常中说的电脑多少核多少线程&#xff0c;很多人具体不知道什么意思&#xff0c;下面举例4核和4线程什么意思。二、那么4线程又是怎么回事呢&#xff1f;三、那么知道了上面的介绍后怎么看一台电脑是几核&#xff0c;几线程呢&#xff1f; 一、我们日常中说的电脑…

​软考-高级-信息系统项目管理师教程 第四版【第24章-法律法规与标准规范-思维导图】​

软考-高级-信息系统项目管理师教程 第四版【第24章-法律法规与标准规范-思维导图】 课本里章节里所有蓝色字体的思维导图

springboot 项目升级 2.7.16 踩坑

记录一下项目更新版本依赖踩坑 这个是项目最早的版本依赖 这里最初是最初是升级到 2.5.7 偷了个懒 这个版本的兼容性比较强 就选了这版本 也不用去修改就手动的去换了一下RabbitMQ的依赖 因为这边项目有AMQP 风险预警 1.spring-amqp版本低于2.4.17的用户应升级到2.4.17 2.spri…