PyTorch之完整的神经网络模型训练

简单的示例:

在PyTorch中,可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码:

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义神经网络的层和参数self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(32 * 14 * 14, 128)self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.softmax(x)return x

在上面的示例中,定义了一个名为MyModel的神经网络模型,继承自nn.Module类。在__init__方法中,我们定义了模型的层和参数。具体来说:

  • 代码定义了一个卷积层,输入通道数为1,输出通道数为32,卷积核大小为3x3,步长为1,填充为1。
  • 定义了一个ReLU激活函数,用于在卷积层之后引入非线性性质。
  • 定义了一个最大池化层,池化核大小为2x2,步长为2。
  • 定义了一个全连接层,输入大小为32x14x14(经过卷积和池化后的特征图大小),输出大小为128。
  • 定义了另一个全连接层,输入大小为128,输出大小为10。
  • 定义了一个softmax函数,用于将模型的输出转换为概率分布。

forward方法中,定义了模型的前向传播过程。具体来说:

  • x = self.conv1(x): 将输入张量传递给卷积层进行卷积操作。
  • x = self.relu(x): 将卷积层的输出通过ReLU激活函数进行非线性变换。
  • x = self.maxpool(x): 将ReLU激活后的特征图进行最大池化操作。
  • x = x.view(x.size(0), -1): 将池化后的特征图展平为一维,以适应全连接层的输入要求。
  • x = self.fc1(x): 将展平后的特征向量传递给第一个全连接层。
  • x = self.relu(x): 将第一个全连接层的输出通过ReLU激活函数进行非线性变换。
  • x = self.fc2(x): 将第一个全连接层的输出传递给第二个全连接层。
  • x = self.softmax(x): 将第二个全连接层的输出通过softmax函数进行归一化,得到每个类别的概率分布。

这个示例展示了一个简单的卷积神经网络模型,适用于处理单通道的图像数据,并输出10个类别的分类结果。可以根据自己的需求和数据特点来定义和修改神经网络模型。

接下来将用于实际的数据集进行训练:

以下是基于CIFAR10数据集的神经网络训练模型:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from nn_mode import *#准备数据集
train_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=False,transform=torchvision.transforms.ToTensor())
#输出数据集的长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print(train_data_size)
print(test_data_size)
#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=64)
test_loader=DataLoader(dataset=test_data,batch_size=64)
#创建神经网络
sjnet=Sjnet()#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learn_lr=0.01#便于修改
YHQ=torch.optim.SGD(sjnet.parameters(),lr=learn_lr)#设置训练网络的参数
train_step=0#训练次数
test_step=0#测试次数
epoch=10#训练轮数writer=SummaryWriter('wanzheng_logs')for i in range(epoch):print("第{}轮训练".format(i+1))#开始训练for data in train_loader:imgs,targets=dataoutputs=sjnet(imgs)loss=loss_fn(outputs,targets)#优化器YHQ.zero_grad()  # 将神经网络的梯度置零,以准备进行反向传播loss.backward()  # 执行反向传播,计算神经网络中各个参数的梯度YHQ.step()  # 调用优化器的step()方法,根据计算得到的梯度更新神经网络的参数,完成一次参数更新train_step =train_step+1if train_step%100==0:print('训练次数为:{},loss为:{}'.format(train_step,loss))writer.add_scalar('train_loss',loss,train_step)#开始测试total_loss=0with torch.no_grad():#上下文管理器,用于指示在接下来的代码块中不计算梯度。for data in test_loader:imgs,targets=dataoutputs = sjnet(imgs)loss = loss_fn(outputs, targets)#使用损失函数 loss_fn 计算预测输出与目标之间的损失。total_loss=total_loss+loss#将当前样本的损失加到总损失上,用于累积所有样本的损失。print('整体测试集上的loss:{}'.format(total_loss))writer.add_scalar('test_loss', total_loss, test_step)test_step = test_step+1torch.save(sjnet,'sjnet_{}.pth'.format(i))print("模型已保存!")writer.close()

 其神经网络训练以及测试时的损失值使用TensorBoard进行展示,如图所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/732774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运维打工人,兼职跑外卖的第二个周末

北京,晴,西南风1级。 前序 今天天气还行,赶紧起来,把衣服都洗洗,准备准备,去田老师吃饭早饭了。 一个甜饼、一个茶叶蛋、3元自助粥花费7.5。5个5挺吉利的。 跑外卖的意义 两个字减肥,记录刚入…

基于最小二乘递推算法的系统参数辨识matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于最小二乘递推算法的系统参数辨识。对系统的参数a1,b1,a2,b2分别进行估计,计算估计误差以及估计收敛曲线&#…

如何在Windows中对硬盘进行分区?这里有详细步骤

本文介绍如何在Windows11、10、8、7、Vista和XP中对硬盘进行分区 如果这个过程听起来比你想象的要复杂一点,不要担心,因为事实并非如此。在Windows中对硬盘进行分区一点也不难,通常只需要几分钟。以下是操作方法。 注意:这些说明适用于Windows 11、Windows 10、Windows 8…

腾讯云轻量应用服务器流量用完了怎么办?

腾讯云轻量服务器流量用完了怎么办?超额流量另外支付流量费,流量价格为0.8元/GB,会自动扣你的腾讯云余额,如果你的腾讯云账号余额不足,那么你的轻量应用服务器会面临停机,停机后外网无法访问,继…

js【详解】Promise

为什么需要使用 Promise ? 传统回调函数的代码层层嵌套,形成回调地狱,难以阅读和维护,为了解决回调地狱的问题,诞生了 Promise 什么是 Promise ? Promise 是一种异步编程的解决方案,本身是一个构…

自然语言处理之语言模型(LM)介绍

自然语言处理(Natural Language Processing,NLP)是人工智能(Artificial Intelligence,AI)的一个重要分支,它旨在使计算机能够理解、解释和生成人类语言。在自然语言处理中,语言模型&…

阿珊详解Vue Router的守卫机制

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

【漏洞复现】Salia PLCC cPH2 远程命令执行漏洞(CVE-2023-46359)

0x01 漏洞概述 Salia PLCC cPH2 v1.87.0 及更早版本中存在一个操作系统命令注入漏洞,该漏洞可能允许未经身份验证的远程攻击者通过传递给连接检查功能的特制参数在系统上执行任意命令。 0x02 测绘语句 fofa:"Salia PLCC" 0x03 漏洞复现 ​…

video视频播放

1.列表页面 <template><div><ul><li class"item" v-for"(item,index) in list" :key"index" click"turnPlay(item.videoUrl)"><img :src"item.img" alt""><div class"btn…

套接字编程 --- 一

目录 1. 预备知识 1.1. 端口号 1.2. 认识TCP协议 1.3. 认识UDP协议 1.4. 网络字节序 2. socket 2.1. socket 常见系统调用 2.1.1. socket 系统调用 2.1.2. bind 系统调用 2.1.3. recvfrom 系统调用 2.1.4. sendto系统调用 2.3. 其他相关接口 2.3.1. bzero 2.3.2…

力扣:17. 电话号码的字母组合

力扣&#xff1a;17. 电话号码的字母组合 描述 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 示例 1&#xff1a; 输…

Linux——文件重定向

目录 前言 一、重定向 二、重定向的运用 三、dup2 四、命令行中的重定向 五、为什么要有标准错误 前言 在之前我们学习了文件标识符&#xff0c;直到close可以使用文件标识符进行关闭&#xff0c;但是当我们关闭1号&#xff08;stdout&#xff09;时&#xff0c;无法往显…

00在linux环境下搭建stm32开发环境

文章目录 前言一、环境搭建1.arm-none-eabi-gcc2.openocd 三、创建stm32标准库工程1.创建工程目录2.修改stm32_flash.ld文件3.写makefile文件4.修改core_cm3.c5.写main函数并下载到板子上 最后 前言 我在那天终于说服自己将系统换成了linux系统了&#xff0c;当换成了linux系统…

UE5.1_使用技巧(常更)

UE5.1_使用技巧&#xff08;常更&#xff09; 1. 清除所有断点 运行时忘记蓝图中的断点可能会出现运行错误的可能&#xff0c;务必运行是排除一切断点&#xff0c;逐个排查也是办法&#xff0c;但是在事件函数多的情况下会很复杂且慢节奏&#xff0c;学会一次性清除所有很有必…

JavaWeb--Mybatis

一&#xff1a;Mybatis概述 1.Mybatis概念 MyBatis 是一款优秀的 持久层框架 &#xff0c;用于简化 JDBC 开发&#xff1b; MyBatis 本是 Apache 的一个开源项目 iBatis, 2010 年这个项目由 apache software foundation 迁移到了 google code&#xff0c;并且改名为 MyB…

OpenTenBase 开发环境搭建及Debug设置

最近有个 OpenTenBase开源核心贡献挑战赛 领导建议大家都去试试&#xff0c;我也去凑了下热闹&#xff0c;发现能力有限一时半会是搞不明白了&#xff0c;最多也就是能搞搞文档翻译&#xff0c;或者写点操作手册啥的。 不过不管怎么样&#xff0c;先把开发环境搭上&#xff0c;…

R语言的数据类型与数据结构:向量、列表、矩阵、数据框及操作方法

R语言的数据类型与数据结构&#xff1a;向量、列表、矩阵、数据框及操作方法 介绍向量列表矩阵数据框 介绍 R语言拥有丰富的数据类型和数据结构&#xff0c;以满足各类数据处理和分析的需求。本文将分享R语言中的数据类型&#xff0c;包括向量、列表、矩阵、数据框等&#xff…

vue组件之间通信方式汇总

方式1&#xff1a;props和$emit props和$emit仅仅限制在父子组件中使用 1.props&#xff1a;父组件向子组件传递数据 1.1 代码展示 <template><div><!-- 这是父组件 --><div>父组件中的基本数据类型age的值是:{{this.age}}</div><div>…

giffgaff怎么充值?giffgaff怎么续费?

-性价比高&#xff1a;0月租&#xff0c;免费接收短信&#xff0c;充值一次&#xff0c;接码可以用20年以上&#xff08;仅需半年保号一次&#xff09;&#xff0c;可能是国内性价比最高的接码实体卡&#xff01;-安全&#xff1a;实体卡无须担心因号码被风控&#xff0c;还可以…

面试经典150题【61-70】

文章目录 面试经典150题【61-70】61.旋转链表86.分隔链表104. 二叉树的最大深度100.相同的树226.翻转二叉树101.对称二叉树105.从前序与中序遍历序列构造二叉树106.从后序和中序遍历序列构造二叉树117.填充每个节点的下一个右侧节点指针II114.二叉树展开为链表 面试经典150题【…