Pytorch 的实际应用 学习笔记

一.  模型的下载

weights为false时则为没有提前经过训练的模型,为true时则经过了提前训练

vgg16_false = torchvision.models.vgg16(weights=False)
vgg16_true = torchvision.models.vgg16(weights=True)

打印

二. 模型的修改

(1)添加操作

分为两种,一种是在classifier的外部添加,一种是在内部添加

外部添加,例如添加了一个线性层

vgg16_true.add_module("add_linear", nn.Linear(1000, 10))

打印,最下方添加了线性层

内部添加

vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

打印,在classifier里面添加

(2)修改操作

例如,我修改索引为6的操作

vgg16_false.classifier[6] = nn.Linear(4096, 5)

打印

三. 模型的保存与读取

在PyTorch中,可以使用torch.save函数来保存模型的状态字典或整个模型。同时,可以使用torch.load函数来加载保存的模型。

1. 保存模型的状态字典:

# 保存模型的状态字典
torch.save(model.state_dict(), 'model.pth')# 加载模型的状态字典
model.load_state_dict(torch.load('model.pth'))

2. 保存整个模型:

# 保存整个模型
torch.save(model, 'model.pth')# 加载整个模型
model = torch.load('model.pth')

需要注意的是,如果要加载模型,需要确保模型的定义和保存时一致。如果要加载模型到GPU上,需要在torch.load函数中传入map_location参数来指定加载到哪个设备上。

四. 训练套路实例

训练流程

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom ch2.model import Yktrain_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=True,download=True)
test_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=False,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, drop_last=True)loss_fn = nn.CrossEntropyLoss()learning_rate = 1e-2
yk = Yk()
opt = torch.optim.SGD(yk.model1.parameters(), learning_rate)total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("../logs")
for i in range(epoch):print("————第{}次训练开始————".format(i))yk.train()for data in train_loader:images, targets = dataoutput = yk(images)loss = loss_fn(output, targets)opt.zero_grad()loss.backward()opt.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{}, Loss:{}".format(total_train_step, loss))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始yk.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_loader:images, targets = dataoutput = yk(images)loss = loss_fn(output, targets)total_test_loss = total_test_loss + lossaccuracy = (output.argmax(1) == targets).sum()total_accuracy=total_accuracy+accuracyprint("整体测试集上的loss:{}".format(total_test_step))print("整体测试集上的正确率{}".format(total_accuracy/len(test_set)))total_test_step += 1writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy.item(), total_test_step)torch.save(yk, "yk_{}".format(i))print("模型已保存")writer.close()

训练模型

from torch import nnclass Yk(nn.Module):def __init__(self):super(Yk, self).__init__()self.model1 = nn.Sequential(nn.Conv2d(3, 32, (5, 5), padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, (5, 5), padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, (5, 5), padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(4 * 4 * 64, 64),nn.Linear(64, 10))def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':yk = Yk()

打印

五. 使用GPU训练

1. 使用cuda

原本代码

import timeimport torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom ch2.model import Yktrain_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=True,download=True)
test_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=False,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, drop_last=True)start_time = time.time()loss_fn = nn.CrossEntropyLoss()
# loss_fn = loss_fn.cuda()
learning_rate = 1e-2
yk = Yk()
# yk = yk.cuda()
opt = torch.optim.SGD(yk.model1.parameters(), learning_rate)total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("../logs")
for i in range(epoch):print("————第{}次训练开始————".format(i))yk.train()for data in train_loader:images, targets = data# images = images.cuda()# targets = targets.cuda()output = yk(images)loss = loss_fn(output, targets)opt.zero_grad()loss.backward()opt.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(end_time-start_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始yk.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_loader:images, targets = data# images = images.cuda()# targets = targets.cuda()output = yk(images)loss = loss_fn(output, targets)total_test_loss = total_test_loss + lossaccuracy = (output.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的loss:{}".format(total_test_step))print("整体测试集上的正确率{}".format(total_accuracy / len(test_set)))total_test_step += 1writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy.item(), total_test_step)torch.save(yk, "yk_{}".format(i))print("模型已保存")writer.close()

间隔时间如下

使用cuda,用gpu后,间隔时间明显极大缩短了

2. 定义设备

在里面定义使用cpu,还是cuda,有多个gpu,可以选用第几个(我选用的第一个)

device = torch.device("cuda:0")

原本需要使用*.cuda的地方,修改为*.to(device):

yk = yk.to(device)

全部代码如下:

import timeimport torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom ch2.model import Ykdevice = torch.device("cuda:0")train_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=True,download=True)
test_set = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(), train=False,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, drop_last=True)start_time = time.time()loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
learning_rate = 1e-2
yk = Yk()
yk = yk.to(device)
opt = torch.optim.SGD(yk.model1.parameters(), learning_rate)total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("../logs")
for i in range(epoch):print("————第{}次训练开始————".format(i))yk.train()for data in train_loader:images, targets = dataimages = images.to(device)targets = targets.to(device)output = yk(images)loss = loss_fn(output, targets)opt.zero_grad()loss.backward()opt.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(end_time - start_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始yk.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_loader:images, targets = dataimages = images.to(device)targets = targets.to(device)output = yk(images)loss = loss_fn(output, targets)total_test_loss = total_test_loss + lossaccuracy = (output.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的loss:{}".format(total_test_step))print("整体测试集上的正确率{}".format(total_accuracy / len(test_set)))total_test_step += 1writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy.item(), total_test_step)torch.save(yk, "yk_{}".format(i))print("模型已保存")writer.close()

运行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3887.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ中的交换机类型

交换机类型 可以看到,在订阅模型中,多了一个exchange角色,而且过程略有变化: Publisher:生产者,不再发送消息到队列中,而是发给交换机 Exchange:交换机,一方面&#xff…

欧科云链:为什么减半对比特币生态的影响正在逐步“减弱”?

出品|OKG Research 作者|Jason Jiang 欧科云链OKLink数据显示,比特币于区块高度840000(北京时间2024年4月20日8:09)成功完成第四次减半,比特币挖矿奖励正式由6.25BTC减少至3.125BTC。此次减半之后&#x…

Spring MVC系列之九大核心组件

概述 Spring MVC是面试必问知识点其一,Spring MVC知识体系庞杂,有以下九大核心组件: HandlerMappingHandlerAdapterHandlerExceptionResolverViewResolverRequestToViewNameTranslatorLocaleResolverThemeResolverMultipartResolverFlashMa…

中电金信:深度解析|数字化营销运营体系搭建

如何更好更快地梳理好体系搭建思路,稳步实现落地?下文将为大家明确搭建的推进步骤、执行要点,帮助商业银行理顺数字化营销运营体系的“点”“线”“面”~ 与所有转型的曲折、阵痛等特征一样,商业银行构建数字化营销运营体系过程中…

URL路由基础与Django处理请求的过程分析

1. URL路由基础 对于高质量的Web应用来讲,使用简洁、优雅的URL设计模式非常有必要。Django框架允许设计人员自由地设计URL模式,而不用受到框架本身的约束。对于URL路由来讲,其主要实现了Web服务的入口。用户通过浏览器发送过来的任何请求&am…

PyQt5中的QTablewidget

环境 PyQt5 VSCode Qt Designer生成界面 在VSCode的资源管理器中,右键选择 PYQT:New Form,打开Qt Designer 选择新建Dialog without Buttons,点击 创建 在左侧的Item Widgets中将 Table Widget拖入Dialog窗体中。 得到界面 将文件保存…

CH4INRULZ-v1靶机练习实践报告

CH4INRULZ-v1靶机练习实践报告 1 安装靶机 靶机是.ova文件,需要用VirtualBox打开,但我习惯于使用VMWare,因此修改靶机文件,使其适用于VMWare打开。 解压ova文件,得到.ovf文件和.vmdk文件。直接用VMWare打开.ovf文件即可。 2 夺…

Oceanbase体验之(一)运维管理工具OCP部署(社区版4.2.2)

资源规划建议 ocp主机1台 内存:64G CPU1:2C及以上 硬盘大于500G observer服务器3台 内存32G CPU:4C以上 硬盘大于1T 建议存储硬盘与操作系统硬盘隔开实现IO隔离 一、OBD、OCP安装包准备 [rootobserver /]# chown -R admin:admin /software/ [rootobserver /]# …

四:物联网ARM开发

一:ARM体系结构概述 1:控制外设led灯还有一些按键这些就要用到gpio,采集传感器的数据需要adc进行转化数据格式,特殊的外设和传感器是通过特殊的协议接口去进行连接的比如一些轴传感器和主控器的连接是通过spi,IIC 控制…

更新!!!Unity移动端游戏性能优化简谱

UWA官方出品,结合多年优化经验撰写了《Unity移动端游戏性能优化简谱》,文章从Unity移动端游戏优化的一些基础讨论出发,例举和分析了近几年基于Unity开发的移动端游戏项目中最为常见的部分性能问题,并展示了如何使用UWA的性能检测工…

(MSFT.O)微软2024财年Q3营收619亿美元

在科技的浩渺宇宙中,一颗璀璨星辰再度闪耀其光芒——(MSFT.O)微软公司于2024财政年度第三季展现出惊人的财务表现,实现总营业收入达到令人咋舌的6190亿美元。这一辉煌成就不仅突显了微软作为全球技术领导者之一的地位,更引发了业界内外对这家…

AIX7环境上一次艰难的Oracle打补丁经历

系统环境 AIX :7200-05-03-2148 Oracle:11.2.0.4 PSU: 11.2.0.4.201020(31718723) perl:5.28 问题一:AUTO patch #/u01/app/11.2.0/grid/OPatch/opatch auto /tmp/31718723 错误信息如下:匹配mos 2516761.1…

Kafka 3.x.x 入门到精通(05)——对标尚硅谷Kafka教程

Kafka 3.x.x 入门到精通(05)——对标尚硅谷Kafka教程 2. Kafka基础2.1 集群部署2.2 集群启动2.3 创建主题2.4 生产消息2.5 存储消息2.6 消费消息2.6.1 消费消息的基本步骤2.6.2 消费消息的基本代码2.6.3 消费消息的基本原理2.6.3.1消费者组2.6.3.1.1 消费…

美国洛杉矶站群服务器如何提高网站排名?

美国洛杉矶站群服务器怎么样?美国洛杉矶站群服务器如何提高网站排名?Rak部落小编为您整理发布美国洛杉矶站群服务器如何提高网站排名? 美国洛杉矶站群服务器可以通过以下几种方式帮助提高网站排名: - **提升网站性能**:美国站群服务器通常配备高速CPU…

LLM学习笔记-5

目录 1.多层神经网络的实现2. 训练轮次示例3. 保存并加载模型4. 使用GPU加速训练5. 使用上面所教,进行一次训练 摘要:今天想整理一下Pytorch常用操作,以便以后进行预习(不是) 1.多层神经网络的实现 这是常用的操作&a…

Elcomsoft iOS Forensics Toolkit: iPhone/iPad/iPod 设备取证工具包

天津鸿萌科贸发展有限公司是 ElcomSoft 系列取证软件的授权代理商。 Elcomsoft iOS Forensics Toolkit 软件工具包适用于取证工作,对 iPhone、iPad 和 iPod Touch 设备执行完整文件系统和逻辑数据采集。对设备文件系统制作镜像,提取设备机密&#xff08…

阿斯达年代记三强争霸服务器没反应 安装中发生错误的解决方法

阿斯达年代记三强争霸服务器没反应 安装中发生错误的解决方法 最近刚上线的由影视剧改编的游戏《阿斯达年代记三强争霸》可谓是在游戏圈内引起了轩然大波,这是一款由网石集团与龙工作室联合开发的MMORPG游戏,游戏背景设定在一个名为阿斯大陆的区域&…

vue 实现项目进度甘特图

项目需求: 实现以1天、7天、30天为周期(周期根据筛选条件选择),展示每个项目不同里程碑任务进度。 项目在Vue-Gantt-chart: 使用Vue做数据控制的Gantt图表基础上进行了改造。 有需要的小伙伴也可以直接引入插件,自己…

用Scrapy编写第一个入门项目(基础四件套:spider,pipeline,setting,items)

简介:scrapy是一个用于爬取网页并提取数据的应用框架,也可用于提取API数据 写在前面:只想看scrapy的童鞋子请跳过5-7直接step8) step5,6是xpath和css入门,用于提取数据; step7是文件储存方式&…

国产麒麟系统下打包electron+vue项目(AppImage、deb)

需要用到的一些依赖包、安装包以及更详细的打包方法word以及麒麟官网给出的文档都已放网盘,链接在文章最后!!!!!!!!!!!!&a…