Pytorch 的实际应用 学习笔记

一.  模型的下载

weights为false时则为没有提前经过训练的模型,为true时则经过了提前训练

vgg16_false = torchvision.models.vgg16(weights=False)
vgg16_true = torchvision.models.vgg16(weights=True)

打印

二. 模型的修改

(1)添加操作

分为两种,一种是在classifier的外部添加,一种是在内部添加

外部添加,例如添加了一个线性层

vgg16_true.add_module("add_linear", nn.Linear(1000, 10))

打印,最下方添加了线性层

内部添加

vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

打印,在classifier里面添加

(2)修改操作

例如,我修改索引为6的操作

vgg16_false.classifier[6] = nn.Linear(4096, 5)

打印

三. 模型的保存与读取

在PyTorch中,可以使用torch.save函数来保存模型的状态字典或整个模型。同时,可以使用torch.load函数来加载保存的模型。

1. 保存模型的状态字典:

# 保存模型的状态字典
torch.save(model.state_dict(), 'model.pth')# 加载模型的状态字典
model.load_state_dict(torch.load('model.pth'))

2. 保存整个模型:

# 保存整个模型
torch.save(model, 'model.pth')# 加载整个模型
model = torch.load('model.pth')

需要注意的是,如果要加载模型,需要确保模型的定义和保存时一致。如果要加载模型到GPU上,需要在torch.load函数中传入map_location参数来指定加载到哪个设备上。

四. 训练套路实例

训练流程

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom ch2.model import Yktrain_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=True,download=True)
test_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=False,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, drop_last=True)loss_fn = nn.CrossEntropyLoss()learning_rate = 1e-2
yk = Yk()
opt = torch.optim.SGD(yk.model1.parameters(), learning_rate)total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("../logs")
for i in range(epoch):print("————第{}次训练开始————".format(i))yk.train()for data in train_loader:images, targets = dataoutput = yk(images)loss = loss_fn(output, targets)opt.zero_grad()loss.backward()opt.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{}, Loss:{}".format(total_train_step, loss))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始yk.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_loader:images, targets = dataoutput = yk(images)loss = loss_fn(output, targets)total_test_loss = total_test_loss + lossaccuracy = (output.argmax(1) == targets).sum()total_accuracy=total_accuracy+accuracyprint("整体测试集上的loss:{}".format(total_test_step))print("整体测试集上的正确率{}".format(total_accuracy/len(test_set)))total_test_step += 1writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy.item(), total_test_step)torch.save(yk, "yk_{}".format(i))print("模型已保存")writer.close()

训练模型

from torch import nnclass Yk(nn.Module):def __init__(self):super(Yk, self).__init__()self.model1 = nn.Sequential(nn.Conv2d(3, 32, (5, 5), padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, (5, 5), padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, (5, 5), padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(4 * 4 * 64, 64),nn.Linear(64, 10))def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':yk = Yk()

打印

五. 使用GPU训练

1. 使用cuda

原本代码

import timeimport torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom ch2.model import Yktrain_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=True,download=True)
test_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=False,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, drop_last=True)start_time = time.time()loss_fn = nn.CrossEntropyLoss()
# loss_fn = loss_fn.cuda()
learning_rate = 1e-2
yk = Yk()
# yk = yk.cuda()
opt = torch.optim.SGD(yk.model1.parameters(), learning_rate)total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("../logs")
for i in range(epoch):print("————第{}次训练开始————".format(i))yk.train()for data in train_loader:images, targets = data# images = images.cuda()# targets = targets.cuda()output = yk(images)loss = loss_fn(output, targets)opt.zero_grad()loss.backward()opt.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(end_time-start_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始yk.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_loader:images, targets = data# images = images.cuda()# targets = targets.cuda()output = yk(images)loss = loss_fn(output, targets)total_test_loss = total_test_loss + lossaccuracy = (output.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的loss:{}".format(total_test_step))print("整体测试集上的正确率{}".format(total_accuracy / len(test_set)))total_test_step += 1writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy.item(), total_test_step)torch.save(yk, "yk_{}".format(i))print("模型已保存")writer.close()

间隔时间如下

使用cuda,用gpu后,间隔时间明显极大缩短了

2. 定义设备

在里面定义使用cpu,还是cuda,有多个gpu,可以选用第几个(我选用的第一个)

device = torch.device("cuda:0")

原本需要使用*.cuda的地方,修改为*.to(device):

yk = yk.to(device)

全部代码如下:

import timeimport torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom ch2.model import Ykdevice = torch.device("cuda:0")train_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=True,download=True)
test_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=False,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, drop_last=True)start_time = time.time()loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
learning_rate = 1e-2
yk = Yk()
yk = yk.to(device)
opt = torch.optim.SGD(yk.model1.parameters(), learning_rate)total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("../logs")
for i in range(epoch):print("————第{}次训练开始————".format(i))yk.train()for data in train_loader:images, targets = dataimages = images.to(device)targets = targets.to(device)output = yk(images)loss = loss_fn(output, targets)opt.zero_grad()loss.backward()opt.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(end_time - start_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始yk.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_loader:images, targets = dataimages = images.to(device)targets = targets.to(device)output = yk(images)loss = loss_fn(output, targets)total_test_loss = total_test_loss + lossaccuracy = (output.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的loss:{}".format(total_test_step))print("整体测试集上的正确率{}".format(total_accuracy / len(test_set)))total_test_step += 1writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy.item(), total_test_step)torch.save(yk, "yk_{}".format(i))print("模型已保存")writer.close()

运行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3887.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ中的交换机类型

交换机类型 可以看到,在订阅模型中,多了一个exchange角色,而且过程略有变化: Publisher:生产者,不再发送消息到队列中,而是发给交换机 Exchange:交换机,一方面&#xff…

【小浩算法 DFS cpp题解】构造二叉树、打印二叉树、递归求树深度、实现DFS

目录 前言实现思路一些疑问的解决 我的代码 前言 今天第一次做一下cpp的树的算法,还是遇到了不少问题的,对树的学习其实在考研期间会比链表和数组少很多,其中最关键的要学会用递归。求深度也好,DSF也好,构造树也好&am…

欧科云链:为什么减半对比特币生态的影响正在逐步“减弱”?

出品|OKG Research 作者|Jason Jiang 欧科云链OKLink数据显示,比特币于区块高度840000(北京时间2024年4月20日8:09)成功完成第四次减半,比特币挖矿奖励正式由6.25BTC减少至3.125BTC。此次减半之后&#x…

Spring MVC系列之九大核心组件

概述 Spring MVC是面试必问知识点其一,Spring MVC知识体系庞杂,有以下九大核心组件: HandlerMappingHandlerAdapterHandlerExceptionResolverViewResolverRequestToViewNameTranslatorLocaleResolverThemeResolverMultipartResolverFlashMa…

中电金信:深度解析|数字化营销运营体系搭建

如何更好更快地梳理好体系搭建思路,稳步实现落地?下文将为大家明确搭建的推进步骤、执行要点,帮助商业银行理顺数字化营销运营体系的“点”“线”“面”~ 与所有转型的曲折、阵痛等特征一样,商业银行构建数字化营销运营体系过程中…

OceanBase 分布式数据库【信创/国产化】- OceanBase V4.3 里程碑版本

本心、输入输出、结果 文章目录 OceanBase 分布式数据库【信创/国产化】- OceanBase V4.3 里程碑版本前言OceanBase 数据更新架构4.3.0 版本是 OceanBase 迈向实时分析 AP 场景的重要里程特性:TP & AP一体化的产品形态新向量化引擎物化视图MySQL 模式下 Online DDL 扩充全…

C#算法之快速排序

算法释义:朋友们,我们在上文中说到,归并算法是一种分治算法,同样的,快速排序也是一种分治算法。所谓分治算法,原理上来说,是将规模为N的问题分解为若干个规模为较小的M的问题,这些子…

URL路由基础与Django处理请求的过程分析

1. URL路由基础 对于高质量的Web应用来讲,使用简洁、优雅的URL设计模式非常有必要。Django框架允许设计人员自由地设计URL模式,而不用受到框架本身的约束。对于URL路由来讲,其主要实现了Web服务的入口。用户通过浏览器发送过来的任何请求&am…

PyQt5中的QTablewidget

环境 PyQt5 VSCode Qt Designer生成界面 在VSCode的资源管理器中,右键选择 PYQT:New Form,打开Qt Designer 选择新建Dialog without Buttons,点击 创建 在左侧的Item Widgets中将 Table Widget拖入Dialog窗体中。 得到界面 将文件保存…

CH4INRULZ-v1靶机练习实践报告

CH4INRULZ-v1靶机练习实践报告 1 安装靶机 靶机是.ova文件,需要用VirtualBox打开,但我习惯于使用VMWare,因此修改靶机文件,使其适用于VMWare打开。 解压ova文件,得到.ovf文件和.vmdk文件。直接用VMWare打开.ovf文件即可。 2 夺…

Go语言 Interface(接口)

基本介绍 Go 语言提供了另外一种数据类型即接口,它把所有的具有共性的方法定义在一起,任何其他类型只要实现了这些方法就是实现了这个接口。接口可以让我们将不同的类型绑定到一组公共的方法上,从而实现多态和灵活的设计。Go 语言中的接口是…

Oceanbase体验之(一)运维管理工具OCP部署(社区版4.2.2)

资源规划建议 ocp主机1台 内存:64G CPU1:2C及以上 硬盘大于500G observer服务器3台 内存32G CPU:4C以上 硬盘大于1T 建议存储硬盘与操作系统硬盘隔开实现IO隔离 一、OBD、OCP安装包准备 [rootobserver /]# chown -R admin:admin /software/ [rootobserver /]# …

四:物联网ARM开发

一:ARM体系结构概述 1:控制外设led灯还有一些按键这些就要用到gpio,采集传感器的数据需要adc进行转化数据格式,特殊的外设和传感器是通过特殊的协议接口去进行连接的比如一些轴传感器和主控器的连接是通过spi,IIC 控制…

UE_反射系统(虚幻编译系统)

UE_反射系统(虚幻编译系统) UCLASS、UFUNCTION、UPROPERTY UCLASS 宏的有效关键字 https://docs.unrealengine.com/4.27/en-US/ProgrammingAndScripting/GameplayArchitecture/Classes/Specifiers/ When declaring classes, Class Specifiers can be added to the declar…

更新!!!Unity移动端游戏性能优化简谱

UWA官方出品,结合多年优化经验撰写了《Unity移动端游戏性能优化简谱》,文章从Unity移动端游戏优化的一些基础讨论出发,例举和分析了近几年基于Unity开发的移动端游戏项目中最为常见的部分性能问题,并展示了如何使用UWA的性能检测工…

mxnet gluon GRU 文档

mxnet.gluon.rnn.GRU官方文档 以下是一个使用的简单用例,详细信息前往官网 # hidden_size 100 num_layer 3 layer mx.gluon.rnn.GRU(100, 3) layer.initialize() # seq_len 5 batch_size 3 input_size 10 input mx.nd.random.uniform(shape(5, 3, 10)) # by…

(MSFT.O)微软2024财年Q3营收619亿美元

在科技的浩渺宇宙中,一颗璀璨星辰再度闪耀其光芒——(MSFT.O)微软公司于2024财政年度第三季展现出惊人的财务表现,实现总营业收入达到令人咋舌的6190亿美元。这一辉煌成就不仅突显了微软作为全球技术领导者之一的地位,更引发了业界内外对这家…

AIX7环境上一次艰难的Oracle打补丁经历

系统环境 AIX :7200-05-03-2148 Oracle:11.2.0.4 PSU: 11.2.0.4.201020(31718723) perl:5.28 问题一:AUTO patch #/u01/app/11.2.0/grid/OPatch/opatch auto /tmp/31718723 错误信息如下:匹配mos 2516761.1…

C语言读数据+遍历行数程序|Visual studio 2022

读数据遍历行数程序 记录一个度数遍历行数的程序 FILE* file2; int row2 0; file2 fopen("D://sins_mat2.txt", "r"); // file1 fopen("D://ga_mat2.txt", "r"); if (file2 NULL) {printf("open file1 failed.\n");re…

Kafka 3.x.x 入门到精通(05)——对标尚硅谷Kafka教程

Kafka 3.x.x 入门到精通(05)——对标尚硅谷Kafka教程 2. Kafka基础2.1 集群部署2.2 集群启动2.3 创建主题2.4 生产消息2.5 存储消息2.6 消费消息2.6.1 消费消息的基本步骤2.6.2 消费消息的基本代码2.6.3 消费消息的基本原理2.6.3.1消费者组2.6.3.1.1 消费…