ANN(MLP) 三种预测

目录

介绍: 

 一、Mlp for binary classification

数据: 

模型: 

预测:

二、Mlp for Multiclass Classification

数据:

模型:

预测:

三、MLP for Regression

数据:

模型: 

 预测:

介绍: 

多层感知器(Multilayer Perceptron,MLP)是一种基于人工神经网络的机器学习算法。它由多个神经元(也称为节点)组成,这些神经元排列在不同的层中,并且每个神经元都与上一层的神经元相连。

MLP的基本结构包括输入层、输出层和一个或多个隐藏层。输入层接收输入数据,输出层产生最终的输出结果。隐藏层在输入层和输出层之间,它们的作用是对输入数据进行抽象和特征提取。

每个神经元都有一个与之关联的权重,这些权重用于计算神经元的加权和。加权和经过激活函数的处理,最终产生神经元的输出。常见的激活函数包括Sigmoid函数、ReLU函数、Tanh函数等。激活函数的作用是引入非线性,以增加模型的表达能力。

MLP的训练过程主要涉及两个步骤:前向传播和反向传播。在前向传播中,输入数据通过网络,每个神经元计算加权和并通过激活函数传递给下一层。在反向传播中,根据网络输出和真实标签之间的误差,通过梯度下降法调整权重,以使预测结果尽可能接近真实值。

MLP可以用于分类和回归问题。在分类问题中,MLP可以通过输出层的激活函数(通常是Softmax函数)将输入数据映射到不同的类别。在回归问题中,MLP可以通过输出层的线性激活函数(通常是恒等函数)来预测连续值。

MLP具有一些优点,如能够学习复杂的非线性关系,适用于大量数据和特征的情况,并且能够处理缺失数据。但是,MLP也存在一些缺点,如对初始权重的依赖性,容易过拟合和计算复杂性较高。

总而言之,MLP是一种强大的机器学习算法,可以应用于各种任务,包括图像和语音识别、自然语言处理、推荐系统等。

 一、Mlp for binary classification

数据: 

# mlp for binary classification
from pandas import read_csv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
# load the dataset
df = read_csv('ionosphere.csv', header=None)

 

模型: 

# split into input and output columns
X, y = df.values[:, :-1], df.values[:, -1]# ensure all data are floating point values
X = X.astype('float32')y = LabelEncoder().fit_transform(y)#改成0、1# split into train and test datasets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)# determine the number of input features
n_features = X_train.shape[1]# define model
model = Sequential()#串型
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(1, activation='sigmoid'))
# compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# fit the model
model.fit(X_train, y_train, epochs=150, batch_size=32, verbose=1)# evaluate the model
loss, acc = model.evaluate(X_test, y_test, verbose=1)
print('Test Accuracy: %.3f' % acc)

预测:

# make a prediction
row = [1,0,0.99539,-0.05889,0.85243,0.02306,0.83398,-0.37708,1,0.03760,0.85243,-0.17755,0.59755,-0.44945,0.60536,-0.38223,0.84356,-0.38542,0.58212,-0.32192,0.56971,-0.29674,0.36946,-0.47357,0.56811,-0.51171,0.41078,-0.46168,0.21266,-0.34090,0.42267,-0.54487,0.18641,-0.45300]
yhat = model.predict([row])
print('Predicted: %.3f' % yhat)
if yhat >= 1/2: yhat = 'G'
else:yhat = 'B'
print('Predicted: ', yhat)

二、Mlp for Multiclass Classification

数据:

#ANN(MLP) for Multiclass Classification 预测蓝蝴蝶花品种 ('setosa', 'versicolor', 'virginica')
from numpy import argmax
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set(style='white')
%matplotlib inline
from sklearn import decomposition
from sklearn import datasets# Loading the dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target

模型:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
# determine the number of input features
n_features = X_train.shape[1]# define model
model = Sequential()
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(3, activation='softmax'))
# compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# fit the model
model.fit(X_train, y_train, epochs=150, batch_size=32, verbose=1)

预测:

# evaluate the model
loss, acc = model.evaluate(X_test, y_test, verbose=0)
print('Test Accuracy: %.3f' % acc)
# make a prediction
row = [8.1,3.8,8.4,8.2]
#row = [2.1,3.5,3.4,2.2]
#row = [6.1,6.5,6.4,6.2]
yhat = model.predict([row])
print('Predicted: %s (class=%d)' % (yhat, argmax(yhat)))

三、MLP for Regression

数据:

from numpy import sqrt
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
# load the dataset
boston=pd.read_csv('boston.csv')

模型: 

y=boston["MEDV"]X=boston.iloc[:,:-1]# define model
model = Sequential()
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(8, activation='relu', kernel_initializer='he_normal'))
model.add(Dense(1))
# compile the model
model.compile(optimizer='adam', loss='mse')
# fit the model
model.fit(X_train, y_train, epochs=150, batch_size=32, verbose=0)#loss, acc = model.evaluate(X_test, y_test, verbose=0)
#print('Test Accuracy: %.3f' % acc)# evaluate the model
error = model.evaluate(X_test, y_test, verbose=0)
print('MSE: %.3f, RMSE: %.3f' % (error, sqrt(error)))

 

 预测:

# make a prediction
row = [0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98]
yhat = model.predict([row])
print('Predicted: %.3f' % yhat)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/665301.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中科大计网学习记录笔记(五):协议层次和服务模型

前言: 学习视频:中科大郑烇、杨坚全套《计算机网络(自顶向下方法 第7版,James F.Kurose,Keith W.Ross)》课程 该视频是B站非常著名的计网学习视频,但相信很多朋友和我一样在听完前面的部分发现信…

《最新出炉》系列入门篇-Python+Playwright自动化测试-10-标签页操作(tab)

1.简介 标签操作其实也是基于浏览器上下文(BrowserContext)进行操作的,而且宏哥在之前的BrowserContext也有提到过,但是有的童鞋或者小伙伴还是不清楚怎么操作,或者思路有点模糊,因此今天单独来对其进行讲…

苹果的ipad可能会缓存vue项目的数据或者pinia数据

如果你发现开发的vue项目在ipad上出现了异常,比如数据出现NaN的情况,或者computed计算属性没生效,或者pinia里面的数据没生效,可能就是ipad浏览器safari缓存了数据导致的,只需要清空safari里面缓存的数据就可以了&…

开始学习第三十天

最近天气有点不太好 都是雨雪 今天也就正式学了第一个月咯 希望以后接着加油 今天是小年 大家小年快乐呀 明天要去奶奶家 在那边就没办法学习了 等年三十之后回来接着学习!

RT-Thread线程管理(使用篇)

layout: post title: “RT-Thread线程管理” date: 2024-1-26 15:39:08 0800 tags: RT-Thread 线程管理(使用篇) 之后会做源码分析 线程是任务的载体,是RTT中最基本的调度单位。 线程执行时的运行环境称为上下文,具体来说就是各个变量和数据&#xff0c…

Kotlin-集成SpringBoot+MyBatis+代码生成器

目录 一、相关版本 二、Maven因引入相关依赖 三、SpringBoot配置文件 四、代码生成工具 五、实现用户服务模块案例 1、Controller 2、Service 3、Entity 4、Mapper 5、接口测试 一、相关版本 工具版本Idea2022.3.2Springboot2.7.12MyBatis3.5.3.1MySQL8.0.28JDK1.8 …

Python详细教程

一、Python简历 Python 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。 Python 的设计具有很强的可读性,相比其他语言经常使用英文关键字,其他语言的一些标点符号,它具有比其他语言更有特色语法结构。 Python 是一种解…

20240203进程间通信的7种方式

内核提供的原始通信方式有三种: ①无名管道:没有名字的管道,是一个特殊的文件,并且存储在内存上,不在文件系统中展示,无名管道打开后,会返回两个文件描述符,一个是读端,…

MySQL原理(五)事务

一、介绍: 1、介绍: 在计算机术语中,事务(Transaction)是访问并可能更新数据库中各种数据项的一个程序执行单元(unit)。事务是恢复和并发控制的基本单位。 2、事务的4大特性 原子性、一致性、隔离性、持久性。这四个属性通常称为ACID特性…

LaTeX表格:合并单元格、文字旋转90度并居中

在LaTeX表格中,如何使用\multirow合并单元格,并将单元格中的文字旋转九十度,并且居中呢? 首先引入graphicx、multirow和array包: \usepackage{graphicx} \usepackage{multirow} \usepackage{booktabs}然后定义一种新…

类银河恶魔城学习记录1-4 PlayerJumpState基本源代码 P31

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释,可供学习Alex教程的人参考 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Player.cs using System.Collections; using System.Collections.Generic; using Unity.VisualScripting; u…

DoubleEnsemble:基于样本重加权和特征选择的金融数据分析方法

现代机器学习模型(如深度神经网络和梯度提升决策树)由于其提取复杂非线性模式的优越能力,在金融市场预测中越来越受欢迎。然而,由于金融数据集的信噪比非常低,并且是非平稳的,复杂的模型往往很容易过拟合。…

「递归算法」:Pow(x,n)

一、题目 实现 pow(x, n) ,即计算 x 的整数 n 次幂函数(即,xn )。 示例 1: 输入:x 2.00000, n 10 输出:1024.00000示例 2: 输入:x 2.10000, n 3 输出:9…

使用Arcgis对欧洲雷达高分辨率降水数据重投影

当前需要使用欧洲高分辨雷达降水数据,但是这个数据的投影问题非常头疼。实际的投影应该长这样(https://gist.github.com/kmuehlbauer/645e42a53b30752230c08c20a9c964f9?permalink_comment_id2954366https://gist.github.com/kmuehlbauer/645e42a53b307…

个人建站前端篇(三)环境变量配置

在项目根目录下新建.env.development和.env.production文件,里面配置环境变量 .env.development内容如下 # 页面标题 VITE_APP_TITLE 云风网 # 开发环境配置 VITE_APP_ENV development # 开发环境 VITE_APP_BASE_API /api # 是否启用代理 VITE_HTTP_PROXY tru…

深入了解 Ansible:全面掌握自动化 IT 环境的利器

本文以详尽的篇幅介绍了 Ansible 的方方面面,旨在帮助读者从入门到精通。无论您是初学者还是有一定经验的 Ansible 用户,都可以在本文中找到对应的内容,加深对 Ansible 的理解和应用。愿本文能成为您在 Ansible 自动化旅程中的良师益友&#…

故障诊断 | 一文解决,LSTM长短期记忆神经网络故障诊断(Matlab)

文章目录 效果一览文章概述专栏介绍模型描述源码设计参考资料效果一览 文章概述 故障诊断模型 | Maltab实现LSTM长短期记忆神经网络故障诊断 专栏介绍 订阅【故障诊断】专栏,不定期更新机器学习和深度学习在故障诊断中的应用;订阅

Skywalking 学习之ByteBuddy 方法执行时间监控

Skywalking git: GitHub - apache/skywalking: APM, Application Performance Monitoring System 集成入门: 10分钟3个步骤集成使用SkyWalking - 知乎 下面自己学习了一下ByteBuddy的用法,实战了一下: 入门教程: B…

[基础IO]文件描述符{重定向/perror/磁盘结构/inode/软硬链接}

文章目录 1. 再识重定向2.浅谈perror()3.初始文件系统4.软硬链接 1. 再识重定向 图解./sf > file.txt 2>&1 1中内容拷贝给2 使得2指向file 再学一个 把file的内容传给cat cat拿到后再给file2 2.浅谈perror() open()接口调用失败返回-1,并且错误码errno被适当的设置,…

虚拟机Windows Server 2016 安装 MySQL8

目录 一、下载MySQL8 1.下载地址: 2.创建my.ini文件 二、安装步骤 第一步:命令窗口 第二步:切换目录 第三步:安装服务 第四步:生成临时密码 第五步:启动服务 第六步: 修改密码 三…