神经网络模型实现(训练、测试)

目录

  • 一、神经网络骨架:
  • 二、卷积操作:
  • 三、卷积层:
  • 四、池化层:
  • 五、激活函数(以ReLU为例):
  • 六、模型搭建:
  • 七、损失函数、梯度下降:
  • 八、模型保存与加载:
  • 九、模型训练:
  • 十、模型测试:

一、神经网络骨架:

import torch
from torch import nn#神经网络
class CLH(nn.Module):def __init__(self):super().__init__()def forward(self, input):output=input+1return outputclh = CLH()
x = torch.tensor(1.0)
output = clh(x)
print(output)

二、卷积操作:

import torch
import torch.nn.functional as Finput = torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]
])
#卷积核
kernel = torch.tensor([[1,2,1],[0,1,0],[2,1,0]
])
#变化为卷积规定输入格式的维度:这里二维(x,y)转四维(t,z,x,y)
input = torch.reshape(input,(1,1,5,5))
kernel = torch.reshape(kernel,(1,1,3,3))#对输入矩阵的上下左右进行分别padding扩充1列0,执行一次stride步长为1的卷积
output = F.conv2d(input, kernel, stride=1, padding=1)
print(output)

运行结果:
在这里插入图片描述

执行过程:
在这里插入图片描述

三、卷积层:

卷积在神经网络中用于提取输入数据的特征,通过与卷积核进行卷积操作来实现特征的提取和学习。

import torchvision
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as Ftest_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=False)
#将数据划分为batch,每个batch有64个样本
dataloader = DataLoader(test_data,batch_size=64)#神经网络
class CLH(nn.Module):def __init__(self):super(CLH,self).__init__()#神经网络中设置一个卷积层,in_channels表示输入通道数,out_channels表示输出通道数,并且卷积核尺寸为3×3的随机矩阵self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)def forward(self,input):#对输入数据执行一次二维卷积return self.conv1(input)clh = CLH()#data是一个batch
for data in dataloader:imgs,targets = dataoutput = clh(imgs)#torch.Size([64, 3, 32, 32])表示[batchs大小,每个batch的通道数,每个通道x轴像素数,每个通道y轴像素数]print(imgs.shape)#torch.Size([64, 6, 30, 30])表示[batchs大小,每个batch的通道数,每个通道x轴像素数,每个通道y轴像素数]#其中每个通道由32×32像素变为30×30像素,其余的像素点组合成该batch的其他通道print(output.shape)

在这里插入图片描述

四、池化层:

最大池化的作用是为了保留特征同时将数据量缩小。
例如:1080p图像经过最大池化层变为720p。

import torch
from torch import nn
from torch.nn import MaxPool2d
#输入像素变为tensor类型
input = torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]
],dtype=torch.float32)
#变化为池化规定输入格式的维度:这里二维(x,y)转四维(t,z,x,y)
input = torch.reshape(input,(1,1,5,5))#神经网络
class CLH(nn.Module):def __init__(self):super(CLH,self).__init__()#神经网络中设置一个池化层,ceil_mode表示池化合覆盖输入数据不够时是否计算self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=True)def forward(self,input):#对输入数据执行一次最大池化操作return self.maxpool1(input)#创建神经网络
clh = CLH()output = clh(input)
print(output)

运行结果:
在这里插入图片描述

执行过程:
1

五、激活函数(以ReLU为例):

import torch
from torch import nn
from torch.nn import MaxPool2d, ReLU#输入像素变为tensor类型
input = torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]
],dtype=torch.float32)
#变化为池化规定输入格式的维度:这里二维(x,y)转四维(t,z,x,y)
input = torch.reshape(input,(1,1,5,5))#神经网络
class CLH(nn.Module):def __init__(self):super(CLH,self).__init__()#神经网络中设置一个激活函数self.relu = ReLU()def forward(self,input):#对输入数据执行一次最大池化操作return self.relu(input)#创建神经网络
clh = CLH()output = clh(input)
print(output)

运行结果:
在这里插入图片描述

六、模型搭建:

在这里插入图片描述
上面的卷积神经网络模型共有8个隐藏层(卷积层+池化层)和1个输出层。

from torch import nn
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import MaxPool2d, Flatten, Linear, Sequential
import torch# 神经网络
class CLH(nn.Module):def __init__(self):super(CLH, self).__init__()#搭建模型架构,模型会按顺序从上到下执行self.model1 = Sequential(#卷积层,提取特征Conv2d(3, 32, 5, padding=2),#最大池化层,减小数据维度,防止过拟合MaxPool2d(2),#卷积层Conv2d(32, 32, 5, padding=2),#最大池化层MaxPool2d(2),#卷积层Conv2d(32, 64, 5, padding=2),#最大池化层MaxPool2d(2),#展平层,将输入数据展平为一维Flatten(),#全连接(线性)层,输入1024长度的一维向量,学习输入数据的非线性关系,输出64的特征一维向量Linear(1024, 64),)def forward(self, x):x = self.model1(x)return xclh = CLH()
print(clh)
#构建一个全0的输入,batch大小64,三通道,宽高为32×32
input = torch.ones((64, 3, 32, 32))
output = clh(input)
print(output.shape)

输出结果:
在这里插入图片描述

七、损失函数、梯度下降:

from torch import nn
from torch.nn import Conv2d
from torch.nn import MaxPool2d,Flatten,Linear, Sequential
import torch
import torchvision
from torch.utils.data import DataLoader
#获取数据,CIFAR10是一个十分类数据集
dataset = torchvision.datasets.CIFAR10("./dataset",train= False, transform =torchvision.transforms.ToTensor(), download=False)
#切分数据为batch
dataloader = DataLoader(dataset, batch_size=1, drop_last=True)#神经网络模型搭建
class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32,64,5, padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xtudui = Tudui()#损失函数
loss = nn.CrossEntropyLoss()
#SGD(随机梯度下降)优化器,用于更新模型clh中的参数以最小化损失函数
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)#每个batch(data)执行如下操作
for data in dataloader:#特征和真实标签值imgs, targets = data#训练结果outputs = tudui(imgs)#计算损失result_loss = loss(outputs, targets)#每轮训练梯度清零,避免梯度累积optim.zero_grad()#反向传播计算损失函数关于模型各个参数的梯度(偏导数)result_loss.backward()#根据梯度更新模型参数,使用梯度下降算法进行参数更新optim.step()print(result_loss)

运行结果:
在这里插入图片描述

八、模型保存与加载:

保存:

import torchvision
import torch
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1--保存模型结构及模型参数
torch.save(vgg16,"vgg16_method1.pth")# 保存方式2--仅保存模型参数存为字典,不保存模型结构(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

读取:

import torch
import torchvision
# 保存方式1对应的加载模型结构 + 参数方式
model = torch.load("vgg16_method1.pth")
print(model)# 保存方式2对应的加载模型参数方式
model2 = torch.load("vgg16_method2.pth") #加载的是字典
print(model2)vgg16 = torchvision.models.vgg16(pretrained=False) #为方式2创建模型结构并加载参数的完整写法
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

九、模型训练:

在这里插入图片描述
上面的卷积神经网络模型共有8个隐藏层(卷积层+池化层)和1个输出层。

from torch.utils.tensorboard import SummaryWriterimport torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader# 定义训练的设备
device = torch.device("cuda")
# 准备训练集
train_data = torchvision.datasets.CIFAR10("./dataset", train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 准备测试集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# len()获取数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader加载数据集,一个batch包含64个图片样本
train_dataloader = DataLoader(train_data, batch_size=64, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=64, drop_last=True)# 创建网络模型,这里一般将网络模型单独定义一个文件
class CLH(nn.Module):def __init__(self):super(CLH, self).__init__()self.model = nn.Sequential(#卷积层nn.Conv2d(3, 32, 5, 1, 2),#最大池化层nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),#展平为一维nn.Flatten(),  # 展平后的序列长度为 64*4*4=1024#全连接层nn.Linear(1024, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xclh = CLH()
#设置为GPU训练
clh = clh.to(device)# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)# 优化器(梯度下降算法)对clh.parameters()中的参数进行更新
learning_rate = 1e-2
optimizer = torch.optim.SGD(clh.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 1.记录训练的次数
total_train_step = 0
# 2.记录测试的次数
total_test_step = 0
# 3.训练的轮数
epoch = 10# 添加tensorboard可视化结果
writer = SummaryWriter("../logs_train")#训练epoch轮
for i in range(epoch):print("-------------第 {} 轮训练开始------------".format(i + 1))# 训练步骤开始clh.train()#对于每个batch(data,包含64张图片)for data in train_dataloader:#获取batch中的样本和真实值标签imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)#该batch的训练结果output = clh(imgs)#计算该batch(data)上的损失loss = loss_fn(output, targets)# 优化器优化模型#梯度清零optimizer.zero_grad()#计算损失函数关于每个参数的梯度loss.backward()#梯度下降算法更新参数optimizer.step()total_train_step = total_train_step + 1print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始(感觉更像验证,但是并没有更新一些学习率之类的参数,仅仅进行了测试)clh.eval()total_test_loss = 0total_accuracy = 0#设置禁止更新梯度,防止测试集更新模型参数with torch.no_grad():#对于测试集上的每个batch(data,包含64张图片)for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)#仅计算预测结果而不更新参数(torch.no_grad())outputs = clh(imgs)#计算该batch(data)上的损失loss = loss_fn(outputs, targets)#计算整个测试集上的损失total_test_loss = total_test_loss + loss.item()print("整个测试集上的Loss:{}".format(total_test_loss))writer.add_scalar("test_loss", total_test_loss, total_test_step)total_test_step = total_test_step + 1#每轮(epoch)训练完成后保存该轮的训练模型torch.save(clh, "clh_{}.pth".format(i))print("-------------第{}轮训练结束,模型已保存-------------".format(i + 1))writer.close()

部分执行结果:
在这里插入图片描述
在这里插入图片描述

十、模型测试:

import torchvision
from PIL import Image
import torch
#注意导入模型结构文件
from CLHmodule import *#获取测试样本
image_path = "./dog.png"
image = Image.open(image_path)
print(image)#将测试样本转换为模型规定的格式(和训练集样本尺寸要一样)
image = image.convert('RGB')
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image = transform(image)
print(image)
image = torch.reshape(image,(1,3,32,32))#加载模型
model = torch.load("clh_0.pth",map_location=torch.device("cuda"))
print(model)#执行预测
model.eval()
with torch.no_grad():image = image.to("cuda")output = model(image)
#预测结果,十分类结果为10个值的一维向量,表示各个分类上的可能性
print(output)
#输出可能性最大的结果
print(output.argmax(1))

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48380.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Oracle select for update 用法

SELECT FOR UPDATE 用法 1、SELECT…FOR UPDATE 语法 SELECT … FOR UPDATE [OF column_list][WAIT n|NOWAIT][SKIP LOCKED]; 其中:   OF 子句用于指定即将更新的列,即锁定行上的特定列。   WAIT 子句指定等待其他用户释放锁的秒数,防止…

基于RFID的课堂签到系统设计

1.简介 基于RFID的课堂签到系统设计是一种利用无线射频识别(RFID)技术实现课堂自动签到的系统。这种系统通过RFID标签(通常是学生携带的卡片或手环等)与安装在教室内的RFID读写器之间的无线电信号进行数据交换,从而实现…

移动设备安全革命:应对威胁与解决方案

移动设备已成为我们日常工作和家庭生活中不可或缺的工具,然而,对于它们安全性的关注和投资仍然远远不够。本文深入分析了移动设备安全的发展轨迹、目前面临的威胁态势,以及业界对于这些安全漏洞响应迟缓的深层原因。文中还探讨了人们在心理层…

Java跨平台的原理是什么?JDK,JRE,JVM三者的作用和区别?xxx.java和xxx.class有什么区别?看这一篇就够了

目录 1. Java跨平台相关问题 1.1 什么是跨平台(平台无关性)? 1.2 跨平台(平台无关性)的好处? 1.3 编译原理基础(Java程序编译过程) 1.4Java跨平台的是实现原理? 1.4.1 JVM(Java虚拟机) 1.4.2 Class文件 1.4.3 …

485开关量采集模块16路I/O输入输出ModbusRTU协议—DAM-3950A

品牌:阿尔泰科技 型号:DAM-3950A 概述: DAM-3950A为16路隔离数字量输入,6路C型10路A型信号继电器输出模块,RS485通讯接口,带有标准ModbusRTU协议。配备良好的人机交互界面,使用方便&#xff…

Linux 文件安装的mysql 启动

1、找到my.cnf 2、确定文件类容: 并确保这些重要的配置:basedir 、datadir、socket 文件或目录都存在 3、找到mysqld 位置 4、启动mysqld mysqld --defaults-file/etc/my.cnf --usermysql

使用Spring的 Environment 和 ConfigurableEnvironment 来在springboot应用启动过程中对属性进行修改。

1.修改yml文件的问题 我在上一篇文章中介绍了使用《java启动springboot项目前根据环境变量动态改编yaml文件的变量值》但是在FC启动后发现存在问题 springboot初始化时需要初始化数据库,这时就会存在数据库连接在yml文件修改之前,这就会导致链接数据库…

feign 报错 Connection reset executing POST

feign 连接异常: feign.RetryableException: Connection reset executing POST替换 feign的 client : Feign在默认情况下使用的是JDK原生的 URLConnection 发送HTTP请求,没有连接池。 可以尝试替换成 httpclient 或者 okhttp。调整最大连接…

c++基础(类和对象中)(类的默认成员函数)

目录 一.构造函数(类似初始化) 1.概念 2.构造函数的特点 二.析构函数(类似 销毁对象/空间) 三.拷贝构造函数(类似复制粘贴的一种 初始化 ) 1.概念: 2.拷贝构造的特点: 四.赋值运算符重载&#xff08…

level 6 day2-3 网络基础2---TCP编程

1.socket(三种套接字:认真看) 套接字就是在这个应用空间和内核空间的一个接口,如下图 原始套接字可以从应用层直接访问到网络层,跳过了传输层,比如在ubtan里面直接ping 一个ip地址,他没有经过TCP或者UDP的数…

Android 11 使用HAL层的ffmpeg库(1)

1.frameworks/av/media目录下面的修改 From edd6f1374c1f15783d9920ebda22ea915e503775 Mon Sep 17 00:00:00 2001 From: GW00219471 <zhumingxingnoboauto.com> Date: Wed, 17 Jan 2024 15:16:10 0800 Subject: [PATCH] ?UTF-8?q?[V35CUX-4542]:E7A7BBE6A48Dcux20E8…

华为OD机试(C卷,200分)- 二叉树计算

题目描述 给出一个二叉树如下图所示&#xff1a; 请由该二叉树生成一个新的二叉树&#xff0c;它满足其树中的每个节点将包含原始树中的左子树和右子树的和。 左子树表示该节点左侧叶子节点为根节点的一颗新树&#xff1b;右子树表示该节点右侧叶子节点为根节点的一颗新树。…

VS+QT 打包可执行文件.exe

切换成release版本&#xff0c;同时更改项目属性中release配置下的各个属性&#xff0c;确保匹配 重新生成解决方案&#xff0c;将生成的.exe复制到一个空白文件夹中 执行&#xff1a; cd D:\QT\5.12.10\msvc2015_64\binwindeployqt C:\Users\DELL\Desktop\serials\MainWind…

食南之徒~马伯庸

◆ 第一章 >> 老赵&#xff0c;这你就不懂了。过大于功&#xff0c;要受罚挨打&#xff0c;不合算&#xff1b;功大于过&#xff0c;下回上司有什么脏活累活&#xff0c;第一时间会想到你&#xff0c;也是麻烦多多。只有功过相抵&#xff0c;上司既挑不出你的错&#xf…

MMU(内存管理单元)

概述 MMU 即内存管理单元&#xff0c;是用硬件电路逻辑实现的一个地址转换器件&#xff0c;它负责接受虚拟地址和地址关系转换表&#xff0c;以及输出物理地址 线性地址 由于保护模式的内存模型是分段模型&#xff0c;它并不适合于 MMU 的分页模型&#xff0c;所以我们要使用…

springcloud-config客户端启用服务发现报错找不到bean EurekaHttpClient

背景 在对已有项目进行改造的时候&#xff0c;集成SpringConfigStarter&#xff0c;编写完bootstrap.yml&#xff0c;在idea 启动项中编辑并新增VM options -Dspring.cloud.config.discovery.enabledtrue&#xff0c;该版本不加spring不会从configService获取信息&#xff0c;…

三、GPIO口

我们在刚接触C语言时&#xff0c;写的第一个程序必定是hello world&#xff0c;其他的编程语言也是这样类似的代码是告诉我们进入了编程的世界&#xff0c;在单片机中也不例外&#xff0c;不过我们的传统就是点亮第一个LED灯&#xff0c;点亮电阻&#xff0c;电容的兄弟&#x…

[CSS] 浮动布局的深入理解与应用

文章目录 浮动的简介元素浮动后的特点解决浮动产生的影响浮动后的影响解决浮动产生的影响 浮动相关属性实际应用示例示例1&#xff1a;图片与文字环绕示例2&#xff1a;多列布局示例3&#xff1a;响应式布局 总结 浮动布局是CSS中一种非常强大的布局方式&#xff0c;最初设计用…

java项目(knife4j使用,静态资源未放在static资源包下,公共字段自动填充,Spring Cache与Spring Task)

Knife4j&#xff08;生成接口文档&#xff09; 使用swagger你只需要按照它的规范去定义接口及接口相关的信息&#xff0c;就可以做到生成接口文档&#xff0c;以及在线接口调试页面。官网:https://swagger.io/ Knife4j是为Java MVC框架集成Swagger生成Api文档的增强解决方案。…