Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能

前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。

实现代码

import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 设置超参数
learning_rate = 0.001
num_epochs = 100
batch_size = 32# 定义输入和输出的维度
input_dim = X.shape[1]
output_dim = len(set(y))# 定义权重和偏置项
W1 = tf.Variable(tf.random.normal(shape=(input_dim, 64), dtype=tf.float64))
b1 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W2 = tf.Variable(tf.random.normal(shape=(64, 64), dtype=tf.float64))
b2 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W3 = tf.Variable(tf.random.normal(shape=(64, output_dim), dtype=tf.float64))
b3 = tf.Variable(tf.zeros(shape=(output_dim,), dtype=tf.float64))# 定义前向传播函数
def forward_pass(X):X = tf.cast(X, tf.float64)h1 = tf.nn.relu(tf.matmul(X, W1) + b1)h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)logits = tf.matmul(h2, W3) + b3return logits# 定义损失函数
def loss_fn(logits, labels):return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate)# 定义准确率指标
accuracy_metric = tf.metrics.SparseCategoricalAccuracy()# 定义训练步骤
def train_step(inputs, labels):with tf.GradientTape() as tape:logits = forward_pass(inputs)loss_value = loss_fn(logits, labels)gradients = tape.gradient(loss_value, [W1, b1, W2, b2, W3, b3])optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2, W3, b3]))accuracy_metric(labels, logits)return loss_value# 进行训练
for epoch in range(num_epochs):epoch_loss = 0.0accuracy_metric.reset_states()for batch_start in range(0, len(X_train), batch_size):batch_end = batch_start + batch_sizebatch_X = X_train[batch_start:batch_end]batch_y = y_train[batch_start:batch_end]loss = train_step(batch_X, batch_y)epoch_loss += losstrain_loss = epoch_loss / (len(X_train) // batch_size)train_accuracy = accuracy_metric.result()print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}")# 进行评估
logits = forward_pass(X_test)
test_loss = loss_fn(logits, y_test)
test_accuracy = accuracy_metric(y_test, logits)print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

实现效果

本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python、机器学习、深度学习基础知识与案例。

致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

邀请三个朋友关注V订阅号:数据杂坛,即可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/121133.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ssh连接远程服务器,并在终端安装anaconda

官网下载安装:anaconda2023.09版本(官网地址:https://www.anaconda.com/download#downloads) wget https://repo.anaconda.com/archive/Anaconda3-2023.09-0-Linux-x86_64.sh使用阿里云镜像下载安装,官网下载太慢。阿…

2698 求一个整数的惩罚数 (子集和,DFS)

class Solution { public:bool dfs(int target, string s, int index, int sum) {// 只有整个字符串都被分割,求和,和看结果是不是等于targetif(index s.size()) {return sum target;}int num 0; // 在现在的子集中去依次加入余下的元素// 1 2 9 6// …

支付宝证书到期更新完整过程

如果用户收到 支付宝公钥证书 到期通知后,可以根据如下指引更新证书 确认上传成功后就会生成新的证书,把新的证书替换到生产环境就可以了

浏览器事件循环 (event loop)

进程与线程 进程 进程的概念 进程是操作系统中的一个程序或者一个程序的一次执行过程,是一个动态的概念,是程序在执行过程中分配和管理资源的基本单位,是操作系统结构的基础。 简单的来说,就是一个程序运行开辟的一块内存空间&a…

RocketMQ事务消息 超时重发还是原来的消息吗?

以下面的一个demo例子来分析一下,探索RocketMQ事务消息原理。 public static final String PRODUCER_GROUP "tran-test";public static final String DEFAULT_NAMESRVADDR "127.0.0.1:9876";public static final String TOPIC "Test&qu…

【CSS】CSS 属性计算过程

1. 概述 我们所书写的任何一个 HTML 元素&#xff0c;实际上都有完整的一整套 CSS 样式。如果没有修改某样式&#xff0c;大概率可能使用默认值。 例如&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"&…

unity 基于UGUI的无限动态滚动列表

基于UGUI的动态滚动列表&#xff0c;主要支持以下功能&#xff1a; 继承自UGUI的SrollRect&#xff0c;支持ScrollRect的所有功能&#xff1b; 使用对象池来管理列表元素&#xff0c;以实现列表元素的复用&#xff1b; 支持一行多个元素或一列多个元素&#xff1b; 可使用不…

【MyBatis Plus】深入探索 MyBatis Plus 的条件构造器,自定义 SQL语句,Service 接口的实现

文章目录 前言一、条件构造器1.1 什么是条件构造器1.2 QueryWrapper1.3 UpdateWrapper1.4 LambdaWrapper 二、自定义 SQL 语句2.1 自定义 SQL 的基本用法2.2 自定义 SQL 实现多表查询 三、Service 接口3.1 对 Service 接口的认识3.2 实现 Service 接口3.3 实现增删改查功能3.4 …

el-form那些事

vue3element-plus el-form那些事 输入框后拼接文字 输入框后拼接文字 <el-form-item :label"t(location.locationLength)" prop"locationLength"><el-input v-model"form.locationLength" :placeholder"t(location.inputLocation…

Windoes定时任务、设置定时重启系统

步骤一&#xff1a; 打开计算机管理 通过&#xff1a;control(控制面板&#xff09;或者compmgmt.msc(计算机管理&#xff09;打开程序 步骤二&#xff1a;打开——>系统工具 步骤三&#xff1a; 选择——>任务计划程序 步骤四&#xff1a; 可选择创建新文件命名&…

正点原子嵌入式linux驱动开发——Linux PWM驱动

PWM是很常用到功能&#xff0c;可以通过PWM来控制电机速度&#xff0c;也可以使用PWM来控制LCD的背光亮度。本章就来学习一下如何在Linux下进行PWM驱动开发。 PWM驱动解析 不在介绍PWM是什么了&#xff0c;直接进入使用。 给LCD的背光引脚输入一个PWM信号&#xff0c;这样就…

记一次企业微信的(CorpID)和密钥(Secret)泄漏的利用案例

文章目录 一、介绍二、利用过程1、获取AccessToken2、获取企业微信接口IP段3、获取企业微信回调IP段4、通过部门ID,查看返回的ID5、通过部门ID,查看用户列表6、通过便利ID,发现用户信息泄露,可以进行提交报告7、通过添加接口,添加企业账号8、登陆企业账号进行测试三、参考…

【随机过程】布朗运动

这里写目录标题 Brownian motion Brownian motion The brownian motion 1D and brownian motion 2D functions, written with the cumsum command and without for loops, are used to generate a one-dimensional and two-dimensional Brownian motion, respectively. 使用cu…

主动调度是如何发生的

计算机主要处理计算、网络、存储三个方面。计算主要是 CPU 和内存的合作&#xff1b;网络和存储则多是和外部设备的合作&#xff1b;在操作外部设备的时候&#xff0c;往往需要让出 CPU&#xff0c;就像上面两段代码一样&#xff0c;选择调用 schedule() 函数。 上下文切换主要…

Kafka - 3.x 图解Broker总体工作流程

文章目录 Zk中存储的kafka的信息Kafka Broker总体工作流程1. broker启动后向zk中注册2. Controller谁先启动注册&#xff0c;谁说了算3. 由选举出来的Controller监听brokers节点的变化4. Controller决定leader选举5. Controller将节点信息上传到Zk中6. 其他Controller从zk中同步…

Fourier分析导论——第1章——Fourier分析的起源(E.M. Stein R. Shakarchi)

第 1 章 Fourier分析的起源 (The Genesis of Fourier Analysis) Regarding the researches of dAlembert and Euler could one not add that if they knew this expansion, they made but a very imperfect use of it. They were both persuaded that an arbitrary and d…

jenkins配置gitlab凭据

下载Credentials Binding插件&#xff08;默认是已经安装了&#xff09; 在凭据配置里添加凭据类型 点击保存 Username with password&#xff1a; 用户名和密码 SSH Username with private 在凭据管理里面添加gitlab账号和密码 点击全局 点击添加凭据&#xff08;版本不同…

Go RESTful API 接口开发

文章目录 什么是 RESTful APIGo 流行 Web 框架-GinGo HelloWorldGin 路由和控制器Gin 处理请求参数生成 HTTP 请求响应Gin 的学习内容实战用 Gin 框架开发 RESTful APIOAuth 2.0接口了解用 Go 开发 OAuth2.0 接口示例 编程有一个准则——Don‘t Repeat Yourself&#xff08;不要…

如何在Windows和Linux系统上监听文件夹的变动?

文章目录 如何在Windows和Linux系统上监听文件夹的变动&#xff1f;读写文件文件系统的操作缓冲和流文件改变事件 如何在Windows和Linux系统上监听文件夹的变动&#xff1f; libuv库实现了监听整个文件夹的修改。本文详细介绍libuv库文件读写和监听的的实现方法。libuv库开发了…

Unity的碰撞检测(六)

温馨提示&#xff1a;本文基于前一篇“Unity的碰撞检测(五)”继续探讨两个游戏对象具备刚体的BodyType均为Dynamic&#xff0c;但是Collision Detection属性不同的碰撞检测&#xff0c;阅读本文则默认已阅读前文。 &#xff08;一&#xff09;测试说明 在基于两个游戏对象都具…