【PyTorch 攻略】(6-7/7)

一、说明

本篇介绍模型模型的参数,模型推理和使用,保存加载。

二、训练参数和模型

        在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU())def forward(self, x):x = self.flatten(x)logits = self.line

        加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

        注意:请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。如果不这样做,将产生不一致的推理结果。

三、模型推理

        优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能是非常耗时的。
开放式神经网络交换 (ONNX) 运行时为您提供了一种解决方案,只需训练一次,即可在任何硬件、云或边缘设备上加速推理。

        ONNX 是许多供应商支持的一种通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言和框架(如 Java、JavaScript、C# 和 ML.NET)上对模型进行推理。

input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

        我们将使用测试数据集作为示例数据,以便从 ONNX 模型进行推理以进行预测。

test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]

        我们需要使用 onnxruntime 创建一个推理会话。推理会话。为了推断 onnx 模型,我们使用 run 和 pass 输入要返回的输出列表(如果需要所有输出,请留空)和输入值映射。结果是一个输出列表:

session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].nameresult = session.run([output_name], {input_name: x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')

四、torch.utils.data.DataLoader和torch.utils.data.Dataset

        PyTorch有两个基元来处理数据:torch.utils.data.DataLoader和torch.utils.data.Dataset数据集存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象。

%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

        PyTorch提供特定于领域的库,如TorchText,TorchVision和TorchAudio,所有这些都包括数据集。在本教程中,我们将使用TorchVision数据集。

        torchvision.datasets 模块包含许多真实世界视觉数据(如 CIFAR 和 COCO)的数据集对象。在本教程中,我们将使用 FashionMNIST 数据集。每个TorchVision数据集都包含两个参数:转换target_transform分别修改样本和标签。

# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)

        我们将数据集作为参数传递给 DataLoader。这将在我们的数据集上包装一个可迭代对象,并支持自动批处理、采样、随机排序和多进程数据加载。这里我们定义一个 64 的批量大小,即 dataloader 迭代中的每个元素将返回一批 64 个特征和标签。

batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):idx = torch.randint(len(test_data), size=(1,)).item()img, label = test_data[idx]figure.add_subplot(rows, cols, i)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()
Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64

五、创建模型

        为了在 PyTorch 中定义神经网络,我们创建一个继承自 nn.Module 的类。我们在 __init__ 函数中定义网络层,并在转发函数中指定数据如何通过网络。为了加速神经网络的运算,我们将其转移到 GPU(如果可用)。

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))# Define model
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU())def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)
Using cuda device
NeuralNetwork((flatten): Flatten()(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True)(5): ReLU())
)

六、优化模型参数

        为了训练模型,我们需要一个损失函数和一个优化器。我们将使用 nn。交叉熵损失用于损失,随机梯度下降用于优化。

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

        在单个训练循环中,模型对训练数据集进行预测(批量馈送到它),并向后传播预测误差以调整模型的参数。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)

        我们还可以对照测试数据集检查模型的性能,以确保它正在学习。

def test(dataloader, model):size = len(dataloader.dataset)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

        训练过程通过多次迭代(纪元)进行。在每个时期,模型学习参数以做出更好的预测。我们打印模型在每个时期的准确性和损失;我们希望看到精度随着每个时期的增加和损失的减少而减少。

epochs = 15
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model)
print("Done!")
Epoch 1
-------------------------------
loss: 2.295450  [    0/60000]
loss: 2.293073  [ 6400/60000]
loss: 2.278504  [12800/60000]
loss: 2.282501  [19200/60000]
loss: 2.273211  [25600/60000]
loss: 2.258452  [32000/60000]
loss: 2.248237  [38400/60000]
loss: 2.228594  [44800/60000]
loss: 2.240276  [51200/60000]
loss: 2.221318  [57600/60000]
Test Error: Accuracy: 51.8%, Avg loss: 0.034745 Epoch 2
-------------------------------
loss: 2.212354  [    0/60000]
loss: 2.207739  [ 6400/60000]
loss: 2.160400  [12800/60000]
loss: 2.176181  [19200/60000]
loss: 2.168270  [25600/60000]
loss: 2.146453  [32000/60000]
loss: 2.119934  [38400/60000]
loss: 2.083791  [44800/60000]
loss: 2.126453  [51200/60000]
loss: 2.077550  [57600/60000]
Test Error: Accuracy: 53.2%, Avg loss: 0.032452 Epoch 3
-------------------------------
loss: 2.082280  [    0/60000]
loss: 2.068733  [ 6400/60000]
loss: 1.965958  [12800/60000]
loss: 1.997126  [19200/60000]
loss: 2.002057  [25600/60000]
loss: 1.967370  [32000/60000]
loss: 1.910595  [38400/60000]
loss: 1.849006  [44800/60000]
loss: 1.944741  [51200/60000]
loss: 1.861265  [57600/60000]
Test Error: Accuracy: 51.6%, Avg loss: 0.028937 Epoch 4
-------------------------------
loss: 1.872628  [    0/60000]
loss: 1.844543  [ 6400/60000]
loss: 1.710179  [12800/60000]
loss: 1.779804  [19200/60000]
loss: 1.737971  [25600/60000]
loss: 1.746953  [32000/60000]
loss: 1.624768  [38400/60000]
loss: 1.575720  [44800/60000]
loss: 1.742827  [51200/60000]
loss: 1.653375  [57600/60000]
Test Error: Accuracy: 58.4%, Avg loss: 0.025570 Epoch 5
-------------------------------
loss: 1.662315  [    0/60000]
loss: 1.636235  [ 6400/60000]
loss: 1.508407  [12800/60000]
loss: 1.606842  [19200/60000]
loss: 1.560728  [25600/60000]
loss: 1.606024  [32000/60000]
loss: 1.426900  [38400/60000]
loss: 1.406240  [44800/60000]
loss: 1.619918  [51200/60000]
loss: 1.521326  [57600/60000]
Test Error: Accuracy: 61.2%, Avg loss: 0.023459 Epoch 6
-------------------------------
loss: 1.527535  [    0/60000]
loss: 1.511209  [ 6400/60000]
loss: 1.377129  [12800/60000]
loss: 1.494889  [19200/60000]
loss: 1.457990  [25600/60000]
loss: 1.502333  [32000/60000]
loss: 1.291539  [38400/60000]
loss: 1.285098  [44800/60000]
loss: 1.484891  [51200/60000]
loss: 1.414015  [57600/60000]
Test Error: Accuracy: 62.2%, Avg loss: 0.021480 Epoch 7
-------------------------------
loss: 1.376779  [    0/60000]
loss: 1.384830  [ 6400/60000]
loss: 1.230116  [12800/60000]
loss: 1.382574  [19200/60000]
loss: 1.255630  [25600/60000]
loss: 1.396211  [32000/60000]
loss: 1.157718  [38400/60000]
loss: 1.186382  [44800/60000]
loss: 1.340606  [51200/60000]
loss: 1.321607  [57600/60000]
Test Error: Accuracy: 62.8%, Avg loss: 0.019737 Epoch 8
-------------------------------
loss: 1.243344  [    0/60000]
loss: 1.279124  [ 6400/60000]
loss: 1.121769  [12800/60000]
loss: 1.293069  [19200/60000]
loss: 1.128232  [25600/60000]
loss: 1.315465  [32000/60000]
loss: 1.069528  [38400/60000]
loss: 1.123324  [44800/60000]
loss: 1.243827  [51200/60000]
loss: 1.255190  [57600/60000]
Test Error: Accuracy: 63.4%, Avg loss: 0.018518 Epoch 9
-------------------------------
loss: 1.154148  [    0/60000]
loss: 1.205280  [ 6400/60000]
loss: 1.046463  [12800/60000]
loss: 1.229866  [19200/60000]
loss: 1.048813  [25600/60000]
loss: 1.254785  [32000/60000]
loss: 1.010614  [38400/60000]
loss: 1.077114  [44800/60000]
loss: 1.176766  [51200/60000]
loss: 1.206567  [57600/60000]
Test Error: Accuracy: 64.3%, Avg loss: 0.017640 Epoch 10
-------------------------------
loss: 1.090360  [    0/60000]
loss: 1.149150  [ 6400/60000]
loss: 0.990786  [12800/60000]
loss: 1.183704  [19200/60000]
loss: 0.997114  [25600/60000]
loss: 1.207199  [32000/60000]
loss: 0.967512  [38400/60000]
loss: 1.043431  [44800/60000]
loss: 1.127000  [51200/60000]
loss: 1.169639  [57600/60000]
Test Error: Accuracy: 65.3%, Avg loss: 0.016974 Epoch 11
-------------------------------
loss: 1.041194  [    0/60000]
loss: 1.104409  [ 6400/60000]
loss: 0.947670  [12800/60000]
loss: 1.149421  [19200/60000]
loss: 0.960403  [25600/60000]
loss: 1.169899  [32000/60000]
loss: 0.935149  [38400/60000]
loss: 1.018250  [44800/60000]
loss: 1.088222  [51200/60000]
loss: 1.139813  [57600/60000]
Test Error: Accuracy: 66.2%, Avg loss: 0.016446 Epoch 12
-------------------------------
loss: 1.000646  [    0/60000]
loss: 1.067356  [ 6400/60000]
loss: 0.912046  [12800/60000]
loss: 1.122742  [19200/60000]
loss: 0.932827  [25600/60000]
loss: 1.138785  [32000/60000]
loss: 0.910242  [38400/60000]
loss: 0.999010  [44800/60000]
loss: 1.056596  [51200/60000]
loss: 1.114582  [57600/60000]
Test Error: Accuracy: 67.5%, Avg loss: 0.016011 Epoch 13
-------------------------------
loss: 0.966393  [    0/60000]
loss: 1.035691  [ 6400/60000]
loss: 0.881672  [12800/60000]
loss: 1.100845  [19200/60000]
loss: 0.910265  [25600/60000]
loss: 1.112597  [32000/60000]
loss: 0.889558  [38400/60000]
loss: 0.982751  [44800/60000]
loss: 1.029199  [51200/60000]
loss: 1.092738  [57600/60000]
Test Error: Accuracy: 68.5%, Avg loss: 0.015636 Epoch 14
-------------------------------
loss: 0.936334  [    0/60000]
loss: 1.007734  [ 6400/60000]
loss: 0.854663  [12800/60000]
loss: 1.081601  [19200/60000]
loss: 0.890581  [25600/60000]
loss: 1.089641  [32000/60000]
loss: 0.872057  [38400/60000]
loss: 0.969192  [44800/60000]
loss: 1.005193  [51200/60000]
loss: 1.073098  [57600/60000]
Test Error: Accuracy: 69.4%, Avg loss: 0.015304 Epoch 15
-------------------------------
loss: 0.908971  [    0/60000]
loss: 0.982067  [ 6400/60000]
loss: 0.830095  [12800/60000]
loss: 1.064921  [19200/60000]
loss: 0.874204  [25600/60000]
loss: 1.069008  [32000/60000]
loss: 0.856447  [38400/60000]
loss: 0.957340  [44800/60000]
loss: 0.983547  [51200/60000]
loss: 1.055251  [57600/60000]
Test Error: Accuracy: 70.3%, Avg loss: 0.015001 Done!

        准确性最初不会很好(没关系!尝试运行循环以获取更多纪元或将learning_rate调整为更大的数字。也可能是我们选择的模型配置可能不是此类问题的最佳配置。

七、保存模型

        保存模型的常用方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "data/model.pth")
print("Saved PyTorch Model State to model.pth")

八、负载模型

        加载模型的过程包括重新创建模型结构并将状态字典加载到其中。

model = NeuralNetwork()
model.load_state_dict(torch.load("data/model.pth"))

        此模型现在可用于进行预测。

classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"

        祝贺!您已经完成了 PyTorch 初学者教程!我们希望本教程能帮助您在 PyTorch 上开始深度学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/85187.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 刷题记录——从零开始记录自己一些不会的(二)

20. 替换后的最长重复字符 题意 给你一个字符串 s 和一个整数 k 。你可以选择字符串中的任一字符,并将其更改为任何其他大写英文字符。该操作最多可执行 k 次。 在执行上述操作后,返回包含相同字母的最长子字符串的长度。 思路 代码 class Solution…

多重视窗管理程序 screen

当我们在使用 MobaXterm/XShell 进行远程访问服务器时,进行远程访问的界面往往不能关掉,否则,程序将不再运行。而且,程序在运行的过程中,还必须时刻保证网络的通常,这些条件都很难得到满足。 为了解决上述…

【个人笔记本】本地化部署详细流程 LLaMA中文模型:Chinese-LLaMA-Alpaca-2

不推荐小白,环境配置比较复杂 全部流程 下载原始模型:Chinese-LLaMA-Alpaca-2linux部署llamacpp环境使用llamacpp将Chinese-LLaMA-Alpaca-2模型转换为gguf模型windows部署Text generation web UI 环境使用Text generation web UI 加载模型并进行对话 准…

DOMBOM

DOM Document Object Model:文档对象模型 DOM树 文档:一个页面就是一个文档; 节点:网页中的所有内容,在文档树中都是节点,使用node表示; DOM操作节点实现网页特效的步骤: 获取ht…

软件需求文档、设计文档、开发文档、运维文档大全

在软件开发过程中,文档扮演着至关重要的角色。它不仅记录了项目的需求、设计和开发过程,还为项目的维护和管理提供了便利。本文将详细介绍软件开发文档的重要性和作用,以及需求分析、软件设计、开发过程、运维管理和项目管理等方面的文档要求…

DA5 网站用户没有补全的信息

目录 1.题目描述 2.输入描述 3.输出描述 4.题目分析 5.通过代码 1.题目描述 现有一个Nowcoder.csv文件,它记录了牛客网的部分用户数据,包含如下字段(字段与字段之间以逗号间隔): Nowcoder_ID:用户ID …

Vosviewer的安装与使用

Vosviewer的安装与使用 1 安装2 使用参考: 关于vosviewer我就不过多介绍了。 vosviewer与citespace有什么区别?在这里可以引用一下知乎的文章简要说明一下: 1.操作难易VOSviewer很简单,在官网下载的时候会附带一个英文手册,稍微…

Linux命令基础

一、linux目录结构。 Linux没有windows的盘的概念,是一个树形的结构。唯一的根目录为/,所有的文件都在他下面。 描述方式也与windows有所不同 二、命令基础格式。 command [-options] [parameter]([ ]表示可选的) command:必…

SMB:使用 Ansible 自动化配置 samba 客户端服务端

写在前面 考试顺便整理博文内容整理 使用 Ansible 部署 samba 客户端和服务端理解不足小伙伴帮忙指正 对每个人而言,真正的职责只有一个:找到自我。然后在心中坚守其一生,全心全意,永不停息。所有其它的路都是不完整的&#xff0c…

百度之星(数学基础题)

糖果促销 小度最喜欢吃糖啦!!! 这天商店糖果促销,可给小度高兴坏了。 促销规则:一颗糖果有一张糖纸,p 张糖纸可以换取一颗糖果。换出来糖果的包装纸当然也能再换糖果。 小度想吃 k 颗糖果,他…

以太坊代币标准ERC20、ERC721

两个概念 ERC(Ethereum Request for Comment) 以太坊意见征集稿EIP(Ethereum Improvement Proposals)以太坊改进提案 ERC和EIP用于使得以太坊更加完善;在ERC中提出了很多标准,用的最多的标准就是它的Token标准; 有哪些标准详细见https://eips.ethereum…

三行代码实现图像画质修复,图片清晰度修复,清晰度提升python

核心代码 # 原始文件 enhancer ImageEnhance.Sharpness(Image.open(文件路径.png)) # 增强图片 img_enhanced enhancer.enhance(增强系数float) # 输出目标文件 img_enhanced.save(文件名.png)注意,输入输出文件格式必须一致 所需依赖 # 文件选择框&#xff0c…

【lesson7】yum的介绍及使用

文章目录 预备工作yum的基本过程yum的操作**yum源问题:****yum三板斧:**yum listyum searchyum list | grepyum installyum install -yyum removeyum remove -y 预备工作 首先有三个问题: 问题解答: 这里我们联想到了手机 问题…

论文阅读:AugGAN: Cross Domain Adaptation with GAN-based Data Augmentation

Abstract 基于GAN的图像转换方法存在两个缺陷:保留图像目标和保持图像转换前后的一致性,这导致不能用它生成大量不同域的训练数据。论文提出了一种结构感知(Structure-aware)的图像转换网络(image-to-image translation network)。 Proposed Framework…

Golang 字符串

目录 1. Golang 字符串1.1. 基础概念1.2. 字符串编码1.3. 遍历字符串1.4. 类型转换1.5. 总结1.6. String Concatenation (字符串连接)1.6.1. Using the operator1.6.2. Using the operator1.6.3. Using the Join method1.6.4. Using Sprintf method1.6.5. Using Go string Bu…

K-means 聚类算法学习笔记

K-means 聚类算法 是一种无监督学习算法,用来将 n n n 个样本点分成 k k k 类,使得整个数据集的误差平方和 S S E SSE SSE 最小。在本例中,样本点是指平面直角坐标系上的点,聚类中心也是平面直角坐标系上的点,而每个…

5.5V-65V Vin同步降压控制器,具有线路前馈SCT82630DHKR

描述: SCT82630是一款65V电压模式控制同步降压控制器,具有线路前馈。40ns受控高压侧MOSFET的最小导通时间支持高转换比,实现从48V输入到低压轨的直接降压转换,降低了系统复杂性和解决方案成本。如果需要,在低至6V的输…

Java“牵手”义乌购商品详情数据,义乌购商品详情接口,义乌购API接口申请指南

义乌购隶属浙江义乌购电子商务有限公司旗下网站。该平台定位为依托实体市场,服务实体市场,以诚信为根本,将7万网上商铺与实体商铺一一对应绑定,为采购商和经营户提供可控、可信、可溯源的交易保障。 义乌购平台现有商铺商品、市场…

网络连接中的三次握手和四次挥手

三次握手和四次挥手都是TCP协议通信过程中建立和关闭连接的步骤。 三次握手的步骤如下: 客户端发送SYN包,进入SYN-SENT状态。服务器接收到SYN包,回复一个ACK包和一个SYN包,进入SYN-RECEIVED状态。客户端收到ACK包和SYN包&#x…

Fair原理篇Fair逻辑动态化架构设计与实现

本文的核心内容包括: 数据逻辑处理布局中的逻辑处理Flutter类型数据处理一、数据逻辑处理 我们接触的每一个Flutter界面,大多由布局和逻辑相关的代码组成。如Flutter初始工程的Counting Demo的代码: class _MyHomePageState extends State<MyHomePage> {// 变量 int…