【机器学习】基于tensorflow实现你的第一个DNN网络

博客导读:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

目录

一、引言

二、tensorflow介绍

2.1 tensorflow历史

2.2 tensorflow特点

 2.3 tensorflow安装

三、tensorflow实战

3.1 引入依赖的tensorflow库

3.2 训练数据准备

3.3 创建三层DNN模型

3.4 编译模型、定义损失函数与优化器

3.5 启动训练,迭代收敛

3.6 模型评估

3.7 可以直接跑的代码 

四、总结


一、引言

上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发介绍如何使用pytorch实现一个简单的DNN网络,今天我们还是用同样的例子,看看使用tensorflow如何实现。

二、tensorflow介绍

2.1 tensorflow历史

TensorFlow由谷歌人工智能团队谷歌大脑(Google Brain)开发和维护,拥有包括TensorFlow Hub、TensorFlow Lite、TensorFlow Research Cloud在内的多个项目以及各类应用程序接口(Application Programming Interface, API)。自2015年11月9日起,TensorFlow依据阿帕奇授权协议(Apache 2.0 open source license)开放源代码。

2.2 tensorflow特点

深度学习时代,tensorflow在工业应用较为广泛,而pytorch更多应用于研究中。大模型时代,pytorch是很多项目的底层库,大有超过tensorflow的趋势。可谓并驾齐驱。

  • 生态系统更成熟:TensorFlow拥有一个庞大的社区和丰富的资源,包括大量的教程、预训练模型和工具,适合从初学者到专家的各个层次用户。
  • 生产部署友好:TensorFlow支持更多的平台和设备,包括移动设备和边缘设备,提供了TensorFlow Lite和TensorFlow.js等,便于模型的部署和优化。
  • 静态图与动态图的结合:虽然早期TensorFlow以静态图为主,但TensorFlow 2.x引入了Eager Execution,结合了动态图的易用性和静态图的高性能,同时保持了模型的可部署性。
  • Keras集成:TensorFlow内建了Keras,这是一个高级神经网络API,使得模型构建、训练和评估更加简洁直观。
  • TensorBoard:TensorFlow自带的可视化工具TensorBoard,便于可视化模型结构、训练过程中的损失和指标,帮助用户更好地理解和调试模型。
  • 广泛的工业应用支持:由于其成熟度和稳定性,TensorFlow在工业界得到了广泛的应用,特别是在大型企业中。

 2.3 tensorflow安装

与pytorch一样,还是采用conda创建环境,采用pip安装tensorflow包

1.建立名为pytrain,python版本为3.11的conda环境(这里与pytorch一样)

conda create -n pytrain python=3.11
conda activate pytrain

​  

 2.采用pip下载tensorflow以及机器学习常用的scikit-learn和numpy包

pip install tensorflow scikit-learn numpy  -i https://mirrors.cloud.tencent.com/pypi/simple

​ 

这里未指定版本,默认下载最新版本tensorflow-2.16.1以及其他tensorboard等生态包。 

三、tensorflow实战

 动手实现一个三层DNN网络:

3.1 引入依赖的tensorflow库

这里主要是tensorflow、keras、sklearn、numpy等

Keras是一个用于构建和训练深度学习模型的高级API,它设计得极其用户友好,支持快速实验。Keras可以运行在TensorFlow之上。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

3.2 训练数据准备

这里采用numpy库进行数据随机生成

# 假设你已经有了特征数据 X 和标签数据 y
# X, y = ...  # 实际数据加载和预处理步骤
# 这里我们用随机数据作为示例
np.random.seed(0)
X = np.random.rand(1000, 1000)  # 1000个样本,每个样本1000个特征
y = np.random.randint(0, 2, size=(1000, 1))  # 二分类标签# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
  • 首先,采用numpy的random随机生成X矩阵(1000行样本*1000行特征)和y矩阵(1000行0或1的label)
  • 其次,采用sklearn库中的StandardScaler将X矩阵中的每个样本特征数值标准化(将每个特征都转换为正态分布,均值为0,标准差为1),这一步骤对于机器学习算法的性能至关重要,特别是那些对输入数据的尺度敏感的算法。
  • 最后,按照2:8的比例从数据中切分出测试机与训练集

3.3 创建三层DNN模型

采用keras.sequential类,顾名思义“按顺序的”由输入至输出编排神经网络

# 创建模型
model = Sequential([Dense(512, input_shape=(X_train.shape[1],)),  # 第一层Activation('relu'),Dense(512),  # 第二层Activation('relu'),Dense(1),  # 输出层Activation('sigmoid')  # 二分类使用sigmoid
])

 Sequential是Keras中用于构建深度学习模型的一个类,特别适合于构建线性的堆叠层模型。这种模型结构是层与层直接相连,没有复杂的拓扑结构,适合于解决如图像分类、文本分类等任务

特点

  • 线性堆叠:层按照添加的顺序堆叠,每一层只与前一层有连接。
  • 易于使用:适合初学者和快速原型设计,对于复杂的网络结构可能不够灵活。
  • 灵活性限制:对于需要多输入或多输出,或者层间有复杂连接的模型,应使用更高级的模型结构,如Functional API。

3.4 编译模型、定义损失函数与优化器

不同于pytorch的实例化模型对象,这里采用compile对模型进行编译。与pytorch相同点是都要定义损失函数和优化器,方法与技巧完全相同。

# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),loss=BinaryCrossentropy(),metrics=['accuracy'])
  • optimizer=Adam(learning_rate=0.001):这里选择了Adam作为优化器。Adam(Adaptive Moment Estimation)是一种常用的优化算法,它结合了RMSprop和Momentum的优点,能够自动调整学习率。通过设置learning_rate=0.001,可以控制模型学习的速度。学习率是训练过程中的一个重要超参数,影响模型收敛的速度和最终的性能。
  • loss=BinaryCrossentropy():损失函数设置为二元交叉熵(Binary Crossentropy)。这个损失函数适用于二分类问题,它衡量了模型预测的概率分布与实际标签之间的差异。在二分类任务中,正确选择损失函数对于模型的性能至关重要。
  • metrics=['accuracy']:指定评估模型性能的指标。这里使用的是准确率(accuracy),即分类正确的比例。在训练和验证过程中,除了损失值外,还会计算并显示这个指标,帮助我们了解模型的性能。

3.5 启动训练,迭代收敛

不同于pytorch需要写两个循环处理每一行样本,tensorflow直接采用fit方法对输入的特征样本矩阵以及label矩阵进行训练

tensorflow版:

# 训练模型
history = model.fit(X_train, y_train, epochs=100, validation_split=0.1,  # 使用10%的数据作为验证集verbose=1)

pytorch版:

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()  # 设置为训练模式running_loss = 0.0for i, (inputs, labels) in enumerate(data_loader, 0):optimizer.zero_grad()  # 清零梯度outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()  # 反向传播optimizer.step()  # 更新权重running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(data_loader)}')

对比来看,pytorch版的更加透明,有助于理解,tensorflow更加便捷 

运行后可以看到loss逐步收敛:​

3.6 模型评估

通过model.evaluate对模型进行评估,evaluate与fit的区别是只计算指标不进行模型更新

tensorflow版:

# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

 pytorch版:

import torchmetrics # 导入torchmetricstest_num_samples = 200  # 测试样本数
test_X_train = torch.randn(test_num_samples, input_size) 
test_y_train = torch.randint(0, output_size, (test_num_samples,))# 数据加载
test_dataset = TensorDataset(test_X_train,test_y_train)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)# 在模型训练完成后进行评估
# 首先,我们需要确保模型在评估模式下
model.eval()# 初始化准确率和召回率的计算器
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=output_size)
recall = torchmetrics.Recall(task="multiclass", num_classes=output_size)with torch.no_grad():  # 确保在评估时不进行梯度计算for inputs, labels in test_data_loader:outputs = model(inputs)preds = torch.softmax(outputs, dim=1)# 更新指标计算器accuracy.update(preds, labels)recall.update(preds, labels)# 打印准确率和召回率
print(f'Accuracy: {accuracy.compute():.4f}')
print(f'Recall: {recall.compute():.4f}')print('Evaluation finished.')

对比pytorch需要写一个循环,tensorflow.keras的封装更为简洁

运行后,可以输出模型的准确率与召回率,由于采用随机生成的测试数据且迭代轮数较少,具体数值不错参考,可以根据自己需要丰富数据。

3.7 可以直接跑的代码 

与上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发一样,附可以直接运行的代码,先跑起来,再一行行研究!

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np# 假设你已经有了特征数据 X 和标签数据 y
# X, y = ...  # 实际数据加载和预处理步骤
# 这里我们用随机数据作为示例
np.random.seed(0)
X = np.random.rand(1000, 1000)  # 1000个样本,每个样本1000个特征
y = np.random.randint(0, 2, size=(1000, 1))  # 二分类标签# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 创建模型
model = Sequential([Dense(512, input_shape=(X_train.shape[1],)),  # 第一层Activation('relu'),Dense(512),  # 第二层Activation('relu'),Dense(1),  # 输出层Activation('sigmoid')  # 二分类使用sigmoid
])# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),loss=BinaryCrossentropy(),metrics=['accuracy'])# 训练模型
history = model.fit(X_train, y_train, epochs=10, validation_split=0.1,  # 使用10%的数据作为验证集verbose=1)# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

四、总结

本文先对tensorflow深度学习框架历史、特点及安装方法进行介绍,接下来基于tensorflow带读者一步步开发一个简单的三层神经网络程序,最后附可执行的代码供读者进行测试学习。个人感觉tensorflow封装程度高于pytorch,网络结构也更加清晰,但pytorch更加透明。

喜欢的话期待您的关注、点赞、收藏,您的互动是对我最大的鼓励!

如果还有时间,可以看看我的其他文章:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/844330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

传统RNN网络及其案例--人名分类

传统的RNN模型简介 RNN 先上图 这图看起来莫名其妙,想拿着跟CNN对比着学第一眼看上去有点摸不着头脑,其实我们可以把每一个时刻的图展开来,如下 其中,为了简化计算,我们默认每一个隐层参数相同,这样看来R…

添砖Java(十二)——异常,异常捕获,常见异常方法

异常: 定义:异常通俗来讲,其实就是你写出bug来了,编译器给你报错了。 public static void main(String[] args)throws Exception {int z10/0;} 这个代码虽然说是可以运行,但是编译器会报错。 因为10不能去除以0。 异…

【C++】:vector容器的底层模拟实现迭代器失效隐藏的浅拷贝

目录 💡前言一,构造函数1 . 强制编译器生成默认构造2 . 拷贝构造3. 用迭代器区间初始化4. 用n个val值构造5. initializer_list 的构造 二,析构函数三,关于迭代器四,有关数据个数与容量五,交换函数swap六&am…

C# 数组/集合排序

一&#xff1a;基础类型集合排序 /// <summary> /// 排序 /// </summary> /// <param name"isReverse">顺序是否取反</param> public static void Sort<T>(this IList<T> array, bool isReverse false)where T : IComparable …

10种排序算法总结-(c语言实现与动画演示)

算法分类 十种常见排序算法可以分为两大类&#xff1a; 比较类排序&#xff1a;通过比较来决定元素间的相对次序&#xff0c;由于其时间复杂度不能突破O(nlogn)&#xff0c;因此也称为非线性时间比较类排序。非比较类排序&#xff1a;不通过比较来决定元素间的相对次序&#…

什么叫USDT(泰达币)的前世今生!

一、引言 在数字货币的世界里&#xff0c;USDT&#xff08;Tether USDT&#xff09;以其独特的稳定机制&#xff0c;成为了连接传统金融市场与加密货币市场的桥梁。本文将带您了解USDT的诞生背景、发展历程、技术特点以及未来展望。 二、USDT的诞生背景 USDT是Tether公司推出…

【服务器部署篇】Linux下Node.js的安装和配置

作者介绍&#xff1a;本人笔名姑苏老陈&#xff0c;从事JAVA开发工作十多年了&#xff0c;带过刚毕业的实习生&#xff0c;也带过技术团队。最近有个朋友的表弟&#xff0c;马上要大学毕业了&#xff0c;想从事JAVA开发工作&#xff0c;但不知道从何处入手。于是&#xff0c;产…

ChatGPT的工作原理,这篇文章说清楚了!

作者&#xff1a;史蒂芬沃尔弗拉姆&#xff08;Stephen Wolfram&#xff09;英、美籍 计算机科学家&#xff0c; 物理学家。他是 Mathematica 的首席设计师&#xff0c;《一种新科学》一书的作者。 ChatGPT 能够自动生成一些读起来表面上甚至像人写的文字的东西&#xff0c;这…

《庆余年算法番外篇》:范闲通过最短路径算法在阻止黑骑截杀林相

剧情背景 在《庆余年 2》22集中&#xff0c;林相跟大宝交代完为人处世的人生哲理之后&#xff0c;就要跟大宝告别了 在《庆余年 2》23集中&#xff0c;林相在告老还乡的路上与婉儿和大宝告别后 范闲也在与婉儿的对话中知道黑骑调动是绝密&#xff0c;并把最近一次告老还乡梅…

汇智知了堂实力展示:四川农业大学Python爬虫实训圆满结束

近日&#xff0c;汇智知了堂在四川农业大学举办的为期五天的校内综合项目实训活动已圆满结束。本次实训聚焦Python爬虫技术&#xff0c;旨在提升学生的编程能力和数据分析能力&#xff0c;为学生未来的职业发展打下坚实的基础。 作为一家在IT教育行业享有盛誉的机构&#xff…

C++数据结构之:队Queue

摘要&#xff1a; it人员无论是使用哪种高级语言开发东东&#xff0c;想要更高效有层次的开发程序的话都躲不开三件套&#xff1a;数据结构&#xff0c;算法和设计模式。数据结构是相互之间存在一种或多种特定关系的数据元素的集合&#xff0c;即带“结构”的数据元素的集合&am…

嵌入式不一定只能用C!

嵌入式不一定只能用C! ---------------------------------------------------------------------------------------手动分割线-------------------------------------------------------------------------------- 本文章参考了以下文章&#xff1a; 这里是引用 ------------…

现场辩论赛活动策划方案

活动目的&#xff1a; 技能竞赛中的辩论环节既可以考核员工的知识点&#xff0c;同时也可以考核员工业务办事能力&#xff0c;表达能力&#xff0c;是一种比较全面且较有深度的竞赛方式。 辩论赛细则&#xff1a; 1、时间提示 : 自由辩论阶段&#xff0c;每方使用时间剩…

【CTF-Web】XXE学习笔记(附ctfshow例题)

XXE 文章目录 XXE0x01 前置知识汇总XMLDTD &#xff08;Document Type Definition&#xff09; 0x02 XXE0x03 XXE危害0x04 攻击方式1. 通过File协议读取文件Web373(有回显)Web374(无回显) Web375Web376Web377Web378 0x01 前置知识汇总 XML 可扩展标记语言&#xff08;eXtensi…

故障诊断 | 基于KAN故障诊断模型

效果一览 文章概述 故障诊断 | 基于 KAN故障诊断模型。KAN是一种全新的神经网络架构&#xff0c;它与传统的MLP架构不同&#xff0c;能够用更少的参数量在Science领域取得惊人的表现&#xff0c;并且具备可解释性&#xff0c;有望成为深度学习模型发展的一个重要方向。运用KAN&…

从0开始学web之信息收集

web1~源代码 web1:where is flag?直接右键源代码找到。 web2~源代码 无法查看源代码确实右键不了&#xff0c;F12用不了&#xff0c; 但是还可以在URL前加上view-source: web3~HTTP响应 web3:where is flag?右键源代码没有&#xff0c;那就看看HTTP 头&#xff0c;F12抓…

数据大屏方案 : 实现数据可视化的关键一环_光点科技

在数字时代的浪潮中&#xff0c;数据已经成为企业决策和操作的重要基础。因此&#xff0c;“数据大屏方案”逐渐成为业界关注的焦点。这类方案通过将复杂的数据集合以直观的形式展现出来&#xff0c;帮助决策者快速把握信息&#xff0c;做出更加明智的决策。 数据大屏的定义及作…

Java-数组内存解析

文章目录 1.内存的主要结构&#xff1a;栈、堆2.一维数组的内存解析3.二维数组的内存解析 1.内存的主要结构&#xff1a;栈、堆 2.一维数组的内存解析 举例1&#xff1a;基本使用 举例2&#xff1a;两个变量指向一个数组 3.二维数组的内存解析 举例1&#xff1a; 举例2&am…

java生产制造执行系统MES源码:系统环境:Java EE 8、Servlet 3.0、Apache Maven 3 2;

MES系统技术选型 系统环境&#xff1a;Java EE 8、Servlet 3.0、Apache Maven 3 2&#xff1b; 主框架&#xff1a;Spring Boot 2.2.x、Spring Framework 5.2.x、Spring Security 5.2.x 3 持久层&#xff1a;Apache MyBatis 3.5.x、Hibernate Validation 6.0.x、Alibaba Dru…