pytorch-01

加载mnist数据集

one-hot编码实现

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
print(y)

对于拥有6000个样本的MNIST数据集来说,标签就是一个6000\times 10大小的矩阵张量。

多层感知机模型

#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()  # 拉平图像矩阵self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),   # 输入大小为28*28,输出大小为312维的线性变换层torch.nn.ReLU(),   # 激活函数层torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10)  # 最终输出大小为10,对应one-hot标签维度)def forward(self, input):   # 构建网络x = self.flatten(input)  #拉平矩阵为1维logits = self.linear_relu_stack(x) # 多层感知机return logits

损失函数

优化函数

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数loss = loss_fu(pred,label_batch)  # 计算损失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as npbatch_size = 320                        #设定每次训练的批次数
epochs = 1024                           #设定训练次数#device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)                #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train)//batch_size#开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可视化模型结构和参数

model = NeuralNetwork()
print(model)

是对模型具体使用的函数及其对应的参数进行打印。

格式化显示:

param = list(model.parameters())
k=0
for i in param:l = 1print('该层结构:'+str(list(i.size())))for j in i.size():l*=jprint('该层参数和:'+str(l))k = k+l
print("总参数量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可视化

安装:pip install netron

运行:命令行输入netron

打开:通过网址http://localhost:8080打开

打开保存的模型文件model.pth:

 

 点击颜色块,可以显示详细信息:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38251.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 全局状态管理新宠:Pinia实战指南

文章目录 前言全局状态管理基本步骤:pinia 前言 随着Vue.js项目的日益复杂,高效的状态管理变得至关重要。Pinia作为Vue.js官方推荐的新一代状态管理库,以其简洁的API和强大的功能脱颖而出。本文将带您快速上手Pinia,从安装到应用&…

uniapp如何根据不同角色自定义不同的tabbar

思路: 1.第一种是根据登录时获取的不同角色信息,来进行 跳转到不同的页面,在这些页面中使用自定义tabbar 2.第二种思路是封装一个自定义tabbar组件,然后在所有要展示tabbar的页面中引入使用 1.根据手机号码一键登录&#xff0c…

SpringMVC的基本使用

SpringMVC简介 SpringMVC是Spring提供的一套建立在Servlet基础上,基于MVC模式的web解决方案 SpringMVC核心组件 DispatcherServlet:前置控制器,来自客户端的所有请求都经由DispatcherServlet进行处理和分发Handler:处理器&…

三个方法教大家学会RAR文件转换为ZIP格式

在日常工作当中,RAR和ZIP是两种常见的压缩文件格式。有时候,大家可能会遇到将RAR文件转换为ZIP格式的情况,这通常是为了方便在特定情况下打开或使用文件。下面给大家分享几个RAR文件转换为ZIP格式的方法,下面随小编一起来看看吧~ …

在mfc程序中,如何用c++找到exe文件所在的路径

在 MFC&#xff08;Microsoft Foundation Class&#xff09;程序中&#xff0c;你可以使用 GetModuleFileName 函数来获取当前运行的可执行文件&#xff08;.exe&#xff09;的路径。 以下是一个示例代码&#xff1a; #include <afxwin.h> #include <iostream>in…

KVM性能优化之CPU优化

1、查看kvm虚拟机vCPU的QEMU线程 ps -eLo ruser,pid,ppid,lwp,psr,args |awk /^qemu/{print $1,$2,$3,$4,$5,$6,$8} 注:vcpu是不同的线程&#xff0c;而不同的线程是跑在不同的cpu上&#xff0c;一般情况&#xff0c;虚拟机在运行时自身会点用3个cpus&#xff0c;为保证生产环…

通过MATLAB控制TI毫米波雷达的工作状态

前言 前一章博主介绍了MATLAB上位机软件“设计视图”的制作流程,这一章节博主将介绍如何基于这些组件结合MATLAB代码来发送CFG指令控制毫米波雷达的工作状态 串口配置 首先,在我们选择的端口号输入框和端口波特率设置框内是可以手动填入数值(字符)的,也可以在点击运行后…

汇凯金业:投资交易如何才能不亏损

投资交易中永不亏损是一个理想化的目标&#xff0c;现实中无法完全避免亏损。然而&#xff0c;通过科学的方法、合理的策略和严格的风险管理&#xff0c;投资者可以大幅减少亏损&#xff0c;并提高长期盈利的概率。以下是一些关键策略和方法&#xff0c;帮助投资者在交易中尽量…

【CSRF】

CSRF 原理&#xff1a;诱导用户在访问第三方site时&#xff0c;访问攻击者构造的site,攻击者site会对原site进行恶意操作。 burp模拟攻击&#xff1a; 对一个博客系统点击发布文章时&#xff0c;Burp Suite抓包&#xff0c;右键CSRF PoC功能 -> Engagament tools -> Gen…

洛谷 P3954 [NOIP2017 普及组] 成绩

本文由Jzwalliser原创&#xff0c;发布在CSDN平台上&#xff0c;遵循CC 4.0 BY-SA协议。 因此&#xff0c;若需转载/引用本文&#xff0c;请注明作者并附原文链接&#xff0c;且禁止删除/修改本段文字。 违者必究&#xff0c;谢谢配合。 个人主页&#xff1a;blog.csdn.net/jzw…

太阳能辐射系统加速材料老化的关键设备光照老化实验箱

光照老化实验箱概述 光照老化实验箱是一种模拟太阳光照射对材料影响的实验设备&#xff0c;主要用于加速材料的自然老化过程&#xff0c;以此来评估材料在实际使用环境中的耐久性和稳定性。该设备广泛应用于汽车、航空、建筑、塑料制品等行业&#xff0c;尤其在汽车领域&#…

多商户b2b2c商城系统怎么运营

B2B2C多用户商城系统支持多种运营模式&#xff0c;以满足不同类型和发展阶段的企业需求。以下是五大主要的运营模式&#xff1a; **1. 自营模式&#xff1a;**平台企业通过建立自营线上商城&#xff0c;整合自身多渠道业务。通过会员、商品、订单、财务和仓储等多用户商城管理系…

OK527N-C开发板-简单的性能测试

OK527N-C CoreMark 获取CoreMark源码 首先使用Git克隆仓库&#xff1a; git clone https://github.com/eembc/coremark.git cd coremark修改Makefile 首先复制文件夹 cp -rf posix ok527之后修改ok527文件夹下的core_portme.mak文件&#xff0c;将CC修改如下 CC aarch6…

CPU占用率飙升至100%:是攻击还是正常现象?

在运维和开发的日常工作中&#xff0c;CPU占用率突然飙升至100%往往是一个令人紧张的信号。这可能意味着服务器正在遭受攻击&#xff0c;但也可能是由于某些正常的、但资源密集型的任务或进程造成的。本文将探讨如何识别和应对服务器的异常CPU占用情况&#xff0c;并通过Python…

魔行观察-探鱼·鲜青椒爽麻烤鱼-开关店监测-时间段:2013年1月 至 2024年6月

今日监测对象&#xff1a;探鱼鲜青椒爽麻烤鱼&#xff0c;监测时间段&#xff1a;2011年1月 至 2024年6月 本文用到数据源免费获取地址 魔行观察http://www.wmomo.com/ 品牌介绍&#xff1a; 探鱼建立了产、供、销一体全链条式供应链体系&#xff0c;并在低纬珠江口特设潮汐…

大公司图纸管理的未来趋势

随着科技的不断发展&#xff0c;大公司图纸管理正朝着更加智能化、自动化和协同化的方向发展。以下是大公司图纸管理的未来趋势预测。 1. 智能化管理 利用人工智能和机器学习技术&#xff0c;实现图纸的自动分类、标注和检索。通过智能分析算法&#xff0c;预测图纸的使用趋势…

NSSCTF-Web题目19(数据库注入、文件上传、php非法传参)

目录 [LitCTF 2023]这是什么&#xff1f;SQL &#xff01;注一下 &#xff01; 1、题目 2、知识点 3、思路 [SWPUCTF 2023 秋季新生赛]Pingpingping 4、题目 5、知识点 6、思路 [LitCTF 2023]这是什么&#xff1f;SQL &#xff01;注一下 &#xff01; 1、题目 2、知识…

基于Vue的MOBA类游戏攻略分享平台

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;Java技术、SpringBoot框架、B/S模式、Vue.js 工具&#xff1a;MyEclipse、MySQL 系统展示 首页 用…

在 Windows 上,使用 icacls 命令让apache 用户有权访问

调试免费云服务器&#xff0c;三丰云&#xff0c;用户权限过程。 在 Windows 上&#xff0c;icacls 命令是一个非常强大的工具&#xff0c;用于修改文件和目录的权限。然而&#xff0c;需要注意的是&#xff0c;Windows 默认的 Web 服务器&#xff08;如 IIS&#xff09;通常运…

lstrip()方法——截掉字符串左边的空格或指定的字符

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 语法参考 lstrip()方法用于截掉字符串左边的空格或指定的字符。lstrip()方法的语法格式如下&#xff1a; str.lstrip([chars]) 参数说明&#xff…