【AI】Pytorch神经网络分类初探

Pytorch神经网络分类初探

1.数据准备

环境采用之前创建的Anaconda虚拟环境pytorch,为了方便查看每一步的返回值,可以使用Jupyter Notebook来进行开发。首先把需要的包导入进来

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

torch框架的数据输入依赖两个基类:torch.utils.data.DataLoader和torch.utils.data.Dataset,Dataset 存储样本及其相应的标签,DataLoader 将 Dataset 封装为迭代器。

为了方便使用数据,我们采用Mnist数据集

%matplotlib inline
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)

等待数据下载完毕,然后将数据读入进来。

import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

读入进来的数据并不是tensor格式的,需要将其转化成Tensor格式

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)

最重要的一步,将其转换成dataset和dataloader

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

这样就完成了数据准备的工作

2.定义模型

这边直接引用官网教程的模型

# Get cpu, gpu or mps device for training.
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()#self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):#x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

将打印的结果放在下面,可以查看一下

Using cuda device
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)

3.定义模型损失函数和优化器

这里我们依旧使用官网教程中的直接来

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

这里的SGD是最基础的优化器,采用的是梯度递减的方式,其收敛的会比较慢,如果希望收敛快些,可以使用Adam方式。

4. 定义训练和测试函数

训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

5.开始训练

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dl, model, loss_fn, optimizer)test(valid_dl, model, loss_fn)
print("Done!")

6.模型保存

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

7.模型加载和使用模型预测

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

模型预测

classes = ["0","1","2","3","4","5","6","7","8","9",
]model.eval()
x, y = train_ds[2][0], train_ds[2][1]
with torch.no_grad():x = x.to(device)pred = model(x)print(pred)predicted, actual = classes[pred.argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

使用SGD优化器训练,训练5次的最高精度为76%,而使用Adam优化器第一个epoch的精度就已经达到了97%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206781.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RHCE】openlab搭建web网站

网站需求: 1、基于域名 www.openlab.com 可以访问网站内容为 welcome to openlab!!! 增加映射 [rootlocalhost ~]# vim /etc/hosts 创建网页 [rootlocalhost ~]# mkdir -p /www/openlab [rootlocalhost ~]# echo welcome to openlab > /www/openlab/index.h…

利用法线贴图渲染逼真的3D老虎模型

在线工具推荐: 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时,有几种不同的风格&#xf…

傅里叶变换在图像中的应用

1.图像增强与图像去噪 绝大部分噪音都是图像的高频分量,通过低通滤波器来滤除高频——噪声; 边缘也是图像的高频分量,可以通过添加高频分量来增强原始图像的边缘; 2.图像分割之边缘检测 提取图像高频分量 3.图像特征提取: 形状特…

3-Mybatis

文章目录 Mybatis概述什么是Mybatis?Mybatis导入知识补充数据持久化持久层 第一个Mybatis程序:数据的增删改查查创建环境编写代码1、目录结构2、核心配置文件:resources/mybatis-config.xml3、mybatis工具类:Utils/MybatisUtils4、…

ALNS的MDP模型| 还没整理完12-08

有好几篇论文已经这样做了,先摆出一篇,然后再慢慢更新 第一篇 该篇论文提出了一种称为深增强ALNS(DR-ALNS)的方法,它利用DRL选择最有效的破坏和修复运营商,配置破坏严重性参数施加在破坏算子上&#xff0c…

请简要介绍一下HTML的发展史?

问题:什么是池化思想? 回答: 池化思想是一种资源管理的策略,通过事先创建并维护一组已经初始化好的资源对象池,以便在需要时快速获取资源并在用完后归还给池,以减少资源的创建和销毁开销,提高资…

第二十一章网络通信总结

21.1 网络程序设计基础 Java网络程序设计基础涉及使用Java编程语言创建网络应用程序。这通常涉及到使用Java的网络API,如java.net包,以建立客户端和服务器之间的通信。 基本步骤包括: 1.创建服务器: 使用ServerSocket类创建服务…

常见的中间件--消息队列中间件测试点

最近刷题,看到了有问中间件的题目,于是整理了一些中间件的知识,大多是在小破站上的笔记,仅供大家参考~ 主要分为七个部分来分享: 一、常见的中间件 二、什么是队列? 三、常见消息队列MQ的比较 四、队列…

用户管理 --汇总

一、第一节课 1.1 本人写的 前端: 鱼皮 --> 用户中心 第1节课-CSDN博客 中期: 一、用户管理 第1节课中间-CSDN博客 后端: 一、用户管理-CSDN博客 其他的链接 亿图脑图MindMaster 1.2 优秀球友,推荐 Docs 另…

12_企业架构之Tomcat部署使用

Tomcat 学习目标和内容 1、能够描述Tomcat的使用场景 2、能够简单描述Tomcat的工作原理 3、能够实现部署安装Tomcat 4、能够实现配置Tomcat的service服务和自启动 5、能够实现Tomcat的Host的配置 6、能够实现Nginx反向代理Tomcat 7、能够实现Nginx负载均衡到Tomcat 一、Tomcat介…

Abaqus许可证配置文件问题

在使用Abaqus工程设计和仿真软件时,您可能会遇到许可证配置文件问题。这些问题可能会影响软件的正常运行和工作效率。为了帮助您解决这些问题,我们特别撰写了这篇文章,以提供全面、有效的解决方案。 一、Abaqus许可证配置文件问题及原因 许…

力扣labuladong一刷day32天二叉树

力扣labuladong一刷day32天二叉树 一、297. 二叉树的序列化与反序列化 题目链接:https://leetcode.cn/problems/serialize-and-deserialize-binary-tree/ 思路:关于序列化与反序列化,题目不要求序列化的方式,只要求树经过序列化…

linux的定时任务Corntab

安装crontab # yum安装crontab yum install -y crontab# 开机自启crond服务并现在启动 systemctl enable --now crondcron系统任务调度 系统任务调度: 系统周期性所要执行的工作,比如写缓存数据到硬盘、日志清理等。 在/etc/crontab文件,这…

机器学习之全面了解回归学习器

我们将和大家一起探讨机器学习与数据科学的主题。 本文主要讨论大家针对回归学习器提出的问题。我将概要介绍,然后探讨以下五个问题: 1. 能否将回归学习器用于时序数据? 2. 该如何缩短训练时间? 3. 该如何解释不同模型的结果和…

No suitable driver found for jdbc:mysql://localhost:3306(2023/12/7更新)

有两种情况: 压根没安装下载了但没设为库或方法不对 大多数为第一种情况: 一. 下载jdbc 打开网址选择一个版本进行下载 https://nowjava.com/jar/version/mysql/mysql-connector-java.html 二.安装jdbc 在项目里建一个lib文件夹 在把之前下载的jar文…

优化 SQL 日志记录的方法

为什么 SQL 日志记录是必不可少的 SQL 日志记录在数据库安全和审计中起着至关重要的作用,它涉及跟踪在数据库上执行的所有 SQL 语句,从而实现审计、故障排除和取证分析。SQL 日志记录可以提供有关数据库如何访问和使用的宝贵见解,使其成为确…

JNPF低代码平台详解 -- 系统架构

目录 一、技术介绍 技术架构 二、设计原理 三、界面展示 1.代码生成器 2.工作流程 3.门户设计 4.大屏设计 5.报表设计 6.第三方登录 7.多租户实现 8.分布式调度 9.消息中心 四、功能框架 JNPF低代码是一款新奇、实用、高效的企业级软件开发工具,支持企…

Qt/C++音视频开发58-逐帧播放/上一帧下一帧/切换播放进度/实时解码

一、前言 逐帧播放是近期增加的功能,之前也一直思考过这个功能该如何实现,对于mdk/qtav等内核组件,可以直接用该组件提供的接口实现即可,而对于ffmpeg,需要自己处理,如果有缓存的数据的话,可以…

Rust的eBFP框架Aya(一) - Linux内核网络基础

前言 在我的Rust入门及实战系列文章中已经说明, Rust是一门内存安全的高性能编程语言,从它的这些优秀特性来看,就是一门专为系统开发而诞生的语言。至于很多使用Rust来进行web开发的行为,不能说它们不好,只能说是杀鸡…

2017下半年软工(桥接模式)

题目——桥接模式(抽象调用实现部分) package org.example.桥接模式;/*** 桥接模式的核心思想是将抽象部分与它的实现部分分离,使它们可以独立变化,就是说你在实现部分:WinImp、LinuxImp基础上还能加上RedHatImp&#…