动手学深度学习——多层感知机

1. 感知机

感知机本质上是一个二分类问题。给定输入x、权重w、偏置b,感知机输出:

以猫和狗的分类问题为例,它本质上就是找到下面这条黑色的分割线,使得所有的猫和狗都能被正确的分类。

与线性回归和softmax的不同点:

  • vs 线性回归:输出的都是一个数,但线性回归输出的是实数,而感知机输出的是离散的分类。
  • vs softmax: softmax是一个多分类(如果有n个分类,softmax就会输出n个元素),而感知机只输出一个元素。

感知机存在的问题: 它只能产生线性分割面,对于XOR(异或)函数,无法拟合(一条线不论怎么分割,都无法将绿色和红色分类正确)。

2. 多层感知机(MLP)

对于上面单层感知机的问题,一个改进思想是:一层函数如果做不了,就用多层函数来做,而多层就带来了网络,用不同层解决不同的问题,多层配合来解决更复杂的问题。

可以使用蓝线对所有数据进行x轴方向的正负分类,再使用黄线对所有数据进行y轴方向的正负分类,最后再将两次分类结果进行xor运算就能得到结果。

多层感知机使用隐藏层和激活函数来得到非线性模型。

在softmax基础上多了隐藏层。可选超参:

  • 隐藏层数
  • 每个隐藏层的宽度,通常选择2的若干次冥作为层的宽度

这两个参数的选择取决于输入和输出的复杂度

对复杂的输入,输入维度一般比较高,输出一般会比较少,有两种处理办法:

  1. 做单隐藏层,把模型做平,层的大小设大一点
  2. 做多隐藏层,把模型做深,层的大小可以设小一点,每层的维度逐步减少(如果每层维度都高,则会导致模型太大)

复杂输入到简单输出本质上是一个信息压缩的过程,多层逐步压缩能避免一次压缩太大导致信息损失太严重,例如:128->64->32->16->8
也可以先expand,从128->256->64->32->16->8

3. 激活函数

作用:在神经网络中引入非线性,可以理解为一个开关,当输入信号超过一定阀值时,神经元会被激活并产生输出,而未超过阀值时神经元将会被抑制。

在没有激活函数的情况下,神经网络只能表示线性映射,无法处理复杂的非线性关系。激活函数的作用就是线性结果映射到一个非线性的输出,以帮助神经网络更好的适应输入数据,提高非线性拟合能力。

举例:一个邮件过滤模型中的神经元,负责对输入邮件的特征(长度、关键词等)进行加权求和,但这个结果只是一个连续的数值我们交

激活函数不能是线性函数,否则会变成单层感知机,依然会存在线性分割面无法处理XOR的问题。

激活函数主要作用于隐藏层。

激活函数的几种选择:

  1. sigmoid: 对于任意输入x,都能投影到0~1区间内。

  2. tanh(x): 将输入投影到[-1,1]区间内

  1. ReLU: 就是一个Max函数(常用),特点是计算很快,相比前面基于指数运算的sigmoid和tanh函数都快很多(一次指数运算要100个时钟周期)

对ReLU函数求导,小于等于0时都是0,大于0时都是1,最终结果就是一个二分类。

4. 代码实现

4.1 初始化参数

我们将实现一个具有单隐藏层的多层感知机, 这个隐藏层包含128个隐藏单元。

对于每一层我们都要记录一个权重矩阵和一个偏置向量,并指定requires_grad=True来记录参数梯度。

import torch
from torch import nn
from d2l import torch as d2lnum_inputs, num_outputs, num_hiddens = 784, 10, 128W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [W1, b1, W2, b2]

通常,我们选择2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

4.2 加载数据集

这里继续使用Fashion-MNIST图像分类数据集。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

4.3 激活函数

Relu函数的实现比较简单,就是一个max函数的调用, 它将输入的负值部分截断为0,保留正值部分不变。

def relu(X):a = torch.zeros_like(X)return torch.max(X, a)
  • torch.zeros_like(X): 创建了一个与X具有相同形状的全零张量a。
  • torch.max(X, a): 对于输入X中的每个元素,如果它是正值,则该元素保留不变;如果它是负值,则将其替换为0。

4.4 模型

def net(X):X = X.reshape((-1, num_inputs))    H = relu(X@W1 + b1)  # 隐藏层,这里“@”代表矩阵乘法return (H@W2 + b2)   # 输出层
  1. 使用reshape将输入的二维图像转换为一个长度为num_inputs=784的向量;
  2. 用ReLu函数对隐藏层的线性输出进行激活,得到输出张量H;
  3. 最后,由张量H和权重矩阵W2进行矩阵乘法操作,将偏置向量b2加到结果上,得到预测输出结果。

4.5 损失函数

这里直接使用pytorch中内置的交叉熵损失函数。

loss = nn.CrossEntropyLoss(reduction='none')

4.6 训练

多层感知机的训练过程与softmax的训练过程完全相同,可以直接调用之前定义过的train_ch3函数。

# 将迭代周期数设置为10,并将学习率设置为0.1.
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

训练过程中的模型损失和精度的收敛变化:

epoch: 1, loss: 1.1021366075515746, test_acc: 0.7544
epoch: 2, loss: 0.6142196039199829, test_acc: 0.8004
epoch: 3, loss: 0.5257990721384684, test_acc: 0.8061
epoch: 4, loss: 0.4842481053034465, test_acc: 0.7988
epoch: 5, loss: 0.4575055497487386, test_acc: 0.8266
epoch: 6, loss: 0.4389862974802653, test_acc: 0.8382
epoch: 7, loss: 0.42252545185089113, test_acc: 0.8443
epoch: 8, loss: 0.40933472124735515, test_acc: 0.8458
epoch: 9, loss: 0.3975078603744507, test_acc: 0.8467
epoch: 10, loss: 0.38488629398345947, test_acc: 0.8527

基于之前softmax模型上定义的预测函数,在测试数据集上使用这个模型做验证:

predict_ch3(net, test_iter)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/9884.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker技能

文章目录 Docker2024心得优秀博客 Docker2024 心得 感觉这块目前学习用处不大。工作中用到的大多是编程的技巧,这块是运维的技能。 优秀博客 快速使用Docker部署MySQL、Redis、NginxIDEA集成Docker构建SpringBoot镜像上传服务器Docker常用命令总结docker-compos…

Ubuntu/Linux 安装Docker + PyTorch

文章目录 1. 提前准备2. 安装Docker2.1. 卸载冲突软件(非必要)2.2. 在Ubuntu系统上添加Docker的官方GPG密钥2.3. 将Docker的仓库添加到Ubuntu系统的APT源列表中2.4. 安装最新Docker2.5. 检查 3. 安装Nvidia Container Toolkit3.1. 在Ubuntu系统上添加官方…

求一个B站屏蔽竖屏视频的脚本

求一个B站屏蔽竖屏视频的脚本 现在B站竖屏竖屏越来越多了,手机还好点给我一个按钮,选择不喜欢,但是我一般都用网页版看视屏,网页版不给我选择不喜欢的按钮,目测大概1/4到1/3的视频都是竖屏视频。 目前网页版唯一的进…

Union内存分布

最近研究union,发现union内存分布挺有意思。 Union定义是什么? Union是中文名是联合体,类似于struct,但是跟struct有很多区别,里面参数公用内存。 Union和struct的区别 ①结构体(struct)中所有变量是“共存”的——优…

MarkText 下载安装和运行

1 官网页面 2 Github 页面 3 选择合适的版本,下载后运行。 附录: 官网: https://www.marktext.cc/ Github 地址: https://github.com/marktext/marktext/releases 目前最新版 v0.17.1,Mar 8, 2022。

二叉树的遍历(前序 中序 后序)

一、前序遍历 顺序为: 根-->左子树---->右子树 先访问根节点,再递归进入根节点的左子树(通过递归不断往下遍历),直到访问的节点没有左子树,此时递归进入其右子树(通过递归进行相同操作&a…

~MAY~

一时间不知道要干啥了&#xff0c;随便写点题目吧 亚运奖牌榜 很简单的模拟&#xff0c;题目看错了整了半天。。 #include<bits/stdc.h> using namespace std; int n; struct a {int pai[4]; }c1,c2; int main() {cin>>n;memset(c1.pai,0,sizeof(c1.pai));memse…

有限单元法-编程与软件应用(崔济东、沈雪龙)【PDF下载】

专栏导读 作者简介&#xff1a;工学博士&#xff0c;高级工程师&#xff0c;专注于工业软件算法研究本文已收录于专栏&#xff1a;《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现&#xff0c;并提供所有案例完整源码&#xff1b;2.单元…

在centos7中运行向量数据库PostgreSQL连接不上如何排查?

1. 检查 PostgreSQL 服务状态 首先&#xff0c;您需要确认 PostgreSQL 服务是否正在运行。您可以使用以下命令来检查服务状态&#xff1a; sudo systemctl status postgresql如果服务没有运行&#xff0c;您需要启动它&#xff1a; sudo systemctl start postgresql2. 确认 …

OSPF链路状态数据库

原理概述 OSPF是一种基于链路状态的动态路由协议&#xff0c;每台OSPF路由器都会生成相关的LSA&#xff0c;并将这些LSA通告出去。路由器收到LSA后&#xff0c;会将它们存放在链路状态数据库LSDB中。 LSA有多种不同的类型&#xff0c;不同类型的LSA的功能和作用是不同的&…

【智能优化算法】金豺狼优化算法(Golden jackal optimization,GJO)

金豺狼优化(Golden jackal optimization,GJO)是期刊“Expert Systems with Applications”&#xff08;中科院一区IF 8.3&#xff09;的2022年智能优化算法 01.引言 金豺狼优化(Golden jackal optimization,GJO)旨在为解决实际工程问题提供一种替代的优化方法。GJO的灵感来自金…

【智能优化算法】卷尾猴搜索算法(Capuchin search algorithm,CapSA)

【智能优化算法】卷尾猴搜索算法(Capuchin search algorithm,CapSA)是期刊“NEURAL COMPUTING & APPLICATIONS”&#xff08;IF 6.0&#xff09;的2021年智能优化算法 01.引言 【智能优化算法】卷尾猴搜索算法(Capuchin search algorithm,CapSA)用于解决约束和全局优化问…

VMware Workstation 17 Player 创建虚拟机教程

本教程是以windows server 2012物理机服务器安装好的VMware Workstation 17 Player为例进行演示&#xff0c;安装VMware Workstation 17 Player大家可以自行网上搜索安装。 1、新建虚拟机 双击安装好的VMvare图标&#xff0c;点击创建虚拟机。 2、选择是否安装系统 本步骤选…

23 内核开发- Linux 内核下半段的实现方式

23 内核开发- Linux 内核下半段的实现方式 1.定义 下半部&#xff0c;就是执行中断处理密切相关但是中断处理程序本身不执行的工作。 2.为什么要用下半部执行&#xff1f; 中断处理程序不在进程上下文中运行&#xff0c;所以他们不能阻塞&#xff1b;为什么要推后执行&#xf…

【静态分析】软件分析课程实验A4-类层次结构分析与过程间常量传播

官网&#xff1a;作业 4&#xff1a;类层次结构分析与过程间常量传播 | Tai-e 参考&#xff1a;https://www.cnblogs.com/gonghr/p/17984124 ----------------------------------------------------------------------- 1 作业导览 为 Java 实现一个类层次结构分析&#xf…

shiro-quickstart启动报错

说明&#xff1a;最近在学登录框架&#xff0c;记录一下学习刚shiro框架&#xff0c;启动快速入门样例的错误&#xff1b; 场景 把shiro代码download下来&#xff0c;打开samples&#xff08;样例&#xff09;包&#xff0c;打开快速入门&#xff0c;启动&#xff0c;报错&am…

Apache代理服务器使用注意事项

文章目录 一、配置二、服务器配置三、代理配置四、常用命令1、重启服务器 运行环境&#xff1a;ubuntu 一、配置 服务器配置文件路径&#xff1a;/etc/apache2/apache2.conf不加密代理配置文件路径&#xff1a;/etc/apache2/sites-enabled/000-default.conf加密代理配置文件路…

WiFi网络的重要性

 WiFi网络的重要性 便利性&#xff1a;WiFi网络提供了无线连接&#xff0c;使得用户可以在没有线缆束缚的情况下访问互联网。这种便利性使得用户可以随时随地使用笔记本电脑、平板电脑、智能手机等设备上网&#xff0c;无论是在家中、办公室、咖啡馆还是公共场所。移动性&am…

聊天室项目思路

发起群聊&#xff1a; 从好友表选取人发送到服务器&#xff0c;服务器随机生成不重复的群号&#xff0c;存储在数据库&#xff0c;同时建立中间表&#xff0c;处理用户与群聊的关系 申请入群&#xff1a; 输入群号&#xff0c;发消息给服务器&#xff0c;服务器查询是否存在…

06-xss攻防于绕过

xss的攻击于防御 攻击的利用方式 1&#xff09;获取cookie&#xff0c;实现越权&#xff0c;如果是获取到网站管理员的cookie&#xff0c;也可以叫提权。注意尽量尽快退出账号&#xff0c;删除session&#xff0c;让session失效 2&#xff09;钓鱼网站&#xff0c;模拟真实的…