pytorch-01

加载mnist数据集

one-hot编码实现

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
print(y)

对于拥有6000个样本的MNIST数据集来说,标签就是一个6000\times 10大小的矩阵张量。

多层感知机模型

#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()  # 拉平图像矩阵self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),   # 输入大小为28*28,输出大小为312维的线性变换层torch.nn.ReLU(),   # 激活函数层torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10)  # 最终输出大小为10,对应one-hot标签维度)def forward(self, input):   # 构建网络x = self.flatten(input)  #拉平矩阵为1维logits = self.linear_relu_stack(x) # 多层感知机return logits

损失函数

优化函数

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数loss = loss_fu(pred,label_batch)  # 计算损失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as npbatch_size = 320                        #设定每次训练的批次数
epochs = 1024                           #设定训练次数#device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)                #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train)//batch_size#开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可视化模型结构和参数

model = NeuralNetwork()
print(model)

是对模型具体使用的函数及其对应的参数进行打印。

格式化显示:

param = list(model.parameters())
k=0
for i in param:l = 1print('该层结构:'+str(list(i.size())))for j in i.size():l*=jprint('该层参数和:'+str(l))k = k+l
print("总参数量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可视化

安装:pip install netron

运行:命令行输入netron

打开:通过网址http://localhost:8080打开

打开保存的模型文件model.pth:

 

 点击颜色块,可以显示详细信息:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38251.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 全局状态管理新宠:Pinia实战指南

文章目录 前言全局状态管理基本步骤:pinia 前言 随着Vue.js项目的日益复杂,高效的状态管理变得至关重要。Pinia作为Vue.js官方推荐的新一代状态管理库,以其简洁的API和强大的功能脱颖而出。本文将带您快速上手Pinia,从安装到应用&…

SpringMVC的基本使用

SpringMVC简介 SpringMVC是Spring提供的一套建立在Servlet基础上,基于MVC模式的web解决方案 SpringMVC核心组件 DispatcherServlet:前置控制器,来自客户端的所有请求都经由DispatcherServlet进行处理和分发Handler:处理器&…

三个方法教大家学会RAR文件转换为ZIP格式

在日常工作当中,RAR和ZIP是两种常见的压缩文件格式。有时候,大家可能会遇到将RAR文件转换为ZIP格式的情况,这通常是为了方便在特定情况下打开或使用文件。下面给大家分享几个RAR文件转换为ZIP格式的方法,下面随小编一起来看看吧~ …

KVM性能优化之CPU优化

1、查看kvm虚拟机vCPU的QEMU线程 ps -eLo ruser,pid,ppid,lwp,psr,args |awk /^qemu/{print $1,$2,$3,$4,$5,$6,$8} 注:vcpu是不同的线程,而不同的线程是跑在不同的cpu上,一般情况,虚拟机在运行时自身会点用3个cpus,为保证生产环…

通过MATLAB控制TI毫米波雷达的工作状态

前言 前一章博主介绍了MATLAB上位机软件“设计视图”的制作流程,这一章节博主将介绍如何基于这些组件结合MATLAB代码来发送CFG指令控制毫米波雷达的工作状态 串口配置 首先,在我们选择的端口号输入框和端口波特率设置框内是可以手动填入数值(字符)的,也可以在点击运行后…

汇凯金业:投资交易如何才能不亏损

投资交易中永不亏损是一个理想化的目标,现实中无法完全避免亏损。然而,通过科学的方法、合理的策略和严格的风险管理,投资者可以大幅减少亏损,并提高长期盈利的概率。以下是一些关键策略和方法,帮助投资者在交易中尽量…

太阳能辐射系统加速材料老化的关键设备光照老化实验箱

光照老化实验箱概述 光照老化实验箱是一种模拟太阳光照射对材料影响的实验设备,主要用于加速材料的自然老化过程,以此来评估材料在实际使用环境中的耐久性和稳定性。该设备广泛应用于汽车、航空、建筑、塑料制品等行业,尤其在汽车领域&#…

多商户b2b2c商城系统怎么运营

B2B2C多用户商城系统支持多种运营模式,以满足不同类型和发展阶段的企业需求。以下是五大主要的运营模式: **1. 自营模式:**平台企业通过建立自营线上商城,整合自身多渠道业务。通过会员、商品、订单、财务和仓储等多用户商城管理系…

OK527N-C开发板-简单的性能测试

OK527N-C CoreMark 获取CoreMark源码 首先使用Git克隆仓库: git clone https://github.com/eembc/coremark.git cd coremark修改Makefile 首先复制文件夹 cp -rf posix ok527之后修改ok527文件夹下的core_portme.mak文件,将CC修改如下 CC aarch6…

魔行观察-探鱼·鲜青椒爽麻烤鱼-开关店监测-时间段:2013年1月 至 2024年6月

今日监测对象:探鱼鲜青椒爽麻烤鱼,监测时间段:2011年1月 至 2024年6月 本文用到数据源免费获取地址 魔行观察http://www.wmomo.com/ 品牌介绍: 探鱼建立了产、供、销一体全链条式供应链体系,并在低纬珠江口特设潮汐…

大公司图纸管理的未来趋势

随着科技的不断发展,大公司图纸管理正朝着更加智能化、自动化和协同化的方向发展。以下是大公司图纸管理的未来趋势预测。 1. 智能化管理 利用人工智能和机器学习技术,实现图纸的自动分类、标注和检索。通过智能分析算法,预测图纸的使用趋势…

NSSCTF-Web题目19(数据库注入、文件上传、php非法传参)

目录 [LitCTF 2023]这是什么?SQL !注一下 ! 1、题目 2、知识点 3、思路 [SWPUCTF 2023 秋季新生赛]Pingpingping 4、题目 5、知识点 6、思路 [LitCTF 2023]这是什么?SQL !注一下 ! 1、题目 2、知识…

基于Vue的MOBA类游戏攻略分享平台

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:Java技术、SpringBoot框架、B/S模式、Vue.js 工具:MyEclipse、MySQL 系统展示 首页 用…

lstrip()方法——截掉字符串左边的空格或指定的字符

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 语法参考 lstrip()方法用于截掉字符串左边的空格或指定的字符。lstrip()方法的语法格式如下: str.lstrip([chars]) 参数说明&#xff…

【算法】Merge Sort 合并排序

Merge Sort概述 分而治之算法 递归地将问题分解为多个子问题,直到它们变得简单易解 将解决方案组合起来,解决原有问题 O(n*log(n))运行时间 基于比较的算法的最佳运行时间 一般原则 合并排序: 1. 将数…

【瞎折腾日常】服务器的cpu飙高到1000%了怎么破

一、故障起因 起因是用户反馈系统很卡,我登录普罗米修斯一看,发现docker部署得集群下的一个java应用服务器cpu爆了,直接冲到了1000%以上了,接着就是各种接口超时报警等,赶紧打开对应的服务器查看进程情况,这会使用jstack和top命令定位哪个线程占用的cpu比较大,定位代码问…

椭流线法设计配光器

椭流线法设计配光器 一、设计原理 1、边光原理 边光原理是非成像光学中的一个基础原理,其内容可以表述为:来自光源边缘的光线经过若干有序正则光学曲面后依然落在投射光斑的边缘,而来自光源内部的光线也将落在光斑内部。这里的边缘包含两层…

机械拆装-基于Unity-总体设计

前言 在工业设计和制造领域,零部件的拆装技术是一个重要的应用场景,比如我们在工程训练课程中经历的摩托车发动机拆装课程,是机械类学生的必修课程。虚拟拆装系统模拟和仿真了模型的拆装过程,虽然SolidWorks等机械设计软件能够解决…

性能调优 性能监控

1.影响性能考虑点包括: 数据库、应用程序、中间件(tomcat、nginx)、网络和操作系统等方面。 首先考虑自己的应用属于 CPU密集型 还是 IO密集型 cpu密集型 计算,排序,分组查询,各种算法 IO密集型 网络传输,磁盘读…

大创项目推荐 题目:基于机器视觉opencv的手势检测 手势识别 算法 - 深度学习 卷积神经网络 opencv python

文章目录 1 简介2 传统机器视觉的手势检测2.1 轮廓检测法2.2 算法结果2.3 整体代码实现2.3.1 算法流程 3 深度学习方法做手势识别3.1 经典的卷积神经网络3.2 YOLO系列3.3 SSD3.4 实现步骤3.4.1 数据集3.4.2 图像预处理3.4.3 构建卷积神经网络结构3.4.4 实验训练过程及结果 3.5 …