全连接神经网络:分类与回归示例

分类

创建测试数据

import random
import torch
import torch.utils.datadef get_rectangle():"""随机得到矩形的宽和高,值域0-1之间的小数,判断这是否是一个"胖"的矩形:return:"""width = random.random()height = random.random()fat = int(width >= height)return width, height, fat
width, height, fat=get_rectangle()
print(width, height, fat)

定义数据集(torch.utils.data.Dataset)

定义数据集一般是创建一个class继承torch.utils.data.Dataset,在这个class里面要定义三个函数,分别是init、len、getitem。init一般用于数据集的初始化,预处理等操作;len函数要输出这个数据集有多少条数据,按理来说我们这个测试数据是动态生成的,理论上来说有无穷多条,但在这样还是要给pytorch一个明确的数量;getitem函数是要根据序号i来获取一条数据。

class Dataset(torch.utils.data.Dataset):#正常应该在这里执行数据的加载,处理等工作def __init__(self):pass#定义数据的条数def __len__(self):return 500#根据序号i,获取一条数据def __getitem__(self, i):#获取一个矩形的数据width, height, fat = get_rectangle()#定义宽高为x,定义是否胖为yx = torch.FloatTensor([width, height])y = fatreturn x, y
dataset = Dataset()print(len(dataset), dataset[0])

500 (tensor([0.0132, 0.6463]), 0)

这里我们查看了dataset的数量,并查看了第0条数据。

数据遍历工具loader

这个loader是一个数据的加载器,我们把数据集传给dataset,并每八条数据为一个批次,然后我们打乱数据集当中的顺序,先前我们定义了500条数据,并非是8的整数倍,drop_last不足时直接丢弃。

loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=8,shuffle=True,drop_last=True)
print(len(loader), next(iter(loader)))

62 [tensor([[0.4461, 0.1130],
        [0.6130, 0.8681],
        [0.5334, 0.5767],
        [0.9663, 0.4436],
        [0.6687, 0.5886],
        [0.5669, 0.7870],
        [0.9415, 0.3396],
        [0.2015, 0.5745]]), tensor([1, 0, 0, 1, 1, 0, 1, 0])]

定义神经网络模型

定义的方法也是创建class继承torch.nn.module,一般在这个class下有两个函数分别是init和forward,分别是用于模型初始化和神经网络的计算过程,先来看初始化部分,这里调用了一个sequential这样一个类,用于把多层神经网络给组合在一起,也就是前后串连的关系,算完一层再算完下一层

#全连接神经网络
class Model(torch.nn.Module):#模型初始化部分def __init__(self):super().__init__()#定义神经网络结构self.fc = torch.nn.Sequential(torch.nn.Linear(in_features=2, out_features=32),torch.nn.ReLU(),torch.nn.Linear(in_features=32, out_features=32),torch.nn.ReLU(),torch.nn.Linear(in_features=32, out_features=2),torch.nn.Softmax(dim=1),)#定义神经网络计算过程def forward(self, x):return self.fc(x)
model = Model()
print(model(torch.randn(8,2)).shape)

输入层是两个维度输入,分别就是宽和高,输出是32维的向量,激活函数ReLU将所有的负数归零,中间层就是32x32密集的网络,从这层可以很好的抽取数据当中的特征,输出还是一个Linear,输入是32维的向量,输出是两维的向量,符合我们二分类的条件,最后一层假如了softmax,这层的功能是让两个神经元输出为1,因为我们是一个二分类问题,希望其相加的结果为1。

训练模型

在训练模型部分,首先来初始化一个优化器,代码中使用的是Adam,learning rate为1e-4。因为是分类这里使用的celoss,然后我们调用train函数,对全量的数据遍历100轮,从loader中取到一批批的数据,然后我们把x放到模型中去计算,将模型计算的结果与真实的y进行求误差,也就是调用的loss函数,如果模型计算的结果根y是完全相同的情况下,它的loss应当是0,但一般是不可能的。有了loss再算梯度,调整模型当中的参数,调整完后使梯度归零,所有的pytorch都会经过这里的三个步骤。

def train():#优化器,根据梯度调整模型参数optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)#计算loss的函数loss_fun = torch.nn.CrossEntropyLoss()#让model进入train模式,开启dropout等功能model.train()#全量数据遍历100轮for epoch in range(100):#按批次遍历loader中的数据for i, (x, y) in enumerate(loader):#模型计算out = model(x)#根据计算结果和y的差,计算loss,在计算结果完全正确的情况下,loss为0loss = loss_fun(out, y)#根据loss计算模型的梯度loss.backward()#根据梯度调整模型的参数optimizer.step()#梯度归零,准备下一轮的计算optimizer.zero_grad()if epoch % 20 == 0:#计算正确率acc = (out.argmax(dim=1) == y).sum().item() / len(y)print(epoch, loss.item(), acc)#保存模型到磁盘torch.save(model, 'model/3.model')

关于使用哪一个工具类来计算loss,一般来说回归采用MSEloss,分类使用CEloss。

测试

代码中添加了一个注解,意思是在这个函数中不需要计算模型的梯度,因为在这个函数中执行的是测试,而非执行训练,所以不需要更新参数,所以也就不需要计算模型的梯度。首先要将训练好的模型给它加载进来。让模型进入测试模式,这样可以关闭模型当中的一些dropout之类的功能。从loader中获取一批数据,然后计算模型的正确率。

#测试
#注释的表明不计算模型梯度,节省计算资源
@torch.no_grad()
def test():#从磁盘加载模型model = torch.load('model/3.model')#模型进入测试模式,关闭dropout等功能model.eval()#获取一批数据x, y = next(iter(loader))#模型计算结果out = model(x).argmax(dim=1)print(out, y)print(out == y)

回归

创建测试数据

#生成矩形数据的函数
def get_rectangle():import random#随机得到矩形的宽和高,值域0-1之间的小数width = random.random()height = random.random()#计算面积s = width * heightreturn width, height, s

定义数据集

在这里初始化不需要进行任何操作,数据的条数理论上有无穷多条,但在pytorch中还是要明确指出,每次生成一条数据,将类型转为tensor。

import torch#定义数据集
class Dataset(torch.utils.data.Dataset):#正常应该在这里执行数据的加载,处理等工作def __init__(self):pass#定义数据的条数def __len__(self):return 500#根据序号i,获取一条数据def __getitem__(self, i):#获取一个矩形的数据width, height, s = get_rectangle()#定义宽高为x,定义面积为yx = torch.FloatTensor([width, height])y = torch.FloatTensor([s])return x, ydataset = Dataset()print(len(dataset), dataset[0])

数据遍历工具loader

#数据集加载器,每8条数据为一个批次,打乱顺序,不足8条时丢弃尾数
loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=8,shuffle=True,drop_last=True)print(len(loader), next(iter(loader)))

loader的定义与上面的分类相同。

定义神经网络模型

class Model(torch.nn.Module):#模型初始化部分def __init__(self):super().__init__()#定义神经网络结构self.fc = torch.nn.Sequential(torch.nn.Linear(in_features=2, out_features=32),torch.nn.ReLU(),torch.nn.Linear(in_features=32, out_features=32),torch.nn.ReLU(),torch.nn.Linear(in_features=32, out_features=1),)#定义神经网络计算过程def forward(self, x):return self.fc(x)model = Model()print(model(torch.randn(8, 2)).shape)

与上一个任务不一样的是,最后一层网络它是一个全连接神经网络,并且输出值为一个神经元

训练模型

def train():#优化器,根据梯度调整模型参数optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)#计算loss的函数loss_fun = torch.nn.MSELoss()#让model进入train模式,开启dropout等功能model.train()#全量数据遍历100轮for epoch in range(100):#按批次遍历loader中的数据for i, (x, y) in enumerate(loader):#模型计算out = model(x)#根据计算结果和y的差,计算loss,在计算结果完全正确的情况下,loss为0loss = loss_fun(out, y)#根据loss计算模型的梯度loss.backward()#根据梯度调整模型的参数optimizer.step()#梯度归零,准备下一轮的计算optimizer.zero_grad()if epoch % 20 == 0:print(epoch, loss.item())#保存模型到磁盘torch.save(model, 'model/4.model')

测试

#测试
#注释的表明不计算模型梯度,节省计算资源
@torch.no_grad()
def test():#从磁盘加载模型model = torch.load('model/4.model')#模型进入测试模式,关闭dropout等功能model.eval()#获取一批数据x, y = next(iter(loader))#模型计算结果out = model(x)print(torch.cat([out, y], dim=1))

将刚刚训练好的模型从磁盘上加载进来,获取一批数据,让模型进行计算,并查看模型的一个计算结果是否与真实的y之间是否接近。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/4739.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

修改接口参数名和在Swagger中的展示名

背景 我们有一个接口要支持后端排序,所以需要在请求对象里面增加两个参数:排序字段名、排序方式(asc、desc)。 正好基础jar包中有一个类可以直接拿来用。 Data public class OrderByItem {private String column;private Strin…

创建型设计模式-1.单例设计模式

创建型设计模式-1.单例设计模式 创建型设计模式:核心目的就是给我们提供了一系列全新的创建对象的方式方法 一、简介 1.简述 单例设计模式(Singleton Design Pattern),一个类只允许创建一个对象(或实例&#xff09…

nginx+lua+redis环境搭建(文末赋上脚本)

目录 需求背景 环境搭建后nginx和redis版本 系统环境 搭建步骤 配置服务器DNS 安装ntpdate同步一下系统时间 安装网络工具、编译工具及依赖库 创建软件包下载目录、nginx和redis安装目录 下载配置安装lua解释器LuaJIT 下载nginx NDK(ngx_devel_kit&#xff09…

ceph安装部署

Ceph 简介 存储基础 单机存储设备 单机存储的问题 分布式存储的类型 分布式存储(软件定义的存储 SDS) Ceph 架构 Ceph 核心组件 ​编辑 Pool中数据保存方式支持两种类型 OSD 存储后端 Ceph 数据的存储过程 Ceph 集群部署 基于 ceph-deploy …

网络运维能转型到系统运维吗?

很多网工处于刚起步的初级阶段,各大公司有此专职,但重视或重要程度不高,可替代性强;小公司更多是由其它岗位来兼顾做这一块工作,没有专职,也不可能做得深入。 现在开始学习入门会有一些困难,不…

Hyperledger Fabric测试网络运行官方Java链码[简约版]

文章目录 启动测试网络使用peer CLI测试链码调用链码 启动测试网络 cd fabric-samples/test-networknetwork.sh的脚本语法是&#xff1a;network.sh <mode> [flag] ./network.sh up./network.sh createChannel在java源码路径下 chmod 744 gradlew vim gradlew :set ffu…

[SSM]GoF之工厂模式

目录 六、GoF之工厂模式 6.1工厂模式的三种形态 6.2简单工厂模式 6.3工厂方法模式 6.4抽象工厂模式&#xff08;了解&#xff09; 六、GoF之工厂模式 设计模式&#xff1a;一种可以被重复利用的解决方案 GoF&#xff08;Gang of Four)&#xff0c;中文名——四人组。 该书…

阿里云服务器 用docker部署mysql

阿里云服务器上使用Docker部署MySQL 当您在阿里云服务器上使用Docker部署MySQL时&#xff0c;步骤如下&#xff1a; 登录到阿里云服务器&#xff1a;使用SSH工具登录到您的阿里云服务器。您可以使用命令行工具&#xff08;如OpenSSH&#xff09;或可视化工具&#xff08;如PuT…

银河麒麟高级服务器操作系统V10安装mysql数据库

一、安装前 1.检查是否已经安装mysql rpm -qa | grep mysql2.将查询出的包卸载掉 rpm -e --nodeps 文件名3.将/usr/lib64/libLLVM-7.so删除 rm -rf /usr/lib64/libLLVM-7.so4.检查删除结果 rpm -qa | grep mysql5.搜索残余文件 whereis mysql6.删除残余文件 rm -rf /usr/b…

利用JavaScript判断页面宽度的响应式布局方法

首先前端中&#xff0c;样式尺寸单位全部用rem&#xff0c;比如&#xff1a; .content{ width: 8rem; border-radius: 0.15rem; font-size: 0.95rem; letter-spacing: 0.15rem; &#xff5d; 接着页面中的html给个默认的font-size样式&#xff0c;比如&#xff1a; <html …

MYSQL表操作(DML,DDL)

建表并插入数据&#xff1a; mysql> create table worker(-> dept_id int(11) not null,-> emp_id int (11) not null,-> work_time date not null,-> salary float(8,2) not null,-> poli_face varchar(10) not null default 群众,-> name varchar(20) …

一种对不同类型齐格勒-尼科尔斯 P-I-D 控制器调谐算法研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

MySQL数据库 【增删改查】

目录 一、新增 指定列插入 一次插入多个数据 二、查询 1、全列查询 2、指定列查询 3、查询字段为表达式 4、查询的时候给列名/表达式 指定别名 5、查询时去重 6、排序查询 7、条件查询 8、模糊查询 9、空值查询 10、分页查询 三、修改 四、删除 SQL 最核心…

11、JSON.parse 数据不完整

一、问题描述 使用 JSON.parse 反序列化&#xff0c;出现数据丢失现象。 字符串json数据&#xff1a; {"varImageList": [{"variationValue": "Black ","imageList": [{"variationValue": "Black ","imag…

JS-26 认识防抖和节流函数;自定义防抖、节流函数;自定义深拷贝、事件总线函数

目录 1_防抖和节流1.1_认识防抖和节流函数1.2_认识防抖debounce函数1.3_防抖函数的案例1.4_认识节流throttle函数 2_Underscore实现防抖和节流2.1_Underscore实现防抖和节流2.2_自定义防抖函数2.3_自定义节流函数 3_自定义深拷贝函数4_自定义事件总线 1_防抖和节流 1.1_认识防…

你是不是一个好的测试工程师?

如何评价一个程序员是否优秀一直是一个很有争议的话题。 先说一个真实事件&#xff0c;国际化项目&#xff0c;最开始都是由产品经理在excel中管理翻译&#xff0c;迭代过程中如有增删改&#xff0c;就把增删改的部分标记出来&#xff0c;提供给开发&#xff0c;开发再对应更新…

关于gateway中lb失效

在通过gateway将请求发送到对应的服务模块时&#xff0c;出现了503的报错&#xff0c;也就是gateway时可以正常启动&#xff0c;但是页面上在发送请求获取数据的时候&#xff0c;却不是相应的请求地址。 解决方法&#xff1a; 1.首先你得保证前端项目里面访问网关地址都是正确…

【Netty】NIO基础(三大组件、文件编程)

文章目录 三大组件Channel & BufferSelector ByteBufferByteBuffer 正确使用姿势ByteBuffer 内部结构ByteBuffer 常见方法分配空间向 buffer 写入数据从 buffer 读取数据mark 和 reset 字符串与 ByteBuffer 互转Scattering ReadsGathering Writes粘包、半包分析 文件编程Fi…

vue 当新增样式无法生效的情况下如何处理

使用scoped属性时&#xff0c;会遇到样式问题。需要使用样式穿透解决 <style lang"scss" scoped> </style> 可以使用以下方法 &#xff1a;deep css 使用 >>> less 使用 /deep/ scss 使用 ::v-deep 代码写法如下: .a :deep(.b) { } .…

v-show和v-if的区别以及显示隐藏不生效的奇怪现象以及点击索引错位问题的解释

基本概念没什么好讲的。有时候会遇到莫名其妙不显示的问题&#xff0c;这都是因为对这两个概念理解不透彻造成的。 v-show的本质 v-show的本质就是通过调用css的display:none来实现的&#xff0c;这点非常重要&#xff0c;出问题可以在浏览器调试页面手动设置display:none来验…