【PyTorch 攻略】(6-7/7)

一、说明

本篇介绍模型模型的参数,模型推理和使用,保存加载。

二、训练参数和模型

        在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU())def forward(self, x):x = self.flatten(x)logits = self.line

        加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

        注意:请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。如果不这样做,将产生不一致的推理结果。

三、模型推理

        优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能是非常耗时的。
开放式神经网络交换 (ONNX) 运行时为您提供了一种解决方案,只需训练一次,即可在任何硬件、云或边缘设备上加速推理。

        ONNX 是许多供应商支持的一种通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言和框架(如 Java、JavaScript、C# 和 ML.NET)上对模型进行推理。

input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

        我们将使用测试数据集作为示例数据,以便从 ONNX 模型进行推理以进行预测。

test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]

        我们需要使用 onnxruntime 创建一个推理会话。推理会话。为了推断 onnx 模型,我们使用 run 和 pass 输入要返回的输出列表(如果需要所有输出,请留空)和输入值映射。结果是一个输出列表:

session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].nameresult = session.run([output_name], {input_name: x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')

四、torch.utils.data.DataLoader和torch.utils.data.Dataset

        PyTorch有两个基元来处理数据:torch.utils.data.DataLoader和torch.utils.data.Dataset数据集存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象。

%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

        PyTorch提供特定于领域的库,如TorchText,TorchVision和TorchAudio,所有这些都包括数据集。在本教程中,我们将使用TorchVision数据集。

        torchvision.datasets 模块包含许多真实世界视觉数据(如 CIFAR 和 COCO)的数据集对象。在本教程中,我们将使用 FashionMNIST 数据集。每个TorchVision数据集都包含两个参数:转换target_transform分别修改样本和标签。

# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)

        我们将数据集作为参数传递给 DataLoader。这将在我们的数据集上包装一个可迭代对象,并支持自动批处理、采样、随机排序和多进程数据加载。这里我们定义一个 64 的批量大小,即 dataloader 迭代中的每个元素将返回一批 64 个特征和标签。

batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):idx = torch.randint(len(test_data), size=(1,)).item()img, label = test_data[idx]figure.add_subplot(rows, cols, i)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()
Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64

五、创建模型

        为了在 PyTorch 中定义神经网络,我们创建一个继承自 nn.Module 的类。我们在 __init__ 函数中定义网络层,并在转发函数中指定数据如何通过网络。为了加速神经网络的运算,我们将其转移到 GPU(如果可用)。

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))# Define model
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU())def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)
Using cuda device
NeuralNetwork((flatten): Flatten()(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True)(5): ReLU())
)

六、优化模型参数

        为了训练模型,我们需要一个损失函数和一个优化器。我们将使用 nn。交叉熵损失用于损失,随机梯度下降用于优化。

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

        在单个训练循环中,模型对训练数据集进行预测(批量馈送到它),并向后传播预测误差以调整模型的参数。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)

        我们还可以对照测试数据集检查模型的性能,以确保它正在学习。

def test(dataloader, model):size = len(dataloader.dataset)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

        训练过程通过多次迭代(纪元)进行。在每个时期,模型学习参数以做出更好的预测。我们打印模型在每个时期的准确性和损失;我们希望看到精度随着每个时期的增加和损失的减少而减少。

epochs = 15
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model)
print("Done!")
Epoch 1
-------------------------------
loss: 2.295450  [    0/60000]
loss: 2.293073  [ 6400/60000]
loss: 2.278504  [12800/60000]
loss: 2.282501  [19200/60000]
loss: 2.273211  [25600/60000]
loss: 2.258452  [32000/60000]
loss: 2.248237  [38400/60000]
loss: 2.228594  [44800/60000]
loss: 2.240276  [51200/60000]
loss: 2.221318  [57600/60000]
Test Error: Accuracy: 51.8%, Avg loss: 0.034745 Epoch 2
-------------------------------
loss: 2.212354  [    0/60000]
loss: 2.207739  [ 6400/60000]
loss: 2.160400  [12800/60000]
loss: 2.176181  [19200/60000]
loss: 2.168270  [25600/60000]
loss: 2.146453  [32000/60000]
loss: 2.119934  [38400/60000]
loss: 2.083791  [44800/60000]
loss: 2.126453  [51200/60000]
loss: 2.077550  [57600/60000]
Test Error: Accuracy: 53.2%, Avg loss: 0.032452 Epoch 3
-------------------------------
loss: 2.082280  [    0/60000]
loss: 2.068733  [ 6400/60000]
loss: 1.965958  [12800/60000]
loss: 1.997126  [19200/60000]
loss: 2.002057  [25600/60000]
loss: 1.967370  [32000/60000]
loss: 1.910595  [38400/60000]
loss: 1.849006  [44800/60000]
loss: 1.944741  [51200/60000]
loss: 1.861265  [57600/60000]
Test Error: Accuracy: 51.6%, Avg loss: 0.028937 Epoch 4
-------------------------------
loss: 1.872628  [    0/60000]
loss: 1.844543  [ 6400/60000]
loss: 1.710179  [12800/60000]
loss: 1.779804  [19200/60000]
loss: 1.737971  [25600/60000]
loss: 1.746953  [32000/60000]
loss: 1.624768  [38400/60000]
loss: 1.575720  [44800/60000]
loss: 1.742827  [51200/60000]
loss: 1.653375  [57600/60000]
Test Error: Accuracy: 58.4%, Avg loss: 0.025570 Epoch 5
-------------------------------
loss: 1.662315  [    0/60000]
loss: 1.636235  [ 6400/60000]
loss: 1.508407  [12800/60000]
loss: 1.606842  [19200/60000]
loss: 1.560728  [25600/60000]
loss: 1.606024  [32000/60000]
loss: 1.426900  [38400/60000]
loss: 1.406240  [44800/60000]
loss: 1.619918  [51200/60000]
loss: 1.521326  [57600/60000]
Test Error: Accuracy: 61.2%, Avg loss: 0.023459 Epoch 6
-------------------------------
loss: 1.527535  [    0/60000]
loss: 1.511209  [ 6400/60000]
loss: 1.377129  [12800/60000]
loss: 1.494889  [19200/60000]
loss: 1.457990  [25600/60000]
loss: 1.502333  [32000/60000]
loss: 1.291539  [38400/60000]
loss: 1.285098  [44800/60000]
loss: 1.484891  [51200/60000]
loss: 1.414015  [57600/60000]
Test Error: Accuracy: 62.2%, Avg loss: 0.021480 Epoch 7
-------------------------------
loss: 1.376779  [    0/60000]
loss: 1.384830  [ 6400/60000]
loss: 1.230116  [12800/60000]
loss: 1.382574  [19200/60000]
loss: 1.255630  [25600/60000]
loss: 1.396211  [32000/60000]
loss: 1.157718  [38400/60000]
loss: 1.186382  [44800/60000]
loss: 1.340606  [51200/60000]
loss: 1.321607  [57600/60000]
Test Error: Accuracy: 62.8%, Avg loss: 0.019737 Epoch 8
-------------------------------
loss: 1.243344  [    0/60000]
loss: 1.279124  [ 6400/60000]
loss: 1.121769  [12800/60000]
loss: 1.293069  [19200/60000]
loss: 1.128232  [25600/60000]
loss: 1.315465  [32000/60000]
loss: 1.069528  [38400/60000]
loss: 1.123324  [44800/60000]
loss: 1.243827  [51200/60000]
loss: 1.255190  [57600/60000]
Test Error: Accuracy: 63.4%, Avg loss: 0.018518 Epoch 9
-------------------------------
loss: 1.154148  [    0/60000]
loss: 1.205280  [ 6400/60000]
loss: 1.046463  [12800/60000]
loss: 1.229866  [19200/60000]
loss: 1.048813  [25600/60000]
loss: 1.254785  [32000/60000]
loss: 1.010614  [38400/60000]
loss: 1.077114  [44800/60000]
loss: 1.176766  [51200/60000]
loss: 1.206567  [57600/60000]
Test Error: Accuracy: 64.3%, Avg loss: 0.017640 Epoch 10
-------------------------------
loss: 1.090360  [    0/60000]
loss: 1.149150  [ 6400/60000]
loss: 0.990786  [12800/60000]
loss: 1.183704  [19200/60000]
loss: 0.997114  [25600/60000]
loss: 1.207199  [32000/60000]
loss: 0.967512  [38400/60000]
loss: 1.043431  [44800/60000]
loss: 1.127000  [51200/60000]
loss: 1.169639  [57600/60000]
Test Error: Accuracy: 65.3%, Avg loss: 0.016974 Epoch 11
-------------------------------
loss: 1.041194  [    0/60000]
loss: 1.104409  [ 6400/60000]
loss: 0.947670  [12800/60000]
loss: 1.149421  [19200/60000]
loss: 0.960403  [25600/60000]
loss: 1.169899  [32000/60000]
loss: 0.935149  [38400/60000]
loss: 1.018250  [44800/60000]
loss: 1.088222  [51200/60000]
loss: 1.139813  [57600/60000]
Test Error: Accuracy: 66.2%, Avg loss: 0.016446 Epoch 12
-------------------------------
loss: 1.000646  [    0/60000]
loss: 1.067356  [ 6400/60000]
loss: 0.912046  [12800/60000]
loss: 1.122742  [19200/60000]
loss: 0.932827  [25600/60000]
loss: 1.138785  [32000/60000]
loss: 0.910242  [38400/60000]
loss: 0.999010  [44800/60000]
loss: 1.056596  [51200/60000]
loss: 1.114582  [57600/60000]
Test Error: Accuracy: 67.5%, Avg loss: 0.016011 Epoch 13
-------------------------------
loss: 0.966393  [    0/60000]
loss: 1.035691  [ 6400/60000]
loss: 0.881672  [12800/60000]
loss: 1.100845  [19200/60000]
loss: 0.910265  [25600/60000]
loss: 1.112597  [32000/60000]
loss: 0.889558  [38400/60000]
loss: 0.982751  [44800/60000]
loss: 1.029199  [51200/60000]
loss: 1.092738  [57600/60000]
Test Error: Accuracy: 68.5%, Avg loss: 0.015636 Epoch 14
-------------------------------
loss: 0.936334  [    0/60000]
loss: 1.007734  [ 6400/60000]
loss: 0.854663  [12800/60000]
loss: 1.081601  [19200/60000]
loss: 0.890581  [25600/60000]
loss: 1.089641  [32000/60000]
loss: 0.872057  [38400/60000]
loss: 0.969192  [44800/60000]
loss: 1.005193  [51200/60000]
loss: 1.073098  [57600/60000]
Test Error: Accuracy: 69.4%, Avg loss: 0.015304 Epoch 15
-------------------------------
loss: 0.908971  [    0/60000]
loss: 0.982067  [ 6400/60000]
loss: 0.830095  [12800/60000]
loss: 1.064921  [19200/60000]
loss: 0.874204  [25600/60000]
loss: 1.069008  [32000/60000]
loss: 0.856447  [38400/60000]
loss: 0.957340  [44800/60000]
loss: 0.983547  [51200/60000]
loss: 1.055251  [57600/60000]
Test Error: Accuracy: 70.3%, Avg loss: 0.015001 Done!

        准确性最初不会很好(没关系!尝试运行循环以获取更多纪元或将learning_rate调整为更大的数字。也可能是我们选择的模型配置可能不是此类问题的最佳配置。

七、保存模型

        保存模型的常用方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "data/model.pth")
print("Saved PyTorch Model State to model.pth")

八、负载模型

        加载模型的过程包括重新创建模型结构并将状态字典加载到其中。

model = NeuralNetwork()
model.load_state_dict(torch.load("data/model.pth"))

        此模型现在可用于进行预测。

classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"

        祝贺!您已经完成了 PyTorch 初学者教程!我们希望本教程能帮助您在 PyTorch 上开始深度学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/85187.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 刷题记录——从零开始记录自己一些不会的(二)

20. 替换后的最长重复字符 题意 给你一个字符串 s 和一个整数 k 。你可以选择字符串中的任一字符,并将其更改为任何其他大写英文字符。该操作最多可执行 k 次。 在执行上述操作后,返回包含相同字母的最长子字符串的长度。 思路 代码 class Solution…

DOMBOM

DOM Document Object Model:文档对象模型 DOM树 文档:一个页面就是一个文档; 节点:网页中的所有内容,在文档树中都是节点,使用node表示; DOM操作节点实现网页特效的步骤: 获取ht…

软件需求文档、设计文档、开发文档、运维文档大全

在软件开发过程中,文档扮演着至关重要的角色。它不仅记录了项目的需求、设计和开发过程,还为项目的维护和管理提供了便利。本文将详细介绍软件开发文档的重要性和作用,以及需求分析、软件设计、开发过程、运维管理和项目管理等方面的文档要求…

DA5 网站用户没有补全的信息

目录 1.题目描述 2.输入描述 3.输出描述 4.题目分析 5.通过代码 1.题目描述 现有一个Nowcoder.csv文件,它记录了牛客网的部分用户数据,包含如下字段(字段与字段之间以逗号间隔): Nowcoder_ID:用户ID …

Vosviewer的安装与使用

Vosviewer的安装与使用 1 安装2 使用参考: 关于vosviewer我就不过多介绍了。 vosviewer与citespace有什么区别?在这里可以引用一下知乎的文章简要说明一下: 1.操作难易VOSviewer很简单,在官网下载的时候会附带一个英文手册,稍微…

Linux命令基础

一、linux目录结构。 Linux没有windows的盘的概念,是一个树形的结构。唯一的根目录为/,所有的文件都在他下面。 描述方式也与windows有所不同 二、命令基础格式。 command [-options] [parameter]([ ]表示可选的) command:必…

以太坊代币标准ERC20、ERC721

两个概念 ERC(Ethereum Request for Comment) 以太坊意见征集稿EIP(Ethereum Improvement Proposals)以太坊改进提案 ERC和EIP用于使得以太坊更加完善;在ERC中提出了很多标准,用的最多的标准就是它的Token标准; 有哪些标准详细见https://eips.ethereum…

【lesson7】yum的介绍及使用

文章目录 预备工作yum的基本过程yum的操作**yum源问题:****yum三板斧:**yum listyum searchyum list | grepyum installyum install -yyum removeyum remove -y 预备工作 首先有三个问题: 问题解答: 这里我们联想到了手机 问题…

论文阅读:AugGAN: Cross Domain Adaptation with GAN-based Data Augmentation

Abstract 基于GAN的图像转换方法存在两个缺陷:保留图像目标和保持图像转换前后的一致性,这导致不能用它生成大量不同域的训练数据。论文提出了一种结构感知(Structure-aware)的图像转换网络(image-to-image translation network)。 Proposed Framework…

5.5V-65V Vin同步降压控制器,具有线路前馈SCT82630DHKR

描述: SCT82630是一款65V电压模式控制同步降压控制器,具有线路前馈。40ns受控高压侧MOSFET的最小导通时间支持高转换比,实现从48V输入到低压轨的直接降压转换,降低了系统复杂性和解决方案成本。如果需要,在低至6V的输…

Java“牵手”义乌购商品详情数据,义乌购商品详情接口,义乌购API接口申请指南

义乌购隶属浙江义乌购电子商务有限公司旗下网站。该平台定位为依托实体市场,服务实体市场,以诚信为根本,将7万网上商铺与实体商铺一一对应绑定,为采购商和经营户提供可控、可信、可溯源的交易保障。 义乌购平台现有商铺商品、市场…

OpenHarmony Meetup常州站招募令

OpenHarmony Meetup 常州站正火热招募中! 诚邀充满激情的开发者参与线下盛会~ 探索OpenHarmony前沿科技,畅谈未来前景, 感受OpenHarmony生态构建之路的魅力! 线下参与,名额有限,仅限20位幸运者&#xff01…

计算机网络常见问题

1.谈一谈对OSI七层模型和TCP/IP四层模型的理解? 1.1.为什么要分层? 在计算机中网络是个复杂的系统,不同的网络与网络之间由于协议,设备,软件等各种原因在协调和通讯时容易产生各种各样的问题。例如:各物流…

【MySQL】 MySQL的增删改查(进阶)--贰

文章目录 🛫新增🛬查询🌴聚合查询🚩聚合函数🎈GROUP BY子句📌HAVING 🎋联合查询⚾内连接⚽外连接🧭自连接🏀子查询🎡合并查询 🎨MySQL的增删改查(…

web前端之float布局与flex布局

float布局 <style>.nav {overflow: hidden;background-color: #6adfd0; /* 导航栏背景颜色 */}.nav a {float: left;display: block;text-align: center;padding: 14px 16px;text-decoration: none;color: #000000; /* 导航栏文字颜色 */}.nav a:hover {background-col…

测试基础知识

测试的基础知识 软件&#xff1a;控制硬件工作的工具。 软件测试&#xff1a;使用技术手段验证软件是否满足使用需求。 软件测试目的&#xff1a;减少软件缺陷&#xff0c;保证软件质量。 测试主流技能&#xff1a;功能测试、自动化测试、接口测试、性能测试 功能测试&#xff…

三相组合式过电压保护器试验

三相组合式过电压保护器试验 试验目的 三相组合式过电压保护器主要分为有带串联间隙过压保护器和无间隙过压保护器两大类&#xff0c;其试验项目内容要求分别使用高压工频交流和高压直流电源。 三相组合式过电压保护器试验&#xff0c;主要是为了及早发现设备内部绝缘受潮及…

解锁学习新方式——助您迈向成功之路

近年来&#xff0c;随着吉林开放大学广播电视大学的崛起&#xff0c;越来越多的学子选择这所优秀的学府来实现自己的梦想。而作为一名学者&#xff0c;我有幸见证了电大搜题微信公众号的诞生&#xff0c;为广大学子提供了一个全新的学习支持平台。 电大搜题微信公众号&#xff…

《深度学习工业缺陷检测》专栏介绍 CSDN独家改进实战

&#x1f4a1;&#x1f4a1;&#x1f4a1;深度学习工业缺陷检测 1&#xff09;提供工业小缺陷检测性能提升方案&#xff0c;满足部署条件&#xff1b; 2&#xff09;针对缺陷样品少等难点&#xff0c;引入无监督检测&#xff1b; 3&#xff09;深度学习 C、C#部署方案&#…

2023工博会强势回归!智微工业携八大系列重磅亮相

中国国际工业博览会&#xff08;简称"中国工博会"&#xff09;自1999年创办以来&#xff0c;历经二十余年发展创新&#xff0c;通过专业化、市场化、国际化、品牌化运作&#xff0c;已发展成为通过国际展览业协会&#xff08;UFI&#xff09;认证、中国工业领域规模最…