采用辛烷值数据集“spectra_data.mat”(任意数据集均可),介绍贝叶斯线性回归模型的构建和使用流程。
运行结果如下:
训练集预测精度指标如下:
训练集数据的R2为: 1
训练集数据的MAE为: 0.00067884
训练集数据的RMSE为: 0.00088939
测试集预测精度指标如下:
测试集数据的R2为: 0.97755
测试集数据的MAE为: 0.17304
测试集数据的RMSE为: 0.23149
具体代码如下:
%% I. 清空环境变量
clear all
clc%% II. 训练集/测试集产生
%%
% 1. 导入数据
load spectra_data.mat
%plot(NIR)表示x轴为0-60,y轴为矩阵内的值;plot(NIR')表示x轴为0-401;y轴为矩阵内的值%%
% 2. 随机产生训练集和测试集
temp = randperm(size(NIR,1));
% 训练集——50个样本
P_train = NIR(temp(1:50),:)';
T_train = octane(temp(1:50),:)';
N1 = size(P_train,2);
% 测试集——10个样本
P_test = NIR(temp(51:end),:)';
T_test = octane(temp(51:end),:)';
N2 = size(P_test,2);%% III. 数据归一化
[p_train, ps_input] = mapminmax(P_train,0,1);
p_test = mapminmax('apply',P_test,ps_input);
p_train = p_train';
p_test = p_test';
[t_train, ps_output] = mapminmax(T_train,0,1);
t_train = t_train';
%% IV. 贝叶斯回归模型创建、训练及仿真测试
PriorMdl = bayeslm(401, 'ModelType', 'conjugate');
PosteriorMdl = estimate(PriorMdl,p_train,t_train, 'Display', true);% 仿真测试
pre1 = forecast(PosteriorMdl,p_train);
pre2 = forecast(PosteriorMdl,p_test);
%% V. 性能评价
% 数据反归一化
T_pre1 = mapminmax('reverse',pre1,ps_output);
T_pre2 = mapminmax('reverse',pre2,ps_output);
% RMSE
RMSE1 = sqrt(mean((pre1 - T_train').^2));
RMSE2 = sqrt(mean((pre2 - T_test').^2));
% 决定系数R^2
R2_1 = 1 - norm(T_train' - T_pre1)^2 / norm(T_train' - mean(T_train'))^2;
R2_2 = 1 - norm(T_test' - T_pre2)^2 / norm(T_test' - mean(T_test'))^2;
% MAE
mae1 = mean(abs(T_train' - T_pre1));
mae2 = mean(abs(T_test' - T_pre2));
%%
% 3. 结果对比
disp('训练集预测精度指标如下:')
disp(['训练集数据的R2为: ', num2str(R2_1)])
disp(['训练集数据的MAE为: ', num2str(mae1)])
disp(['训练集数据的RMSE为: ' ,num2str(RMSE1) ])
disp( '测试集预测精度指标如下:')
disp(['测试集数据的R2为: ' , num2str(R2_2)])
disp(['测试集数据的MAE为: ' , num2str(mae2)])
disp(['测试集数据的RMSE为: ', num2str(RMSE2)])%% VI. 绘图
figure(1)
plot(1:N1,T_train,'b:*',1:N1,T_pre1,'r-o')
legend('真实值','预测值')
xlabel('预测样本')
ylabel('辛烷值')
string = {'训练集辛烷值含量预测结果对比';['R^2=' num2str(R2_1)]};
title(string)figure(2)
plot(1:N2,T_test,'b:*',1:N2,T_pre2,'r-o')
legend('真实值','预测值')
xlabel('预测样本')
ylabel('辛烷值')
string = {'测试集辛烷值含量预测结果对比';['R^2=' num2str(R2_2)]};
title(string)