Author:赵志乾
Date:2024-06-26
Declaration:All Right Reserved!!!
1. 基本概念
数据集的线性可分性:给定一个数据集
其中,,,,如果存在某个超平面S
能够将数据集的正实例点和负实例点完全正确地划分到超平面的两侧,即对所有的实例i,有,对所有的实例i,有,则称数据集T是线性可分的,否则,称数据集T线性不可分;
感知机:假设输入空间(特征空间)是,输出空间是。输入表示实例的特征向量,对应于输入空间的点;输出表示实例的类别。由输入空间到输出空间的如下函数:
称为感知机。其中w和b为感知机的模型参数,叫作权向量,叫做偏置,表示w和x的内积。sign是符号函数,即
2. 学习策略
2.1 前提假设
训练数据集是线性可分的。
2.2 学习目标
求得一个能够将训练数据集的正实例点和负实例点完全正确分开的分离超平面,即确定感知机模型参数w、b;
2.3 学习策略
将误分类点到超平面S的总距离作为损失函数并将损失函数极小化。最小化损失函数采用随机梯度下降法。
3. 学习算法
3.1 算法描述
输入:训练数据集, 其中,,,;学习率;
输出:w,b;感知机模型为。
(1) 选取初值;
(2) 在训练集中选取数据;
(3) 如果,
(4) 转至(2),直至训练集中没有误分类点。
3.2 算法实现
// step0: 输入数据构造
List<List<Double>> points = new ArrayList<>();
List<Integer> ys = new ArrayList<>();
List<Double> point = new ArrayList<>();
point.add(3d);
point.add(3d);
points.add(point);
ys.add(1);point = new ArrayList<>();
point.add(4d);
point.add(3d);
points.add(point);
ys.add(1);point = new ArrayList<>();
point.add(1d);
point.add(1d);
points.add(point);
ys.add(-1);// 学习率
double eta=1;//**************************学习过程*******************************
// step1: 初始化w和b
List<Double> w = new ArrayList<>();
Double b = 0d;
for (int index = 0; index < points.get(0).size(); index++) {w.add(0d);
}// step2: 迭代学习
Random random = new Random();
// step2.1: 获取误分类点的下标
List<Integer> misclassifiedPointIndexes = getMisclassifiedPointIndexes(w,b,points,ys);
while(misclassifiedPointIndexes.size()>0){int index = misclassifiedPointIndexes.get(random.nextInt(misclassifiedPointIndexes.size()));// 更新w和bfor(int i = 0; i< w.size(); i++){w.set(i,w.get(i)+eta*ys.get(index)*points.get(index).get(i));}b = b+ eta*ys.get(index);misclassifiedPointIndexes = getMisclassifiedPointIndexes(w,b,points,ys);
}// step3: 输出结果
StringBuilder resultStr = new StringBuilder(" w=[");
for(Double weight : w){resultStr.append(weight).append(",");
}
resultStr.setCharAt(resultStr.length()-1,']');
System.out.println("结果:w="+resultStr+" b="+b);//*****************************函数封装***************************************
// 获取误分类点下表
public static List<Integer> getMisclassifiedPointIndexes(List<Double> w,Double b,List<List<Double>> points,List<Integer> ys){List<Integer> result = new ArrayList<>();for(int index=0; index<points.size();index++){if(isMisclassifiedPoint(w,b,points.get(index),ys.get(index))){result.add(index);}}return result;
}// 判定是否为误分类点
public static boolean isMisclassifiedPoint(List<Double> w,Double b, List<Double> point, Integer y){double error = 0;for(int index=0; index< point.size(); index++){error += w.get(index)*point.get(index);}error += b;error *= y;return error<=0;
}
4. 应用场景
感知机是一种二分类的线性分类模型,属于判别模型,是神经网络和支持向量机的基础;其应用分两个过程:
- 学习过程:求出将训练数据集进行线性划分的分离超平面,即确定模型参数w和b;
- 预测过程:用学习到的模型对新的输入实例进行分类;
注意事项:感知机模型使用的前提条件是训练数据集是线性可分的,否则感知机学习算法不收敛,学习过程发生震荡;