10-pytorch-完整模型训练

b站小土堆pytorch教程学习笔记

一、从零开始构建自己的神经网络

1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10('dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()#损失函数
loss_fn=nn.CrossEntropyLoss()#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数#添加tensorboard
writer=SummaryWriter('logs/train')for i in range(10):print('-------第{}轮训练开始-------'.format(i+1))for data in train_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)#优化器优化模型optimizer.zero_grad()#梯度清零loss.backward()#反向传播计算梯度optimizer.step()#参数优化total_train_step=total_train_step+1if total_train_step % 100==0:#逢100打印print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字writer.add_scalar('train_loss',loss.item(),total_train_step)#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()print('所有测试集上的损失:{}'.format(total_test_loss))writer.add_scalar('test_loss',total_test_loss,total_test_step)total_test_step+=1#保存每一轮模型torch.save(han,'han_{}.pth'.format(i))print('模型已保存')
writer.close()
import torch
from torch import nnclass Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':han=Han()input=torch.ones(64,3,32,32)output=han(input)print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率

结果如下:
在这里插入图片描述

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0total_acc=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()acc=(output.argmax(1)==target).sum()#(1)横着看total_acc+=accprint('所有测试集上的损失:{}'.format(total_test_loss))print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)total_test_step+=1

整体测试集上的正确率:0.27480000257492065

3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()

二、使用GPU

网络模型、数据(输入、标注)、损失函数调用cuda()

1.方式1
#模型
if torch.cuda.is_available():han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/701809.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何运用Mybatis Genertor

MyBatis Generator是一个MyBatis的代码生成器,它可以帮助我们快速生成Mapper接口以及对应的XML文件和模型类。在Java开发中,能大大提升开发效率。本文将介绍如何在IntelliJ IDEA中使用MyBatis Generator。 1. 添加MyBatis Generator依赖 我们首先需要在…

计网 - 深入理解HTTPS:加密技术的背后

文章目录 Pre发展历史Http VS HttpsHTTPS 解决了 HTTP 的哪些问题HTTPS是如何解决上述三个风险的混合加密摘要算法 数字签名数字证书 Pre PKI - 数字签名与数字证书 PKI - 借助Nginx 实现Https 服务端单向认证、服务端客户端双向认证 发展历史 HTTP(超文本传输协…

代码随想录算法训练营第二十五天补|216.组合总和III ● 17.电话号码的字母组合

组合问题:集合内元素的组合,不同集合内元素的组合 回溯模板伪代码 void backtracking(参数) {if (终止条件) {存放结果;return;}for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {处理节点;backtrackin…

人工智能绘画的时代下到底是谁在主导,是人类的想象力,还是AI的创造力?

#ai作画 目录 一.AI绘画的概念 1. 数据集准备: 2. 模型训练: 3. 生成绘画: 二.AI绘画的应用领域 三.AI绘画的发展 四.AI绘画背后的技术剖析 1.AI绘画的底层原理 2.主流模型的发展趋势 2.1VAE — 伊始之门 2.2GAN 2.2.1GAN相较于…

深度学习系列60: 大模型文本理解和生成概述

参考网络课程:https://www.bilibili.com/video/BV1UG411p7zv/?p98&spm_id_frompageDriver&vd_source3eeaf9c562508b013fa950114d4b0990 1. 概述 包含理解和分类两大类问题,对应的就是BERT和GPT两大类模型;而交叉领域则对应T5 2.…

【C++精简版回顾】9.static

1.static修饰成员类型 1.类外初始化&#xff0c;初始化时不需要static修饰(不能修饰)&#xff0c;要有类名限定 2.静态成员是属于类的&#xff0c;全对象公有 1.class class MM { public:MM(string name) {size;a size;this->name name;}void print() {cout << &quo…

瑞_23种设计模式_桥接模式

文章目录 1 桥接模式&#xff08;Bridge Pattern&#xff09;1.1 介绍1.2 概述1.3 桥接模式的结构 2 案例一2.1 需求2.2 代码实现 3 案例二2.1 需求2.1 代码实现 &#x1f64a; 前言&#xff1a;本文章为瑞_系列专栏之《23种设计模式》的桥接模式篇。本文中的部分图和概念等资料…

【MySQL】连接查询和自连接的学习和总结

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-x4sPmqTXA4yupW1n {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

限流算法

下面对常见的限流算法进行讨论。目前&#xff0c;常用的限流算法主要有三种&#xff1a;计数器法、滑动窗口算法、漏桶算法和令牌桶算法。下面分别介绍其原理。 1. 计数器法 计数器法是通过计数对到来的请求进行选择性处理。如系统限制一秒内最多有X个请求&#xff0c;则在该…

《艾尔登法环 黄金树幽影》是什么?Mac电脑怎么玩《艾尔登法环》艾尔登法环下载

全体起立&#xff0c;《艾尔登法环 》最新DLC《黄金树幽影》将在6月21日发布&#xff0c;steam售价198元&#xff0c;现在就可以预订了。宫崎英高在接受FAMI通的采访时表示&#xff0c;新DLC的体量远超《黑暗之魂》和《血源诅咒》资料片。好家伙&#xff0c;别人是把DLC续作&am…

IO进程线程:通信

1.定义互斥锁 #include<myhead.h>int num520;//临界资源//1.创建一个互斥锁变量 pthread_mutex_t mutex;//定义任务&#xff11;函数 void *task1(void *arg) {printf("11111111111111\n");//3.获取锁资源pthread_mutex_lock(&mutex);num1314;sleep(3);pr…

EasyRecovery 16数据恢复软件功能介绍及2024 年最新easyrecover激活密钥?

EasyRecovery Photo16 for windows数据恢复软件免费版下载是一款由Kroll Ontrack公司开发的数据恢复软件&#xff0c;其主要功能是恢复已经删除或损坏的图片文件。该软件可用于恢复各种类型的图片文件&#xff0c;包括JPEG、GIF、BMP、PNG等&#xff0c;同时也支持恢复照片文件…

python-pyecharts画饼图

pyecharts饼图 from pyecharts import options as opts from pyecharts.charts import Pie# 构造数据 data [("A", 10),("B", 20),("C", 30),("D", 40),("E", 50) ]# 实例化饼图 pie Pie()# 添加数据 pie.add("&qu…

【Java多线程】对线程池的理解并模拟实现线程池

目录 1、池 1.1、线程池 2、ThreadPoolExecutor 线程池类 3、Executors 工厂类 4、模拟实现线程池 1、池 “池”这个概念见到非常多&#xff0c;例如常量池、数据库连接池、线程池、进程池、内存池。 所谓“池”的概念就是&#xff1a;&#xff08;提高效率&#xff09; 1…

C语言第三十弹---自定义类型:结构体(上)

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 结构体 1、结构体类型的声明 1.1、结构体回顾 1.1.1、结构的声明 1.1.2、结构体变量的创建和初始化 1.2、结构的特殊声明 1.3、结构的自引用 2、结构体内存…

K8S—集群调度

目录 前言 一 List-Watch 1.1 list-watch概述 1.2 list-watch工作机制 二 集群调度 2.1 调度过程 2.2 Predicate 和 Priorities 的常见算法和优先级选项 2.3 调度方式 三 亲和性 3.1 节点亲和性 3.2 Pod 亲和性 3.3 键值运算关系 3.4 Pod亲和性与反亲和性 3.5 示例…

音视频数字化(数字与模拟-电影)

针对电视屏幕,电影被称为“大荧幕”,也是娱乐行业的顶尖产业。作为一项综合艺术,从被发明至今,近200年的发展史中,无人可以替代,并始终走在时代的前列。 电影回放的原理就是“视觉残留”,也就是快速移过眼前的画面,会在人的大脑中残留短暂的时间,随着画面不断地移过,…

暑期宅家?计算机专业必看的8部电影!一定要安利给你们!

代码编程看上去枯燥乏味&#xff0c;但也是艺术的&#xff0c;感性的&#xff0c;计算机编程的许多概念被应用于电影中&#xff0c;其中有些非常之酷炫&#xff0c;它们甚至能帮助开发人员理解一些编程概念。 所以今天学姐来给大家推荐几部心中top级的编程人必看电影&#xff0…

nginx(二)

nginx的验证模块 输入用户名和密码 第一步先下载httpd 这个安装包 第二步编辑子配置文件 然后去网页访问192.168.68.3/admin/ 连接之后&#xff0c;会出现404&#xff0c;404出现是因为没给网页写页面 如果要写页面&#xff0c;则在/opt/html&#xff0c;建立一个admin&#x…

max_element和min_element使用

头文件 #include<alorithm> 作用 用于返回数组或容器中最值元素(最小值、最大值)&#xff0c;值和下标。 使用举例 #include<iostream> #include<vector> #include<algorithm> using namespace std; int main() {/*数组初始化*/vector<int>…