【PyTorch】多层感知机

文章目录

  • 1. 模型和代码实现
    • 1.1. 模型
      • 1.1.1. 背景
      • 1.1.2. 多层感知机
      • 1.1.3. 激活函数
    • 1.2. 代码实现
      • 1.2.1. 完整代码
      • 1.2.2. 输出结果
  • 2. Q&A

1. 模型和代码实现

1.1. 模型

1.1.1. 背景

许多问题要使用线性模型,但无法简单地通过预处理来实现。此时我们可以通过在网络中加入一个或多个隐藏层来克服线性模型的限制, 使其能处理更普遍的函数关系类型。

1.1.2. 多层感知机

将许多全连接层堆叠在一起。 每一层都输出到上面的层,直到生成最后的输出,我们可以把前层看作表示,把最后一层看作线性预测器。 这种架构通常称为多层感知机,通常缩写为MLP。
多层感知机

1.1.3. 激活函数

我们需要在仿射变换之后对每个隐藏单元应用非线性的激活函数,这样就不可能再将我们的多层感知机退化成线性模型,使得模型具有更强的表达能力。
激活函数是通过计算加权和并加上偏置来确定神经元是否应该被激活, 并将输入信号转换为输出的可微运算的函数。

  • ReLU函数

    • 修正线性单元(Rectified linear unit,ReLU)。

    • 最受欢迎的激活函数。

    • 定义: R e L U ( x ) = m a x ( 0 , x ) \mathrm{ReLU}(x)=\mathrm{max}(0,x) ReLU(x)=max(0,x)
      relu

    • 当输入接近0时,sigmoid函数接近线性变换。
      gradofrelu

    • 当输入值精确等于0时,ReLU函数不可导。 在此时,我们默认使用左侧的导数,即当输入为0时导数为0。 我们可以忽略这种情况,因为输入可能永远都不会是0。

    • 变体:参数化的ReLU(Parameterized ReLU,pReLU),允许即使参数是负的,某些信息依然可以通过,其定义如下: p R e L U ( x ) = m a x ( 0 , x ) + α m i n ( 0 , x ) \mathrm{pReLU}(x)=\mathrm{max}(0,x)+\alpha\mathrm{min}(0,x) pReLU(x)=max(0,x)+αmin(0,x)等等。

  • sigmoid函数

    • 将输入变换为区间(0, 1)上的输出。

    • 在隐藏层中已经较少使用, 它在大部分时候被更简单、更容易训练的ReLU所取代。

    • 定义: s i g m o i d ( x ) = 1 1 + e x p ( − x ) \mathrm{sigmoid}(x)=\frac{1}{1+\mathrm{exp}(-x)} sigmoid(x)=1+exp(x)1
      sigmoid

    • 导数: d d x s i g m o i d ( x ) = s i g m o i d ( x ) ( 1 − s i g m o i d ( x ) ) \frac{\mathrm{d}}{\mathrm{d}x}\mathrm{sigmoid}(x)=\mathrm{sigmoid}(x)(1-\mathrm{sigmoid}(x)) dxdsigmoid(x)=sigmoid(x)(1sigmoid(x))
      gradofsigmoid

  • tanh函数

    • 将其输入压缩转换到区间(-1, 1)上。

    • 定义: t a n h ( x ) = 1 − e x p ( − 2 x ) 1 + e x p ( − 2 x ) \mathrm{tanh}(x)=\frac{1-\mathrm{exp}(-2x)}{1+\mathrm{exp}(-2x)} tanh(x)=1+exp(2x)1exp(2x)
      tanh

    • 当输入接近0时,tanh函数接近线性变换。

    • 导数: d d x t a n h ( x ) = 1 − t a n h 2 ( x ) \frac{\mathrm{d}}{\mathrm{d}x}\mathrm{tanh}(x)=1-\mathrm{tanh}^2(x) dxdtanh(x)=1tanh2(x)
      gradoftanh

1.2. 代码实现

1.2.1. 完整代码

import torch
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriterdef load_dataset(batch_size, num_workers):"""加载数据集"""root = "./dataset"transform = transforms.Compose([transforms.ToTensor()])mnist_train = FashionMNIST(root=root, train=True, transform=transform, download=True)mnist_test = FashionMNIST(root=root, train=False, transform=transform, download=True)dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers)dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers)return dataloader_train, dataloader_testdef init_network(net):"""初始化模型参数"""def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.constant_(m.bias, val=0)if isinstance(net, nn.Module):net.apply(init_weights)class Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def train(net, dataloader_train, criterion, optimizer, device):"""训练模型"""if isinstance(net, nn.Module):net.train()train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in dataloader_train:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)optimizer.zero_grad()loss.mean().backward()optimizer.step()train_metrics.add(float(loss.sum()), accuracy(y_hat, y), y.numel())train_loss = train_metrics[0] / train_metrics[2]train_acc = train_metrics[1] / train_metrics[2]return train_loss, train_accdef test(net, dataloader_test, device):"""测试模型"""if isinstance(net, nn.Module):net.eval()with torch.no_grad():    test_metrics = Accumulator(2)   # 测试准确度总和、样本数for X, y in dataloader_test:X, y = X.to(device), y.to(device)y_hat = net(X)test_metrics.add(accuracy(y_hat, y), y.numel())test_acc = test_metrics[0] / test_metrics[1]return test_accif __name__ == "__main__":# 全局参数设置batch_size = 256num_workers = 4num_epochs = 10learning_rate = 0.1device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 创建记录器writer = SummaryWriter()# 加载数据集dataloader_train, dataloader_test = load_dataset(batch_size, num_workers)# 定义神经网络net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)).to(device)# 初始化神经网络init_network(net)# 定义损失函数criterion = nn.CrossEntropyLoss(reduction='none')# 定义优化器optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)for epoch in range(num_epochs):train_loss, train_acc = train(net, dataloader_train, criterion, optimizer, device)test_acc = test(net, dataloader_test, device)writer.add_scalars("metrics", {'train_loss': train_loss, 'train_acc': train_acc, 'test_acc': test_acc}, epoch)writer.close()

1.2.2. 输出结果

多层感知机

2. Q&A

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/196097.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智能联动第三方告警中心,完美实现故障响应全闭环

前言 我们曾讨论完善的告警策略是整个数据监控系统的重要组成部分(参见《机智的告警策略,完善监控系统的重要一环》),介绍了如何配置告警通知以及场景示例,帮助用户及时更多潜在的故障和问题,有效地保障系…

Redis 之 ZSET 实战应用场景,持续更新!

前言 大白话介绍 Redis 五大基本数据类型之一的 ZSET 开发中常见的应用场景 ZSET 介绍 ZSET 与 SET 相同点:都是是 String类型元素的集合,且不允许重复的成员ZSET 与 SET 不同点:ZSET 每个元素都会关联一个 Double 类型的分数,Re…

Latex去掉参考文献后面的参考文献所在页(去掉参考文献的反向超链接)

如下: 在使用latex插入参考文献的时候,最后面总是会出现这种代号。这是表明的是这条参考文献所在的页码,并且点击之后可以跳转到该页。正式来讲,这个叫超链接的BACKREF。若要去掉,只需要在引用hyperref的时候去掉page…

技术or管理?浅谈软件测试人员的未来职业发展,值得借鉴

我们在工作了一段时间之后,势必会感觉到自己已经积累了一些工作经验了,会开始考虑下一阶段的职业生涯会如何发展。测试人员在职业生涯中的不确定因素还是不少的,由于其入门门槛不高,不用学习太多技术性知识即可入行,所…

Net8 EFCore Mysql 连接

一、安装插件 Pomelo.EntityFrameworkCore.MySq (这里要选8.0.0以上版本低版本不支持.net8) 二、配置数据库连接串 appsettings.json 中配置数据库连接串 "ConnectionStrings": {"Connection": "server172.18.2.183;port3306;databasestudents;uid…

使用opencv将8位图像raw数据转成bmp文件的方法

作者&#xff1a;朱金灿 来源&#xff1a;clever101的专栏 为什么大多数人学不会人工智能编程&#xff1f;>>> 这里说的图像raw数据是只包含图像数据的缓存。主要使用了cv::imencode接口将 cv::Mat转化为图像缓存。 #include <opencv2/opencv.hpp>/* 生成一幅…

【若依框架实现上传文件组件】

若依框架中只有个人中心有上传图片组件&#xff0c;但是这个组件不适用于el-dialog中的el-form表单页面 于是通过elementui重新写了一个上传组件&#xff0c;如图是实现效果 vue代码 <el-dialog :title"title" v-model"find" width"600px"…

Pytorch进阶教学——训练一个图像分类模型(GPU)

目录 1、前言 2、数据集介绍 3、获取数据 4、创建网络 5、训练模型 6、测试模型 6.1、测试整个模型准确率 6.2、测试单张图片 1、前言 编写一个可以分类蚂蚁和蜜蜂图片的模型&#xff0c;使用数据集对卷积神经网络进行训练。训练后的模型可以对蚂蚁或蜜蜂的图片进行…

【广州华锐互动】VR沉浸式体验铝厂安全事故让伤害教育更加深刻

随着科技的不断发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术已经逐渐渗透到各个领域&#xff0c;为我们的生活带来了前所未有的便捷和体验。在安全生产领域&#xff0c;VR技术的应用也日益受到重视。 VR公司广州华锐互动就开发了多款VR安全事故体验系统&#xff0c…

蓝桥杯-03-蓝桥杯学习计划

蓝桥杯-03-蓝桥杯学习计划 参考资料 相关文献 报了蓝桥杯比赛&#xff0c;几乎零基础&#xff0c;如何准备&#xff0c;请大牛指导一下。谢谢&#xff1f; 蓝桥杯2022各组真题汇总(完整可评测) 基础学习 C语言网 ACM竞赛入门,蓝桥杯竞赛指南 廖雪峰的官方官网 算法题单 洛谷…

vue,nvue,uniapp,到底是什么

vue,nvue,uniapp,到底是什么&#xff1f; 发展猜想&#xff1a; 开发移动端软件&#xff0c;一般是控件逻辑&#xff0c;可拖动控件android studio都给你设计好了。 开发web页面时&#xff0c;用vue&#xff0c;vue是前端框架。主要是终端设备通过浏览器进行访问&#xff08…

ubuntu20.04使用LIO-SAM对热室空间进行重建

一、安装LIO-SAM 1.环境配置 默认已经安装过ros sudo apt-get install -y ros-Noetic-navigation sudo apt-get install -y ros-Noetic-robot-localization sudo apt-get install -y ros-Noetic-robot-state-publisher 安装 gtsam(如果是18.04的ubuntu直接按照官网配置&…

C++ 基础篇

目录 C开发概述 C特点 C跨平台的原因 C编译器 C库 操作系统API C基本概念 注释 变量 常量 两种定义常量方式的区别 表示符命名规则 常见的关键字 数据类型 整型 浮点数 字符型 转义字符 字符串型 布尔类型 运算符 算术运算符 赋值运算符 比较运算符 逻…

【VScode】超详细图片讲解下载安装、环境配置、编译执行、调试

这里是目录 VScode是什么&#xff1f;VScode的下载和安装环境介绍安装中文插件 配置VScodeC/C开发环境下载和配置MinGW-w64 编译器套件下载&#xff1a;配置&#xff1a; 安装C/C插件在VScode上编写代码设置C/C编译选项创建执行任务编译执行如果想写其他代码在同一个文件夹在不…

springboot 整合 Spring Security 中篇(RBAC权限控制)

1.先了解RBAC 是什么 RBAC(Role-Based Access control) &#xff0c;也就是基于角色的权限分配解决方案 2.数据库读取用户信息和授权信息 1.上篇用户名好授权等信息都是从内存读取实际情况都是从数据库获取&#xff1b; 主要设计两个类 UserDetails和UserDetailsService 看下…

新媒体营销模拟实训室解决方案

一、引言 随着互联网的发展&#xff0c;新媒体已成为企业进行营销和品牌推广的重要渠道。然而&#xff0c;对于许多企业来说&#xff0c;如何在新媒体上进行有效的营销仍是一大挑战。为了解决这个问题&#xff0c;我们推出了一款新媒体营销模拟实训室解决方案&#xff0c;以帮…

【文末送书】Python OpenCV从入门到精通

文章目录 &#x1f354;简介opencv&#x1f339;内容简介&#x1f6f8;编辑推荐&#x1f384;导读&#x1f33a;彩蛋 &#x1f354;简介opencv OpenCV&#xff08;Open Source Computer Vision Library&#xff09;是一个开源的计算机视觉库&#xff0c;提供了丰富的图像处理和…

java学习part31String

142-常用类与基础API-String的理解与不可变性_哔哩哔哩_bilibili 1.String 2.字符串常量池 变更储存区的原因是加快被gc的频率 比地址&#xff0c;equals比内容 3.字符串连接 s3s4都是字符串常量&#xff0c;后面几个会利用StringBuilder的toString&#xff08;&#xff09;&a…

JAVA全栈开发 day16_MySql01

一、数据库 1.数据储存在哪里&#xff1f; 硬盘、网盘、U盘、光盘、内存&#xff08;临时存储&#xff09; 数据持久化 使用文件来进行存储&#xff0c;数据库也是一种文件&#xff0c;像excel &#xff0c;xml 这些都可以进行数据的存储&#xff0c;但大量数据操作&#x…

C#网络编程TCP程序设计(Socket类、TcpClient类和 TcpListener类)

目录 一、Socket类 1.Socket类的常用属性及说明 2.Socket类的常用方法及说明 二、TcpClient类 三、TcpListener类 四、示例 1.源码 2.生成效果 TCP(Transmission Control Protocol)是一种面向连接的、可靠的、基于字节流的传输层通信协议。在C#中&#xff0c;TCP程序设…