PyTorch官网demo解读——第一个神经网络(1)

神经网络如此神奇,feel the magic

今天分享一下学习PyTorch官网demo的心得,原来实现一个神经网络可以如此简单/简洁/高效,同时也感慨PyTorch如此强大。

这个demo的目的是训练一个识别手写数字的模型!

先上源码:
from pathlib import Path
import requests   # http请求库
import pickle
import gzipfrom matplotlib import pyplot   # 显示图像库import math
import numpy as np
import torch###########下载训练/验证数据######################################################
# 这里加载的是mnist数据集
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)###########解压并加载训练数据######################################################
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")# 通过pyplot显示数据集中的第一张图片
# 显示过程会中断运行,看到效果之后可以屏蔽掉,让调试更顺畅
#print("x_train[0]: ", x_train[0])
#pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#pyplot.show()# 将加载的数据转成tensor
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape   # n是函数,c是列数
print("x_train.shape: ", x_train.shape)
print("y_train.min: {0}, y_train.max: {1}".format(y_train.min(), y_train.max()))# 初始化权重和偏差值,权重是随机出来的784*10的矩阵,偏差初始化为0
weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)# 激活函数
def log_softmax(x):return x - x.exp().sum(-1).log().unsqueeze(-1)# 定义模型:y = wx + b
# 实际上就是单层的Linear模型
def model(xb):return log_softmax(xb @ weights + bias)# 丢失函数 loss function
def nll(input, target):return -input[range(target.shape[0]), target].mean()
loss_func = nll# 计算精度函数
def accuracy(out, yb):preds = torch.argmax(out, dim=1)return (preds == yb).float().mean()###########开始训练##################################################################
bs = 64  # 每一批数据的大小
lr = 0.5  # 学习率
epochs = 2  # how many epochs to train forfor epoch in range(epochs):for i in range((n - 1) // bs + 1):start_i = i * bsend_i = start_i + bsxb = x_train[start_i:end_i]yb = y_train[start_i:end_i]pred = model(xb) # 通过模型预测loss = loss_func(pred, yb) # 通过与实际结果比对,计算丢失值loss.backward() # 反向传播with torch.no_grad():weights -= weights.grad * lr  # 调整权重值bias -= bias.grad * lr  # 调整偏差值weights.grad.zero_()bias.grad.zero_()##########对比一下预测结果############################################################
xb = x_train[0:bs]  # 加载一批数据,这里用的是训练的数据,在实际应用中最好使用没训练过的数据来验证
yb = y_train[0:bs]  # 训练数据对应的正确结果
preds = model(xb)  # 使用训练之后的模型进行预测
print("################## after training ###################")
print("accuracy: ", accuracy(preds, yb))   # 打印出训练之后的精度
# print(preds[0])
print("pred value: ", torch.argmax(preds, dim=1))   # 打印预测的数字
print("real value: ", yb)   # 实际正确的数据,可以直观地和上一行打印地数据进行对比
运行结果:

可以看到训练后模型地预测精度达到了0.9531,已经不错了,毕竟只使用了一个单层地Linear模型;从输出地对比数据中可以看出有三个地方预测错了(红框标记地数字)

ok,今天先到这里,下一篇再来解读代码中地细节

附:

PyTorch官方源码:https://github.com/pytorch/tutorials/blob/main/beginner_source/nn_tutorial.py

天地一逆旅,同悲万古愁!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227350.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为云CodeArts Pipeline常见问答汇总

1.【Pipeline】CodeArts Pipeline流水线如何传递参数至CodeArts Build编译构建任务 答参考文档 https://support.huaweicloud.com/pipeline_faq/pipeline_faq_0004.html https://support.huaweicloud.com/usermanual-pipeline/pipeline_10_0005.html https://support.hu…

Composer 安装与使用

Composer 是 PHP 的一个依赖管理工具。我们可以在项目中声明所依赖的外部工具库,Composer 会帮你安装这些依赖的库文件,有了它,我们就可以很轻松的使用一个命令将其他人的优秀代码引用到我们的项目中来。 Composer 默认情况下不是全局安装&a…

vue3 element-plus 日期选择器 el-date-picker 汉化

vue3 项目中,element-plus 的日期选择器 el-date-picker 默认是英文版的,如下: 页面引入: //引入汉化语言包 import locale from "element-plus/lib/locale/lang/zh-cn" import { ElDatePicker, ElButton, ElConfigP…

西南科技大学数据库实验二(表数据插入、修改和删除)

一、实验目的 (1)学会用SQL语句对数据库进行插入、修改和删除数据操作 (2)掌握insert、update、delete命令实现对表数据插入、修改和删除等更新操作。 二、实验任务 创建数据库,并创建Employees表、Departments表和…

Python学习笔记第七十五天(OpenCV图像应用)

Python学习笔记第七十五天 OpenCV图像应用读取图片显示图像写入图像保存图像 后记 OpenCV图像应用 读取图片 使用OpenCV读取图片非常简单,可以使用cv2.imread()函数。该函数接受两个参数:图像路径和标志。标志指定了读取图像的方式,包括是否…

MySQL5.x和8.0

区别1. 性能:MySQL 8.0 的速度要比 MySQL 5.7 快 2 倍 MySQL 8.0 在以下方面带来了更好的性能:读/写工作负载、IO 密集型工作负载、以及高竞争("hot spot"热点竞争问题)工作负载2. NoSQL:MySQL 从 5.7 版本开…

微服务网关Gateway

springcloud官方提供的网关组件spring-cloud-starter-gateway,看pom.xml文件,引入了webflux做响应式编程,请求转发用到了netty的reactor模型,支持的请求数在1W~1.5W左右。hystrix停止维护后,官方推荐resilience4j做服务熔断,网关这里也能看到依赖。 对于网关提供的功能…

什么是CI/CD?如何在PHP项目中实施CI/CD?

CI/CD(持续集成/持续交付或持续部署)是一种软件开发和交付方法,它旨在通过自动化和持续集成来提高开发速度和交付质量。以下是CI/CD的基本概念和如何在PHP项目中实施它的一般步骤: 持续集成(Continuous Integration -…

Unity 使用AddTorque方法给刚体施加力矩详解

给刚体施加力,除了使用AddForce方法,我们还可以使用AddTorque方法。该方法是通过施加力矩给刚体以力。AddTorque方法从形式上跟AddForce差不多,它也有4个重载方法: 1、AddTorque(Vector3 torque);使用Vector3类型参数…

YOLO v8 目标检测识别翻栏

一、行人翻栏识别背景介绍 1.1跨越围栏是人类活动中一个普遍但需要引起警惕的行为。它不仅可能导致各种意外事故,甚至可能对个人的生命安全构成威胁。在交通领域,跨越围栏可能导致严重的交通事故,造成人员伤亡。在公共场所,如公园…

华为、新华三、锐捷常用命令总结

华为、新华三、锐捷常用命令总结 一、华为交换机基础配置命令二、H3C交换机的基本配置三、锐捷交换机基础命令配置 一、华为交换机基础配置命令 1、创建vlan&#xff1a; <Quidway> //用户视图&#xff0c;也就是在Quidway模式下运行命令。 <Quidway>system-view…

B+树和索引

B树概念 是一种平衡多路搜索树&#xff08;Balanced Multiway Search Tree&#xff09;&#xff0c;常用于数据库和文件系统的索引结构。相比于其他的树型数据结构&#xff0c;如二叉搜索树和B树&#xff0c;B树在大数据量下的性能表现更优秀。 B树的基本特性&#xff1a; 多…

Ansible如何处理play错误的?Ansible角色?

Ansible如何处理play错误的&#xff1a;Ansible审查每个任务的返回代码&#xff0c;以确定任务是否成功或失败。默认情况下&#xff0c;当一个任务失败时&#xff0c;Ansible会立即中止该主机上的其他操作&#xff0c;并跳过所有后续任务。 实际生产中&#xff0c;若希望即使任…

在Node.js中MongoDB查询分页的方法

本文主要介绍在Node.js中MongoDB查询分页的方法。 目录 Node.js中MongoDB查询分页使用原生的mongodb驱动程序查询分页使用Mongoose库进行查询分页注意项 Node.js中MongoDB查询分页 在Node.js中使用MongoDB进行查询分页&#xff0c;可以使用原生的mongodb驱动程序或者Mongoose库…

【web安全】密码爆破讲解,以及burp的爆破功能使用方法

前言 菜某总结&#xff0c;欢迎指正错误进行补充 密码暴力破解原理 暴力破解实际就是疯狂的输入密码进行尝试登录&#xff0c;针对有的人喜欢用一些个人信息当做密码&#xff0c;有的人喜欢用一些很简单的低强度密码&#xff0c;我们就可以针对性的生成一个字典&#xff0c;…

【Linux】文件系统、文件系统结构、虚拟文件系统

一、文件系统概述 1. 什么是文件系统&#xff1f;2. 文件系统&#xff08;文件管理系统的方法&#xff09;的种类有哪些&#xff1f;3. 什么是分区&#xff1f;4. 什么是文件系统目录结构&#xff1f;5. 什么虚拟文件系统Virtual File System &#xff1f;6. 虚拟文件系统有什…

【华为数据之道学习笔记】5-6非结构化数据入湖

1. 非结构化数据管理的范围 非结构化数据包括无格式的文本、各类格式的文档、图像、音频、视频等多样异构的格式文件。相较于结构化数据&#xff0c;非结构化数据更难以标准化和理解&#xff0c;因而非结构化数据的管理不仅包括文件本身&#xff0c;而且包括对文件的描述属性&a…

OpenAI开源超级对齐方法:用GPT-2,监督、微调GPT-4

12月15日&#xff0c;OpenAI在官网公布了最新研究论文和开源项目——如何用小模型监督大模型&#xff0c;实现更好的新型对齐方法。 目前&#xff0c;大模型的主流对齐方法是RLHF&#xff08;人类反馈强化学习&#xff09;。但随着大模型朝着多模态、AGI发展&#xff0c;神经元…

Julia调用Matlab, Python以及R的微分方程求解器

文章目录 从其他语言翻译来的求解器重新封装版本 SciML教程系列&#xff1a; Julia求解常微分方程解Lorentz方程求解简谐振动的微分方程求解单摆 从其他语言翻译来的求解器 对于熟悉MATLAB/Python/R的程序员&#xff0c;可先使用下表中的求解器&#xff0c;因为这些求解器是…

Spring Boot SOAP Web 服务端和客户端

一. 服务端 1. 技术栈 JDK 1.8&#xff0c;Eclipse&#xff0c;Maven – 开发环境SpringBoot – 基础应用程序框架wsdl4j – 为我们的服务发布 WSDLSOAP-UI – 用于测试我们的服务JAXB maven 插件 – 用于代码生成 2.创建 Spring Boot 项目 添加 Wsdl4j 依赖关系 编辑pom…