ANN(MLP) 三种预测

目录

介绍: 

 一、Mlp for binary classification

数据: 

模型: 

预测:

二、Mlp for Multiclass Classification

数据:

模型:

预测:

三、MLP for Regression

数据:

模型: 

 预测:

介绍: 

多层感知器(Multilayer Perceptron,MLP)是一种基于人工神经网络的机器学习算法。它由多个神经元(也称为节点)组成,这些神经元排列在不同的层中,并且每个神经元都与上一层的神经元相连。

MLP的基本结构包括输入层、输出层和一个或多个隐藏层。输入层接收输入数据,输出层产生最终的输出结果。隐藏层在输入层和输出层之间,它们的作用是对输入数据进行抽象和特征提取。

每个神经元都有一个与之关联的权重,这些权重用于计算神经元的加权和。加权和经过激活函数的处理,最终产生神经元的输出。常见的激活函数包括Sigmoid函数、ReLU函数、Tanh函数等。激活函数的作用是引入非线性,以增加模型的表达能力。

MLP的训练过程主要涉及两个步骤:前向传播和反向传播。在前向传播中,输入数据通过网络,每个神经元计算加权和并通过激活函数传递给下一层。在反向传播中,根据网络输出和真实标签之间的误差,通过梯度下降法调整权重,以使预测结果尽可能接近真实值。

MLP可以用于分类和回归问题。在分类问题中,MLP可以通过输出层的激活函数(通常是Softmax函数)将输入数据映射到不同的类别。在回归问题中,MLP可以通过输出层的线性激活函数(通常是恒等函数)来预测连续值。

MLP具有一些优点,如能够学习复杂的非线性关系,适用于大量数据和特征的情况,并且能够处理缺失数据。但是,MLP也存在一些缺点,如对初始权重的依赖性,容易过拟合和计算复杂性较高。

总而言之,MLP是一种强大的机器学习算法,可以应用于各种任务,包括图像和语音识别、自然语言处理、推荐系统等。

 一、Mlp for binary classification

数据: 

# mlp for binary classification
from pandas import read_csv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
# load the dataset
df = read_csv('ionosphere.csv', header=None)

 

模型: 

# split into input and output columns
X, y = df.values[:, :-1], df.values[:, -1]# ensure all data are floating point values
X = X.astype('float32')y = LabelEncoder().fit_transform(y)#改成0、1# split into train and test datasets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)# determine the number of input features
n_features = X_train.shape[1]# define model
model = Sequential()#串型
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(1, activation='sigmoid'))
# compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# fit the model
model.fit(X_train, y_train, epochs=150, batch_size=32, verbose=1)# evaluate the model
loss, acc = model.evaluate(X_test, y_test, verbose=1)
print('Test Accuracy: %.3f' % acc)

预测:

# make a prediction
row = [1,0,0.99539,-0.05889,0.85243,0.02306,0.83398,-0.37708,1,0.03760,0.85243,-0.17755,0.59755,-0.44945,0.60536,-0.38223,0.84356,-0.38542,0.58212,-0.32192,0.56971,-0.29674,0.36946,-0.47357,0.56811,-0.51171,0.41078,-0.46168,0.21266,-0.34090,0.42267,-0.54487,0.18641,-0.45300]
yhat = model.predict([row])
print('Predicted: %.3f' % yhat)
if yhat >= 1/2: yhat = 'G'
else:yhat = 'B'
print('Predicted: ', yhat)

二、Mlp for Multiclass Classification

数据:

#ANN(MLP) for Multiclass Classification 预测蓝蝴蝶花品种 ('setosa', 'versicolor', 'virginica')
from numpy import argmax
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set(style='white')
%matplotlib inline
from sklearn import decomposition
from sklearn import datasets# Loading the dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target

模型:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
# determine the number of input features
n_features = X_train.shape[1]# define model
model = Sequential()
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(3, activation='softmax'))
# compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# fit the model
model.fit(X_train, y_train, epochs=150, batch_size=32, verbose=1)

预测:

# evaluate the model
loss, acc = model.evaluate(X_test, y_test, verbose=0)
print('Test Accuracy: %.3f' % acc)
# make a prediction
row = [8.1,3.8,8.4,8.2]
#row = [2.1,3.5,3.4,2.2]
#row = [6.1,6.5,6.4,6.2]
yhat = model.predict([row])
print('Predicted: %s (class=%d)' % (yhat, argmax(yhat)))

三、MLP for Regression

数据:

from numpy import sqrt
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
# load the dataset
boston=pd.read_csv('boston.csv')

模型: 

y=boston["MEDV"]X=boston.iloc[:,:-1]# define model
model = Sequential()
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(1))
# compile the model
model.compile(optimizer='adam', loss='mse')
# fit the model
model.fit(X_train, y_train, epochs=150, batch_size=32, verbose=0)#loss, acc = model.evaluate(X_test, y_test, verbose=0)
#print('Test Accuracy: %.3f' % acc)# evaluate the model
error = model.evaluate(X_test, y_test, verbose=0)
print('MSE: %.3f, RMSE: %.3f' % (error, sqrt(error)))

 

 预测:

# make a prediction
row = [0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98]
yhat = model.predict([row])
print('Predicted: %.3f' % yhat)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/665301.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中科大计网学习记录笔记(五):协议层次和服务模型

前言: 学习视频:中科大郑烇、杨坚全套《计算机网络(自顶向下方法 第7版,James F.Kurose,Keith W.Ross)》课程 该视频是B站非常著名的计网学习视频,但相信很多朋友和我一样在听完前面的部分发现信…

《最新出炉》系列入门篇-Python+Playwright自动化测试-10-标签页操作(tab)

1.简介 标签操作其实也是基于浏览器上下文(BrowserContext)进行操作的,而且宏哥在之前的BrowserContext也有提到过,但是有的童鞋或者小伙伴还是不清楚怎么操作,或者思路有点模糊,因此今天单独来对其进行讲…

苹果的ipad可能会缓存vue项目的数据或者pinia数据

如果你发现开发的vue项目在ipad上出现了异常,比如数据出现NaN的情况,或者computed计算属性没生效,或者pinia里面的数据没生效,可能就是ipad浏览器safari缓存了数据导致的,只需要清空safari里面缓存的数据就可以了&…

RT-Thread线程管理(使用篇)

layout: post title: “RT-Thread线程管理” date: 2024-1-26 15:39:08 0800 tags: RT-Thread 线程管理(使用篇) 之后会做源码分析 线程是任务的载体,是RTT中最基本的调度单位。 线程执行时的运行环境称为上下文,具体来说就是各个变量和数据&#xff0c…

Kotlin-集成SpringBoot+MyBatis+代码生成器

目录 一、相关版本 二、Maven因引入相关依赖 三、SpringBoot配置文件 四、代码生成工具 五、实现用户服务模块案例 1、Controller 2、Service 3、Entity 4、Mapper 5、接口测试 一、相关版本 工具版本Idea2022.3.2Springboot2.7.12MyBatis3.5.3.1MySQL8.0.28JDK1.8 …

Python详细教程

一、Python简历 Python 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。 Python 的设计具有很强的可读性,相比其他语言经常使用英文关键字,其他语言的一些标点符号,它具有比其他语言更有特色语法结构。 Python 是一种解…

MySQL原理(五)事务

一、介绍: 1、介绍: 在计算机术语中,事务(Transaction)是访问并可能更新数据库中各种数据项的一个程序执行单元(unit)。事务是恢复和并发控制的基本单位。 2、事务的4大特性 原子性、一致性、隔离性、持久性。这四个属性通常称为ACID特性…

LaTeX表格:合并单元格、文字旋转90度并居中

在LaTeX表格中,如何使用\multirow合并单元格,并将单元格中的文字旋转九十度,并且居中呢? 首先引入graphicx、multirow和array包: \usepackage{graphicx} \usepackage{multirow} \usepackage{booktabs}然后定义一种新…

DoubleEnsemble:基于样本重加权和特征选择的金融数据分析方法

现代机器学习模型(如深度神经网络和梯度提升决策树)由于其提取复杂非线性模式的优越能力,在金融市场预测中越来越受欢迎。然而,由于金融数据集的信噪比非常低,并且是非平稳的,复杂的模型往往很容易过拟合。…

「递归算法」:Pow(x,n)

一、题目 实现 pow(x, n) ,即计算 x 的整数 n 次幂函数(即,xn )。 示例 1: 输入:x 2.00000, n 10 输出:1024.00000示例 2: 输入:x 2.10000, n 3 输出:9…

使用Arcgis对欧洲雷达高分辨率降水数据重投影

当前需要使用欧洲高分辨雷达降水数据,但是这个数据的投影问题非常头疼。实际的投影应该长这样(https://gist.github.com/kmuehlbauer/645e42a53b30752230c08c20a9c964f9?permalink_comment_id2954366https://gist.github.com/kmuehlbauer/645e42a53b307…

深入了解 Ansible:全面掌握自动化 IT 环境的利器

本文以详尽的篇幅介绍了 Ansible 的方方面面,旨在帮助读者从入门到精通。无论您是初学者还是有一定经验的 Ansible 用户,都可以在本文中找到对应的内容,加深对 Ansible 的理解和应用。愿本文能成为您在 Ansible 自动化旅程中的良师益友&#…

故障诊断 | 一文解决,LSTM长短期记忆神经网络故障诊断(Matlab)

文章目录 效果一览文章概述专栏介绍模型描述源码设计参考资料效果一览 文章概述 故障诊断模型 | Maltab实现LSTM长短期记忆神经网络故障诊断 专栏介绍 订阅【故障诊断】专栏,不定期更新机器学习和深度学习在故障诊断中的应用;订阅

[基础IO]文件描述符{重定向/perror/磁盘结构/inode/软硬链接}

文章目录 1. 再识重定向2.浅谈perror()3.初始文件系统4.软硬链接 1. 再识重定向 图解./sf > file.txt 2>&1 1中内容拷贝给2 使得2指向file 再学一个 把file的内容传给cat cat拿到后再给file2 2.浅谈perror() open()接口调用失败返回-1,并且错误码errno被适当的设置,…

虚拟机Windows Server 2016 安装 MySQL8

目录 一、下载MySQL8 1.下载地址: 2.创建my.ini文件 二、安装步骤 第一步:命令窗口 第二步:切换目录 第三步:安装服务 第四步:生成临时密码 第五步:启动服务 第六步: 修改密码 三…

【服务器搭建】快速完成幻兽帕鲁服务器的搭建及部署【零基础上手】

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址我的个人博客 大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 教程详戳:不需要懂技术,1分钟幻兽帕鲁服…

stable diffusion学习笔记——高清修复

ai画图中通常存在以下痛点: 受限于本地设备的性能(主要是显卡显存),无法跑出分辨率较高的图片。生图的时候分辨率一调大就爆显存。即便显存足够。目前主流的模型大多基于SD1.0和SD1.5,这些模型在训练的时候通常使用小…

【Git】01 Git介绍与安装

文章目录 一、版本控制系统二、Git三、Windows安装Git3.1 下载Git3.2 安装3.3 检查 四、Linux安装Git4.1 YUM安装4.2 源码安装 五、配置Git5.1 配置用户名和邮箱5.2 配置级别5.3 查看配置 六、总结 一、版本控制系统 版本控制系统,Version Control System&#xff…

大数据分析|大数据分析的三类核心技术

文献来源:Saggi M K, Jain S. A survey towards an integration of big data analytics to big insights for value-creation[J]. Information Processing & Management, 2018, 54(5): 758-790. 下载链接:链接:https://pan.baidu.com/s/1…

2024.2.3 寒假训练记录(17)

补一下牛客,菜得发昏了,F搞了两个小时都没搞出来,不如去开H了 还没补完 剩下的打了atc再来 文章目录 牛客 寒假集训1A DFS搜索牛客 寒假集训1B 关鸡牛客 寒假集训1C 按闹分配牛客 寒假集训1D 数组成鸡牛客 寒假集训1E 本题又主要考察了贪心牛…