目录
一、实验数学原理
二、算法实现步骤
三、实例分析
四、实验结果及分析
一、实验数学原理
线性判别分析((Linear Discriminant Analysis ,简称 LDA)是一种经典的线性学习方法,在二分类问题上因为最早由 [Fisher,1936] 提出,亦称 ”Fisher 判别分析“。并且LDA也是一种监督学习的降维技术,也就是说它的数据集的每个样本都有类别输出。
其思想主要为:给定训练样例集,设法将样例投影到一条直线上,使得同样样例的投影尽可能接近、异样样例的投影点尽可能远离;在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。其实可以用一句话概括:就是“投影后类内方差最小,类间方差最大”。
可以用二维图表示如下:其中“+”,“-”分别表示两种不同类的数据集,椭圆表示数据集的外轮廓,虚线表示点的投影线的最短距离,虚线与实线的交点为数据集的投影点,绿色方块和蓝色方块分别这两类原始数据的中心点,实心三角和圆分别为投影后数据的中心点。
给定一个数据集D,将数据集划分训练集和测试集,根据类别对数据集再次进行划分,然后对每个类别的数据集进行求值
如:
二、算法实现步骤
二分类LDA线性判别实验算法:
本次二分类LDA线性判别实验步骤:
1.创建主程序LDA.m,拟合函数程序LDAfit.m,预测函数程序LDApredict.m,绘制分类图程序LDAplot.m
2.对完整数据集进行读取和划分,以ID1-ID75样本作为为训练集,ID76-ID95样本作为测试集,ID96-ID115为待预测集。
3.对训练集样本进行拟合求出w
4.对测试集样本进行预测并对预测结果进行评价
5.对待预测集样本进行分类预测
三、实例分析
部分数据展示如下:
ID | x | y | lable |
1 | 8.122 | 3.251 | 1 |
2 | 1.647 | 4.796 | 0 |
3 | 2.270 | 7.128 | 0 |
4 | 5.680 | 4.563 | 1 |
5 | 7.915 | 4.184 | 1 |
6 | 5.565 | 3.513 | 1 |
7 | 2.892 | 6.166 | 0 |
8 | 6.624 | 4.329 | 1 |
9 | 1.740 | 5.262 | 0 |
10 | 6.417 | 3.134 | 1 |
在本次实验中,只讲最大特征值法求解,其他的看后续情况吧,如普通法、奇异值分解法等。
主程序LDA.m 创建:
close all;clc;clear;
%% 第一步:读取数据
data = xlsread('D:\桌面\LDAdata.xlsx');%% 第二步:分离数据集
%分离训练集
train_X = data(1:75,2:3);
train_lei = data(1:75,4);
%分测试集
test_X = data(76:95,2:3);
test_lei = data(76:95,4);
%分离预测集
predict_X = data(96:end,2:3);%% 第三步:对训练集进行训练和预测
% 求解w
[x0,x1,y0,y1,w1]=LDAfit(train_X,train_lei);
% 定义分类点
u = mean(train_X*w1);
% 预测结果的分类情况
[y1_pred,lei1] = LDApredict(train_X,w1,1,u);
% 对训练集分类绘图
LDAplot(x0,x1,y0,y1,w1,0,u)
title('训练集分类图')
% 预测结果的准确率
accu=1-sum(abs(lei1'-train_lei))./length(train_lei);
fprintf('\n训练集的预测的准确率为:%d\n',accu)
fprintf('\n')
%
%% 第四步:对测试集集预测并输出结果
% 此处不求解w,仍用上面训练集那个
[x0,x1,y0,y1,w]=LDA2fit(test_X,test_lei);
% 对测试集的预测结果进行分类
[y2_pred,lei2] = LDApredict(test_X,w1,76,u);
% 预测结果的准确率
accu=1-sum(abs(lei2'-test_lei))./length(test_lei);
fprintf('\n测试集的预测的准确率为:%d\n',accu)
fprintf('\n')
% 对测试集分类进行绘图
figure() %重新开一个图,不在上一个图上继续绘制
LDAplot(x0,x1,y0,y1,w1,0,u)
title('测试集集分类图')%% 第五步:对预测集进行预测并输出预测结果
[y3_pred,lei3] = LDApredict(predict_X,w1,96,u);
%对预测集进行绘图
figure() %重新开一个图,不在上一个图上继续绘制
%由结果清楚知道前10为类比0,后10为类别1
x0 = data(96:105,2);
x1 = data(106:end,2);
y0 = data(96:105,3);
y1 = data(106:end,3);
LDAplot(x0,x1,y0,y1,w1,0,u)
title('预测集的分类图')
拟合函数程序LDAfit.m 创建:
%% X为数据集在此数据集只有2个变量,lei代表类别,此只适用于2二分类
% 如果数据集的变量个数有变化,请自行改正
%%
function [x0,x1,y0,y1,w]=LDA2fit(X,lei)
%% 变量解释
% X为待拟合的数据集集,lei为待拟合数据集中的类别列数据集
%% 第一步:按类别分离X0,X1
x = X(:,1);
Y = X(:,2);
x0 = x(find(lei==0));
x1 = x(find(lei==1));
y0 = Y(find(lei==0));
y1 = Y(find(lei==1));
X0= [x0,y0];
X1= [x1,y1];%% 第二步:求均值
u0 = mean(X0,1);
u1 = mean(X1,1);%% 第三步:求X0,X1的行数
% ,并按其构建相应均值矩阵
% 求行数的小,1代表列,2代表行
n0 = size(X0,1);
n1 = size(X1,1);
% % 构建均值矩阵
% U0 = repmat(u0,n0,1);
% U1 = repmat(u1,n1,1);%% 第四步:计算协方差矩阵
%X0的协方差矩阵
% E0 = sum((X0-U0)'*(X0-U0));
E0 = cov(X0,0);
%X1的协方差矩阵
% E1 = sum((X1-U1)'*(X1-U1));
E1 = cov(X1,0);%% 第五步:求类内散度矩阵
% Sw = (n0*E0 + n1*E1)/(n0+n1);
Sw = E0 + E1;%% 第六步:求类间散度矩阵
Sb = (u0-u1)'*(u0-u1);%% 第七步:求最大特征向量和特征值
[V,D] = eig(inv(Sw)*Sb);
[a1,a2] = max(max(D));
%% 第八步:求解w
w = V(:,a2);end
预测函数程序LDApredict.m 创建:
function [y_pred,lei] = LDApredict(x,w,a,u)
%% 变量解释
% x 为待预测集 ; w 为训练集的w ; a 为变量起始序列号
% u 为分类点
%% 对预测结果进行分类
disp('----------------本次预测开始--------------------')
lei=[];
y_pred=[];
for i = 1:size(x,1)h = w' * [x(i,1),x(i,2)]';y_pred(i) = h;lei(i) = 1*(h>u);
% fprintf('第%d个数据的类别属于:%d\n',a,lei(i))
% a = a+1;
end
fprintf('第%d号到%d号的数据的预测分类结果为:\n',a,a+length(x)-1)
disp(lei)
disp('----------------本次预测结束--------------------')
绘制分类图程序LDAplot.m 创建:
function LDAplot(x0,x1,y0,y1,w,b,u)
%% 变量解释
% x0,x1,y0,y1 为待预测集中的x0,x1,y0,y1 ;
% w 为训练集的w ; b 为方程的截距
% u 为分类点
%% 第一步:绘制类别0的原始数据集
plot(x0,y0,'or')
%% 第二步:绘制类别1的原始数据集
hold on
plot(x1,y1,'ob')
%% 第三步:绘制原始数据的各类别的中心点
hold on
plot(mean(x0),mean(y0),'*c')
plot(mean(x1),mean(y1),'*g')
%% 第四步:绘制投影线
hold on
k=w(2)/w(1);
x = -3:5;
yy = k*x+b;
plot(x,yy,'k')
%% 第五步:绘制投影线的垂线
xx = linspace(0,6,80);
yyy = (u-xx*w(1))/w(2);
plot(xx,yyy,'--m')
%% 第六步:绘制各类别的投影点
plot((k*(y0-b)+x0)/(k^2+1),k*(k*(y0-b)+x0)/(k^2+1)+b,'+r');
plot((k*(y1-b)+x1)/(k^2+1),k*(k*(y1-b)+x1)/(k^2+1)+b,'+b');
%% 第七步:绘制各类别的投影点的中心点
plot(mean((k*(y0-b)+x0)/(k^2+1)),mean(k*(k*(y0-b)+x0)/(k^2+1)+b),'<c','markerfacecolor','c');
plot(mean((k*(y1-b)+x1)/(k^2+1)),mean(k*(k*(y1-b)+x1)/(k^2+1)+b),'>g','markerfacecolor','g');
%% 第八步:对x轴,y轴添加标签
xlabel('变量x')
ylabel('变量y')
%% 第九步:对以上各部分添加图例
legend('类别0','类别1','类别0中心点','类别1中心点','投影线','分类线','类别0投影点','类别1投影点','类别0投影点中心点','类别1投影点中心点','Location','EastOutside')end
四、实验结果及分析
结果1:
根据给出数据的ID1-ID75样本为训练集进行学习,得到如下学习样本图形输出
结果2:
通过由训练集得到的w值对测试集进行预测,得到预测结果如下:
让预测分类结果与测试集原分类结果进行比较,得到准确率:1
结果3:
继续以训练集得到w值对测试集进行预测,得预测结果如下:
本次分享就到这里了,如果有任何错误请及时联系哦,如有转载请标明原处!