TensorFlow代码逻辑 vs PyTorch代码逻辑

文章目录

  • 一、TensorFlow
    • (一)导入必要的库
    • (二)加载MNIST数据集
    • (三)数据预处理
    • (四)构建神经网络模型
    • (五)编译模型
    • (六)训练模型
    • (七)评估模型
    • (八)将模型的输出转化为概率
    • (九)预测测试集的前5个样本
  • 二、PyTorch
    • (一)导入必要的库
    • (二)定义神经网络模型
    • (三)数据预处理和加载
    • (四)初始化模型、损失函数和优化器
    • (五)训练模型
    • (六)评估模型
    • (七)设置设备为GPU或CPU
    • (八)运行训练和评估
    • (九)预测测试集的前5个样本
  • 三、TensorFlow和PyTorch代码逻辑上的对比
    • (一)模型定义
    • (二)数据处理
    • (三)训练过程
    • (四)自动求导
  • 四、TensorFlow和PyTorch的应用
  • 五、动态图计算
    • (一)TensorFlow(静态图计算):
    • (二)PyTorch(动态图计算):

一、TensorFlow

使用TensorFlow构建一个简单的神经网络来对MNIST数据集进行分类

(一)导入必要的库

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

(二)加载MNIST数据集

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

(三)数据预处理

将图像数据归一化到[0, 1]范围,以提高模型的训练效果

x_train, x_test = x_train / 255.0, x_test / 255.0

(四)构建神经网络模型

  • Flatten层:将输入的28x28像素图像展平成784个特征的一维向量。
  • Dense层:全连接层,包含128个神经元,使用ReLU激活函数。
  • Dropout层:在训练过程中随机丢弃20%的神经元,防止过拟合。
  • 输出层:包含10个神经元,对应10个类别(数字0-9)。
model = models.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dropout(0.2),layers.Dense(10)
])

(五)编译模型

  • optimizer=‘adam’:使用Adam优化器。
  • loss=‘SparseCategoricalCrossentropy’:使用交叉熵损失函数。
  • metrics=[‘accuracy’]:使用准确率作为评估指标。
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

(六)训练模型

model.fit(x_train, y_train, epochs=5)

(七)评估模型

model.evaluate(x_test, y_test, verbose=2)

(八)将模型的输出转化为概率

probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()
])

(九)预测测试集的前5个样本

predictions = probability_model.predict(x_test[:5])
print(predictions)

二、PyTorch

使用PyTorch来构建、训练和评估一个用于MNIST数据集的神经网络模型

(一)导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

(二)定义神经网络模型

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.dropout = nn.Dropout(0.2)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

(三)数据预处理和加载

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

(四)初始化模型、损失函数和优化器

model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

(五)训练模型

def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

(六)评估模型

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.0f}%)\n')

(七)设置设备为GPU或CPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

(八)运行训练和评估

for epoch in range(1, 6):train(model, device, train_loader, optimizer, epoch)test(model, device, test_loader)

(九)预测测试集的前5个样本

model.eval()
with torch.no_grad():samples = next(iter(test_loader))[0][:5].to(device)output = model(samples)predictions = F.softmax(output, dim=1)print(predictions)

三、TensorFlow和PyTorch代码逻辑上的对比

(一)模型定义

  • 在TensorFlow中,通常使用tf.keras模块来定义模型。可以使用Sequential API或Functional API。
import tensorflow as tf# Sequential API
model = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])# Functional API
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
  • PyTorch中,定义模型时需要继承nn.Module类并实现forward方法
import torch
import torch.nn as nnclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.dense1 = nn.Linear(784, 64)self.relu = nn.ReLU()self.dense2 = nn.Linear(64, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.relu(self.dense1(x))x = self.softmax(self.dense2(x))return xmodel = Model()

(二)数据处理

  • TensorFlow有tf.data模块来处理数据管道
import tensorflow as tfdef preprocess(data):# 数据预处理逻辑return datadataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
dataset = dataset.map(preprocess).batch(32)
  • PyTorch使用torch.utils.data.DataLoader和Dataset类来处理数据管道
import torch
from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):x = self.data[idx]y = self.labels[idx]return x, ydataset = CustomDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

(三)训练过程

  • TensorFlow的tf.keras提供了高阶API来进行模型编译和训练
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=5, batch_size=32)
  • PyTorch中,训练过程需要手动编写,包括前向传播、损失计算、反向传播和优化步骤
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for data, labels in dataloader:optimizer.zero_grad()outputs = model(data)loss = criterion(outputs, labels)loss.backward()optimizer.step()

(四)自动求导

  • TensorFlow在后端自动处理梯度计算和应用
# 使用model.fit自动处理
  • PyTorch的自动求导功能非常灵活,可以使用autograd模块
# 使用loss.backward()和optimizer.step()手动处理

四、TensorFlow和PyTorch的应用

总体来说,PyTorch提供了更多的灵活性和控制,适合需要自定义复杂模型和训练过程的场景。而TensorFlow则更加高级和简洁,适合快速原型和标准模型的开发。

TensorFlow:

  • 高阶API:使用tf.keras简化模型定义、训练和评估,适合快速原型开发和生产部署。
  • 性能优化:支持图计算,优化执行速度和资源使用,适合大规模分布式训练。
  • 广泛生态:拥有丰富的工具和库,如TensorBoard用于可视化,TensorFlow Lite用于移动端部署。
  • 企业支持:由Google支持,广泛应用于工业界,提供稳定的长期支持和更新。

PyTorch:

  • 灵活性:采用动态图计算,代码易于调试和修改,适合研究和实验。
  • 简单直观:符合Python语言习惯,API设计简洁明了,降低学习曲线。
  • 社区活跃:由Facebook支持,拥有活跃的开源社区,快速响应用户需求和改进。
  • 科研应用:广泛应用于学术界,支持多种前沿研究,如自定义损失函数和复杂模型结构。

五、动态图计算

动态图计算是PyTorch的一个显著特点,它让模型的计算图在每次前向传播时动态生成,而不是像TensorFlow那样预先定义和编译。

动态图计算的定义与特性:

  • 动态生成:每次执行前向传播时,计算图都会根据当前输入数据动态构建。
  • 即时调试:允许在代码执行时使用标准的Python调试工具(如pdb),进行逐步调试和检查。
  • 灵活性高:支持更复杂和动态的模型结构,如条件控制流和递归神经网络,更适合研究实验和快速原型开发。

(一)TensorFlow(静态图计算):

在TensorFlow中,计算图是预先定义并编译的。在模型定义和编译之后,图结构固定,随后输入数据进行计算。

import tensorflow as tf# 定义计算图
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.matmul(x, W) + b
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))# 创建会话并执行
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(1000):batch_x, batch_y = ...  # 获取训练数据sess.run(loss, feed_dict={x: batch_x, y: batch_y})

(二)PyTorch(动态图计算):

在PyTorch中,计算图在每次前向传播时动态构建,代码更接近标准的Python编程风格。

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.dense = nn.Linear(784, 10)def forward(self, x):return self.dense(x)model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练过程
for epoch in range(1000):for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

@RequestMapping属性详解及案例演示

RequestMapping源码 Target({ElementType.TYPE, ElementType.METHOD}) Retention(RetentionPolicy.RUNTIME) Documented Mapping public interface RequestMapping {String name() default "";AliasFor("path")String[] value() default {};AliasFor(&quo…

智能写作与痕迹消除:AI在创意文案和论文去痕中的应用

作为一名AI爱好者,我积累了许多实用的AI生成工具。今天,我想分享一些我经常使用的工具,这些工具不仅能帮助提升工作效率,还能激发创意思维。 我们都知道,随着技术的进步,AI生成工具已经变得越来越智能&…

简单分享 for循环,从基础到高级

1. 基础篇:Hello, For Loop! 想象一下,你想给班上的每位同学发送“Hello!”,怎么办?那就是for循环啦, eg:首先有个名字的列表,for循环取出,分别打印 names ["Alice", …

Apache APISIX 介绍

Apache APISIX 是一个动态、实时、高性能的云原生API网关,属于Apache软件基金会旗下的项目。以下是对Apache APISIX的详细介绍: 一、基本概述 定义:Apache APISIX是一个提供丰富流量管理功能的云原生API网关。功能:包括负载均衡…

git出现Permission denied问题

Warning: Permanently added ‘icode.baidu.com,10.11.81.103’ (RSA) to the list of known hosts. Permission denied (baas,keyboard-interactive,publickey). fatal: Could not read from remote repository. Please make sure you have the correct access rights and the…

nodejs操作excel文件实例,读取sheets, 设置cell颜色

本代码是我帮客户做的兼职的实例,涉及用node读取excel文件,遍历sheets,给单元格设置颜色等操作,希望对大家接活有所帮助。 gen.js let dir"D:\\武汉烟厂\\山东区域\\备档资料\\销区零售终端APP维护清单\\走访档案\\2024年6月…

Spring之事务失效的场景

Spring事务失效的场景 异常捕获处理:自己处理了异常,没有抛出。解决:手动抛出抛出检查异常:配置rollbackFor属性为Excetion非public方法导致事务失效,改为public 1、异常捕获处理 示例: 张三1000元&#…

7月形势分析-您下一步该如何做,才能走出困境?

马上工程项目,再有三五天就要结束的了。即便推后也不会超过一周时间了。所以需要考虑将来干啥呢?  一方面就是继续去济宁做建筑工程的活。管吃住,但是因为至亲之间,难免咋说呢,总之还是不太舒服的样子。管事情多&…

bigNumber的部分使用方法与属性

场景:最近做IoT项目的时候碰到一个问题,涉及到双精度浮点型的数据范围的校验问题。业务上其实有三种类型:int、float和double类型三种。他们的范围分别是: //int int: [-2147483648, 2147483647],//float float: [-3402823466385…

PHP7源码结构

PHP7程序的执行过程 1.PHP代码经过词法分析转换为有意义的Token; 2.Token经过语法分析生成AST(Abstract Synstract Syntax Tree,抽象语法树); 3.AST生成对应的opcode,被虚拟机执行。 源码结构&#xff1…

一切为了安全丨2024中国应急(消防)品牌巡展武汉站成功召开!

消防品牌巡展武汉站 6月28日,由中国安全产业协会指导,中国安全产业协会应急创新分会、应急救援产业网联合主办,湖北消防协会协办的“一切为了安全”2024年中国应急(消防)品牌巡展-武汉站成功举办。该巡展旨在展示中国应急(消防&am…

qt QTreeView的简单使用(多级子节点)

MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);setWindowTitle("QTreeView的简单使用");model new QStandardItemModel;model->setHorizontalHeaderLabels(QStringList() << "left&q…

【数据结构 - 时间复杂度和空间复杂度】

文章目录 <center>时间复杂度和空间复杂度算法的复杂度时间复杂度大O的渐进表示法常见时间复杂度计算举例 空间复杂度实例 时间复杂度和空间复杂度 算法的复杂度 算法在编写成可执行程序后&#xff0c;运行时需要耗费时间资源和空间(内存)资源 。因此衡量一个算法的好坏&…

[leetcode]longest-arithmetic-subsequence-of-given-difference. 最长定差子序列

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:int longestSubsequence(vector<int> &arr, int difference) {int ans 0;unordered_map<int, int> dp;for (int v: arr) {dp[v] dp[v - difference] 1;ans max(ans, dp[v]);}return ans…

Qt源码分析:窗体绘制与响应

作为一套开源跨平台的UI代码库&#xff0c;窗体绘制与响应自然是最为基本的功能。在前面的博文中&#xff0c;已就Qt中的元对象系统(反射机制)、事件循环等基础内容进行了分析&#xff0c;并捎带阐述了窗体响应相关的内容。因此&#xff0c;本文着重分析Qt中窗体绘制相关的内容…

ECharts 快速入门

文章目录 1. 引入 ECharts2. 初始化 ECharts 实例3. 配置图表选项4. 使用配置项生成图表5. 最常用的几种图形5.1 柱状图&#xff08;Bar Chart&#xff09;5.2 折线图&#xff08;Line Chart&#xff09;5.3 饼图&#xff08;Pie Chart&#xff09;5.4 散点图&#xff08;Scatt…

如何完成域名解析验证

一&#xff1a;什么是DNS解析&#xff1a; DNS解析是互联网上将人类可读的域名&#xff08;如www.example.com&#xff09;转换为计算机可识别的IP地址&#xff08;如192.0.2.1&#xff09;的过程&#xff0c;大致遵循以下步骤&#xff1a; 查询本地缓存&#xff1a;当用户尝…

Linux内核 -- 多线程之完成量completion的使用

Linux Kernel Completion 使用指南 在Linux内核编程中&#xff0c;completion是一个用于进程同步的机制&#xff0c;常用于等待某个事件的完成。它提供了一种简单的方式&#xff0c;让一个线程等待另一个线程完成某项任务。 基本使用方法 初始化 completion结构需要在使用之…

顺序串算法库构建

学习贺利坚老师顺序串算法库 数据结构之自建算法库——顺序串_创建顺序串s1,创建顺序串s2-CSDN博客 本人详细解析博客 串的概念及操作_串的基本操作-CSDN博客 版本更新日志 V1.0: 在贺利坚老师算法库指导下, 结合本人详细解析博客思路基础上,进行测试, 加入异常弹出信息 v1.0补…

已解决java.awt.geom.NoninvertibleTransformException:在Java2D中无法逆转的转换的正确解决方法,亲测有效!!!

已解决java.awt.geom.NoninvertibleTransformException&#xff1a;在Java2D中无法逆转的转换的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 目录 问题分析 出现问题的场景 报错原因 解决思路 解决方法 1. 检查缩放因子 修改后的缩放变换 …