pytorch -- CIFAR10 完整的模型训练套路

  1. 网络结构
    在这里插入图片描述
  2. 代码
# CIFAR 10
'''
完整的模型训练套路:'''
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为{}'.format(train_data_size))
print('测试数据集的长度为{}'.format(test_data_size))# 2 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 3 搭建神经网络
# 4 创建网络模型
tudui = Tudui()# 5 损失函数
loss_fn = nn.CrossEntropyLoss()# 6 优化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 7 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 #训练轮数
# 添加tensorboard
writer = SummaryWriter('logs_model')for i in range(epoch):print('-----------第{}轮训练开始-----------'.format(i+1))# 训练开始# 训练步骤开始 dropout batchNorm仅对某些层次有作用tudui.train()for data in train_dataloader:imgs, targets = dataoutput = tudui(imgs) #训练模型的预测输出loss = loss_fn(output,targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print('训练次数是{}时,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor变成了数字writer.add_scalar('train_loss',loss.item(),total_train_step)# 训练完一轮,看是否训练好,有没有达到想要的需求,测试数据集中跑一篇看准确率或者损失# 测试步骤开始tudui.eval()total_test_loss = 0total_accuracy = 0# 测试不需要对梯度进行调整with torch.no_grad():for data in test_dataloader:imgs,targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs,targets)total_test_loss += loss.item()# accuracy 正确预测的样本数量accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint('整体测试集上的loss是{}'.format(total_test_loss))print('整体测试集上的正确率是{}'.format(total_accuracy/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_accuracy', total_accuracy, total_test_step)total_test_step+=1torch.save(tudui,'tudui_{}.pth'.format(i))print('模型已保存')writer.close()
# model.py
import torch
from torch import nn# 3 搭建神经网络
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024,64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()# 验证一下输入输出尺寸input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/703727.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据库】mybatis使用总结

文章目录 1. 批量插入、检索<foreach> 1. 批量插入、检索<foreach> <insert id"insertSystemService" >insert into SYSTEMINFO_SERVICE (system_code, service_id, add_user, add_time)values<foreach collection"serviceList" ite…

SSM框架学习笔记07 | Spring MVC入门

文章目录 1. HTTP协议2. Spring MVC2.1. 三层架构2.2. MVC&#xff08;解决表现层的问题&#xff09;2.3. 核心组件 3. Thymeleaf3.1. 模板引擎3.2. Thymeleaf3.3. 常用语法 代码 1. HTTP协议 网址&#xff1a;https://www.ietf.org/ &#xff08;官网网址&#xff09; https:…

C++标准头文件汇总及功能说明

文章目录 algorithmbitsetcctypecerrnoclocalecmathcstdioctimedequeiostreamexceptionfstreamfunctionallimitslistmapiosiosfwdsetsstreamstackstdexceptstreambufcstringutilityvectorcwcharcwctype algorithm algorithm头文件是C的标准算法库&#xff0c;它主要用在容器上。…

dolphinscheduler单机版部署教程

文章目录 前言一、安装准备1. 安装条件2. 安装jdk3. 安装MySQL 二、安装dolphinscheduler1. 下载并解压dolphinscheduler2. 修改配置文件2.1 修改 dolphinscheduler_env.sh 文件2.2 修改 application.yaml 文件 3. 配置mysql数据源3.1 修改MySQL安全策略3.2 查看数据库3.3 创建…

算法训练 安慰奶牛

问题描述&#xff08;题目链接&#xff09; Farmer John变得非常懒&#xff0c;他不想再继续维护供奶牛之间供通行的道路。道路被用来连接N个牧场&#xff0c;牧场被连续地编号为1到N。每一个牧场都是一个奶牛的家。FJ计划除去P条道路中尽可能多的道路&#xff0c;但是还要保持…

使用Docker部署MinIO并结合内网穿透实现远程访问本地数据

文章目录 前言1. Docker 部署MinIO2. 本地访问MinIO3. Linux安装Cpolar4. 配置MinIO公网地址5. 远程访问MinIO管理界面6. 固定MinIO公网地址 前言 MinIO是一个开源的对象存储服务器&#xff0c;可以在各种环境中运行&#xff0c;例如本地、Docker容器、Kubernetes集群等。它兼…

Pytest教程:一种利用 Python Pytest Hook 机制的软件自动化测试网络数据抓包方法

随着计算机技术的发展&#xff0c;使得网络应用的数量不断增加&#xff0c;因此网络数据抓包成为了网络应用开发和测试中非常重要的一部分。目前&#xff0c;已有许多网络数据抓包工具可供使用&#xff0c;例如 Wireshark、Tcpdump、Fiddler 等&#xff0c;但这些工具需要手动配…

快速排序 quicksort

参考视频&#xff1a; 快速排序算法_哔哩哔哩_bilibili #include <stdio.h>void QuickSort(int *arr,int L,int R); int main() {int arr[3] {1000,2,3};QuickSort(arr,0,2);for(int i 0 ; i < 3 ; i){printf("%d ",arr[i]);}return 0; } void QuickSor…

进程的学习

进程基本概念: 1.进程: 程序&#xff1a;存放在外存中的一段数据组成的文件 进程&#xff1a;是一个程序动态执行的过程,包括进程的创建、进程的调度、进程的消亡 2.进程相关命令: 1.top 动态查看当前系统中的所有进程信息&#xff08;根据CPU占用率排序&#xf…

PHY6222系统级SOC蓝牙芯片低功耗高性能蓝牙MESH组网智能家居

简介 PHY6222是一款支持BLE 5.2功能和IEEE 802.15.4通信协议的系统级芯片&#xff08;SoC&#xff09;&#xff0c;集成了超低功耗的高性能多模射频收发机&#xff0c;搭载32-bit ARM?Cortex?-M0处理器&#xff0c;提供64K retention SRAM、可选128K-8M Flash、96KB ROM以及2…

基于单片机的多关节机械臂抓取系统

摘 要:在农业发展过程中,果实采摘是极度耗费人力的工作。为了减少农业生产过程中的人工成本,将人工智能应用于农业领域将是一种有效手段。基于单片机的控制设计出一款智能抓取系统,拥有六关节高自由度机械臂;爪子采用柔性材料,在加强爪子和果实贴合度的情况下减少对果实的…

微信小程序-人脸检测

微信小程序的人脸检测功能&#xff0c;配合蓝牙&#xff0c;配合ESP32 可以实现一些有趣的玩具 本文先只说微信小程序的人脸检测功能 1、人脸检测使用了摄像头&#xff0c;就必须在用户隐私权限里面声明。 修改用户隐私声明后&#xff0c;还需要等待审核&#xff0c;大概一天 …

十、线性代数二-线性相关

目录 1、线性相关的概念&#xff1a; 2、线性相关的代数表示&#xff1a; 3、线性相关的判断方法&#xff1a; 理解&#xff1a;线性相关指的是 向量组&#xff08;α1&#xff0c;α2&#xff0c;α3&#xff0c;...&#xff09;的 秩是 小于 k 的元数的&#xff0c;即齐次…

重磅福利!攻击面管理平台免费试用

活动时间&#xff1a;2024 年 2 月 26 日至 2024 年 6 月 1 日 活动内容&#xff1a;所有新注册的长亭云图极速版用户&#xff0c;即可享受 1 个月专业版试用&#xff0c;价值 2000 元&#xff01; 活动详情&#xff1a; ● 专业版试用期间&#xff0c;用户可享受以下权益&…

第二节:Vben Admin 登录逻辑梳理和对接后端准备

系列文章目录 上一节&#xff1a;第一节&#xff1a;Vben Admin介绍和初次运行 文章目录 系列文章目录前言项目路径的概述一、登录逻辑梳理loginApi接口查看Mock 二、后端程序对接准备关闭Mock 总结 前言 第一节&#xff0c;我们已经配置了前端环境&#xff0c;运行起来了我们…

文献速递:深度学习--深度学习方法用于帕金森病的脑电图诊断

文献速递&#xff1a;深度学习–深度学习方法用于帕金森病的脑电图诊断 01 文献速递介绍 人类大脑在出生时含有最多的神经细胞&#xff0c;也称为神经元。这些神经细胞无法像我们身体的其他细胞那样自我修复。随着年龄的增长&#xff0c;神经元逐渐死亡&#xff0c;因此变得…

el-form 表单文本标签label增加tooltip提示图标

需求&#xff1a;在el-form表单中&#xff0c;el-form-item的文本标签处增加提示语&#xff1b; 标签&#xff1a;el-form、el-form-item、el-tooltip&#xff1b; 实现&#xff1a; <el-form-item prop"basicScore"><span slot"label"><…

C 标准库 - <stdio.h> 详解

在 C 语言中&#xff0c;stdio.h 是一个非常重要的头文件&#xff0c;定义了一系列用于输入和输出的函数、变量和宏。本文将逐一介绍 stdio.h 中定义的函数&#xff0c;并提供每个函数的完整示例。 变量类型 在 stdio.h 中定义了三个变量类型&#xff1a; size_t&#xff1a…

nginx之状态页 日志分割 自定义图表 证书

5.1 网页的状态页 基于nginx 模块 ngx_http_stub_status_module 实现&#xff0c;在编译安装nginx的时候需要添加编译参数 --with-http_stub_status_module&#xff0c;否则配置完成之后监测会是提示语法错误注意: 状态页显示的是整个服务器的状态,而非虚拟主机的状态 server{…

【Git】Git命令的学习与总结

本文实践于 Learn Git Branching 这个有趣的 Git 学习网站。在该网站&#xff0c;可以使用 show command 命令展示所有可用命令。你也可以直接访问网站的sandbox&#xff0c;自由发挥。 一、本地篇 基础篇 git commit git commit将暂存区&#xff08;staging area&#xff…