PyTorch官网demo解读——第一个神经网络(1)

神经网络如此神奇,feel the magic

今天分享一下学习PyTorch官网demo的心得,原来实现一个神经网络可以如此简单/简洁/高效,同时也感慨PyTorch如此强大。

这个demo的目的是训练一个识别手写数字的模型!

先上源码:
from pathlib import Path
import requests   # http请求库
import pickle
import gzipfrom matplotlib import pyplot   # 显示图像库import math
import numpy as np
import torch###########下载训练/验证数据######################################################
# 这里加载的是mnist数据集
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)###########解压并加载训练数据######################################################
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")# 通过pyplot显示数据集中的第一张图片
# 显示过程会中断运行,看到效果之后可以屏蔽掉,让调试更顺畅
#print("x_train[0]: ", x_train[0])
#pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#pyplot.show()# 将加载的数据转成tensor
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape   # n是函数,c是列数
print("x_train.shape: ", x_train.shape)
print("y_train.min: {0}, y_train.max: {1}".format(y_train.min(), y_train.max()))# 初始化权重和偏差值,权重是随机出来的784*10的矩阵,偏差初始化为0
weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)# 激活函数
def log_softmax(x):return x - x.exp().sum(-1).log().unsqueeze(-1)# 定义模型:y = wx + b
# 实际上就是单层的Linear模型
def model(xb):return log_softmax(xb @ weights + bias)# 丢失函数 loss function
def nll(input, target):return -input[range(target.shape[0]), target].mean()
loss_func = nll# 计算精度函数
def accuracy(out, yb):preds = torch.argmax(out, dim=1)return (preds == yb).float().mean()###########开始训练##################################################################
bs = 64  # 每一批数据的大小
lr = 0.5  # 学习率
epochs = 2  # how many epochs to train forfor epoch in range(epochs):for i in range((n - 1) // bs + 1):start_i = i * bsend_i = start_i + bsxb = x_train[start_i:end_i]yb = y_train[start_i:end_i]pred = model(xb) # 通过模型预测loss = loss_func(pred, yb) # 通过与实际结果比对,计算丢失值loss.backward() # 反向传播with torch.no_grad():weights -= weights.grad * lr  # 调整权重值bias -= bias.grad * lr  # 调整偏差值weights.grad.zero_()bias.grad.zero_()##########对比一下预测结果############################################################
xb = x_train[0:bs]  # 加载一批数据,这里用的是训练的数据,在实际应用中最好使用没训练过的数据来验证
yb = y_train[0:bs]  # 训练数据对应的正确结果
preds = model(xb)  # 使用训练之后的模型进行预测
print("################## after training ###################")
print("accuracy: ", accuracy(preds, yb))   # 打印出训练之后的精度
# print(preds[0])
print("pred value: ", torch.argmax(preds, dim=1))   # 打印预测的数字
print("real value: ", yb)   # 实际正确的数据,可以直观地和上一行打印地数据进行对比
运行结果:

可以看到训练后模型地预测精度达到了0.9531,已经不错了,毕竟只使用了一个单层地Linear模型;从输出地对比数据中可以看出有三个地方预测错了(红框标记地数字)

ok,今天先到这里,下一篇再来解读代码中地细节

附:

PyTorch官方源码:https://github.com/pytorch/tutorials/blob/main/beginner_source/nn_tutorial.py

天地一逆旅,同悲万古愁!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227350.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Composer 安装与使用

Composer 是 PHP 的一个依赖管理工具。我们可以在项目中声明所依赖的外部工具库,Composer 会帮你安装这些依赖的库文件,有了它,我们就可以很轻松的使用一个命令将其他人的优秀代码引用到我们的项目中来。 Composer 默认情况下不是全局安装&a…

vue3 element-plus 日期选择器 el-date-picker 汉化

vue3 项目中,element-plus 的日期选择器 el-date-picker 默认是英文版的,如下: 页面引入: //引入汉化语言包 import locale from "element-plus/lib/locale/lang/zh-cn" import { ElDatePicker, ElButton, ElConfigP…

西南科技大学数据库实验二(表数据插入、修改和删除)

一、实验目的 (1)学会用SQL语句对数据库进行插入、修改和删除数据操作 (2)掌握insert、update、delete命令实现对表数据插入、修改和删除等更新操作。 二、实验任务 创建数据库,并创建Employees表、Departments表和…

微服务网关Gateway

springcloud官方提供的网关组件spring-cloud-starter-gateway,看pom.xml文件,引入了webflux做响应式编程,请求转发用到了netty的reactor模型,支持的请求数在1W~1.5W左右。hystrix停止维护后,官方推荐resilience4j做服务熔断,网关这里也能看到依赖。 对于网关提供的功能…

Unity 使用AddTorque方法给刚体施加力矩详解

给刚体施加力,除了使用AddForce方法,我们还可以使用AddTorque方法。该方法是通过施加力矩给刚体以力。AddTorque方法从形式上跟AddForce差不多,它也有4个重载方法: 1、AddTorque(Vector3 torque);使用Vector3类型参数…

在Node.js中MongoDB查询分页的方法

本文主要介绍在Node.js中MongoDB查询分页的方法。 目录 Node.js中MongoDB查询分页使用原生的mongodb驱动程序查询分页使用Mongoose库进行查询分页注意项 Node.js中MongoDB查询分页 在Node.js中使用MongoDB进行查询分页,可以使用原生的mongodb驱动程序或者Mongoose库…

【web安全】密码爆破讲解,以及burp的爆破功能使用方法

前言 菜某总结,欢迎指正错误进行补充 密码暴力破解原理 暴力破解实际就是疯狂的输入密码进行尝试登录,针对有的人喜欢用一些个人信息当做密码,有的人喜欢用一些很简单的低强度密码,我们就可以针对性的生成一个字典,…

【Linux】文件系统、文件系统结构、虚拟文件系统

一、文件系统概述 1. 什么是文件系统?2. 文件系统(文件管理系统的方法)的种类有哪些?3. 什么是分区?4. 什么是文件系统目录结构?5. 什么虚拟文件系统Virtual File System ?6. 虚拟文件系统有什…

OpenAI开源超级对齐方法:用GPT-2,监督、微调GPT-4

12月15日,OpenAI在官网公布了最新研究论文和开源项目——如何用小模型监督大模型,实现更好的新型对齐方法。 目前,大模型的主流对齐方法是RLHF(人类反馈强化学习)。但随着大模型朝着多模态、AGI发展,神经元…

Spring Boot SOAP Web 服务端和客户端

一. 服务端 1. 技术栈 JDK 1.8,Eclipse,Maven – 开发环境SpringBoot – 基础应用程序框架wsdl4j – 为我们的服务发布 WSDLSOAP-UI – 用于测试我们的服务JAXB maven 插件 – 用于代码生成 2.创建 Spring Boot 项目 添加 Wsdl4j 依赖关系 编辑pom…

cesium 自定义贴图,shadertoy移植教程。

1.前言 cesium中提供了一些高级的api,可以自己写一些shader来制作炫酷的效果。 ShaderToy 是一个可以在线编写、测试和分享图形渲染着色器的网站。它提供了一个图形化的编辑器,可以让用户编写基于 WebGL 的 GLSL 着色器代码,并实时预览渲染结…

006 Windows共享

一、共享要求 一般是局域网内使用 1、物理上处于统一局域网 同一公司的网络同一家庭的网络连接同一手机热点的主机 2、逻辑上处于同一局域网 直接可以ping对方主机(能够直接访问到) 二、共享权限 1、共享权限 一般设置为everyone完全控制 2、NTF…

基于3D-CGAN的跨模态MR脑肿瘤分割图像合成

3D CGAN BASED CROSS-MODALITY MR IMAGE SYNTHESIS FOR BRAIN TUMOR SEGMENTATION 基于3D-CGAN的跨模态MR脑肿瘤分割图像合成背景贡献实验方法Subject-specific local adaptive fusion(针对特定主题的局部自适应融合)Brain tumor segmentation model 损失…

K8s投射数据卷

目录 一.Secret 1.secret介绍 2.secret的类型 3.创建secret 4.使用secret 环境变量的形式 volume数据卷挂载 二ConfigMap 1.创建ConfigMap的方式 2.使用ConfigMap 2.1作为volume挂载使用 2.2.作为环境变量 三.Downward API 1.以环境变量的方式实现 2.Volume挂载 一.S…

深入解析 Spring 和 Spring Boot 的区别

目录 引言 1. 设计理念 1.1 Spring 框架的设计理念 1.2 Spring Boot 的设计理念 2. 项目配置 2.1 Spring 框架的项目配置 2.2 Spring Boot 的项目配置 3. 自动配置 3.1 Spring 框架的自动配置 3.2 Spring Boot 的自动配置 4. 微服务支持 4.1 Spring 框架的微服务支持…

OceanBase 4.2.1社区版 最小资源需求安装方式

OceanBase 4.2.1社区版 最小资源需求安装方式 资源需求 资源需求分析 observer Memory 控制参数: memory_limit_percentage 默认80% memory_limit 直接设定observer Memory 大小 System memory 可设为1G 租户内存:sys租户内存设为1G,OCP需要的租户oc…

在Windows上通过cmake-gui及VS2019来 编译OpenCV-4.5.3源码

文章目录 下载OpenCV-4.5.3源码下载opencv_contrib-4.5.3源码打开cmake-gui选择生成器 通过 Visual Studio 2019 打开构建好的.sln工程文件执行编译操作执行安装操作 下载OpenCV-4.5.3源码 可通过github上下载,网上很多,找到tag标签,选择 Op…

OSG中几何体的绘制(二)

5. 几何体操作 在本章的前言中就讲到,场景都是由基本的绘图基元构成的,基本的绘图基元构成简单的几何体,简单的几何体构成复杂的几何体,复杂的几何体最终构造成复杂的场景。当多个几何体组合时,可能存在多种降低场景渲染效率的原因…

AlexNet(pytorch)

AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的 70%提升到 80% 该网络的亮点在于: (1)首次利用 GPU 进行网络加速训练。 &#xff…

Idea中操作Git使用cherry pick

Idea中操作Git使用cherry pick 使用场景使用功能步骤 使用场景 代码开发中,新功能还未开发完,但是master分支需要使用带新功能中的一次提交的代码,就可以使用cherry pack(优选). 使用功能步骤 切换到master分支选中dev分支双击选择需要使用的提交右键,如果有冲突就会弹窗解…