本周学习内容为SVM的基本原理和运用。
参考资料:耳东陈:零基础学SVM—Support Vector Machine(一)
1、什么是SVM
SVM的全称是Support Vector Machine,即支持向量机,主要用于解决模式识别领域中的数据分类问题,属于有监督学习算法的一种。SVM要解决的问题可以用一个经典的二分类问题加以描述。如图1所示,红色和蓝色的二维数据点显然是可以被一条直线分开的,在模式识别领域称为线性可分问题。然而将两类数据点分开的直线显然不止一条,这是线性分类器的特征。不同的分类方案,分类直线到两类向量之间的面积是不一样的,术语称为“决策面”,svm试图找到一个直线量两个类分类,并且面积最大的最优的解,而这个直线取决于两个类型最边界的三个向量。
svm的优缺点:
优点:模型训练鲁棒性很强,对特征向量、噪音数据不敏感,预测快,占内存少。
缺点:训练耗时,时间复杂度是O(N的立方),至少是N的平方,对于太多维度的寻里昂非常耗时和耗费资源。
svm是一种非常优秀的分类器,老师说在深度学习出现之前,它统治机器学习领域20年。
2、svm的优化
现实场景中,使用svm训练可能遇到很多的问题,比如奇怪的数据,比如异常数据outlier的样本数据,一种分类中存在另一种分类的样本,我们需要通过带松弛变脸的svm进行处理,避免无解。松弛变量是常数C,常数C可控制松弛的量度。
SVM虽然解决的是二分类问题,但可以扩展到多个分类问题。
1)、可以对每个需要识别的类型分别训练一个分类模型,用于预测的时候,哪个分类器的预测值高(wtix),就取哪个。
2)、对K个类别,训练k*(k-1)/2个Svm,预测的时候采用投票方式决定,哪个次数多选哪个。
以上方法适用于每种分类模型。
3、作业:
用特征数据预测蘑菇是否有毒。
import pandas as pd
import numpy as np# 导入数据
mush_df = pd.read_csv('./data/mushrooms.csv')# 将值从字母转换为
mush_df_encoded = pd.get_dummies(mush_df) ##独热(ont-hot)编码,将离散的值转(ABC)换为数字mush_df.head()# 将特征和类别标签分布赋值给 X 和 y
X_mush = mush_df_encoded.iloc[:,2:]
y_mush = mush_df_encoded.iloc[:,1]
#查看特征数据
X_mush.head()
#查看标签
y_mush.head()
#训练svm
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline# TODO
pca = PCA(n_components=100, whiten=True, random_state=42)
##这里采用的核函数为线性分类器linear,经过测试效果比非线性的好
svc = SVC( kernel='linear',class_weight='balanced')
model = make_pipeline(pca, svc)
from sklearn.model_selection import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(X_mush, y_mush, random_state=41)
from sklearn.model_selection import GridSearchCV# TODO
param_grid = {'svc__C': [1, 5, 10, 50]}
grid = GridSearchCV(model, param_grid)%time grid.fit(Xtrain, ytrain)
print(grid.best_params_)
# TODO
model = grid.best_estimator_
yfit = model.predict(Xtest)
from sklearn.metrics import classification_report
print(classification_report(ytest, yfit))
##结果展示,全是1。。。
precision recall f1-score support0 1.00 1.00 1.00 10471 1.00 1.00 1.00 984avg / total 1.00 1.00 1.00 2031