Pytorch深度学习-----神经网络之线性层用法

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)


文章目录

  • 系列文章目录
  • 一、线性层是什么?
    • 1.官网解释
    • 2.nn.Linear函数参数介绍
  • 二、实战演示
    • 1.将CIFAR10图片数据集进行线性变换


一、线性层是什么?

线性层是深度学习中常用的一种基本层类型。它也被称为全连接层或仿射层。线性层的作用是将输入数据与权重矩阵相乘,然后加上偏置向量,最后输出一个新的特征表示。

具体来说,线性层可以表示为 Y = XW + b,其中 X 是输入数据W 是权重矩阵b 是偏置向量Y 是输出结果。这个过程可以看作是对输入数据进行线性变换的操作。

1.官网解释

官网访问:LINEAR
如下图所示
在这里插入图片描述
在这里插入图片描述
由此可见,每一层的某个神经元的值都为前一层所有神经元的值的总和。

2.nn.Linear函数参数介绍

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

其中最重要的三个参数为in_features, out_features, bias

in_features, 表示输入的特征值大小,即输入的神经元个数
out_features,表示输出的特征值大小,即经过线性变换后输出的神经元个数
bias,表示是否添加偏置

二、实战演示

在这里插入图片描述
预定要的in_features为1,1,x形式
out_features为1,1,y的形式

1.将CIFAR10图片数据集进行线性变换

代码如下:

import torch
import torchvision
from torch.utils.data import DataLoader# 准备数据
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 加载数据集
dataloader = DataLoader(test_set,batch_size=64)# 查看输入的通道数
# for data in dataloader:
#     imgs, target = data
#     print(imgs.shape)  # torch.Size([64, 3, 32, 32])
#     # 将img进行reshape成1,1,x的形式
#     input = torch.reshape(imgs,(1,1,1,-1)) # 每次一张图,1通道,1*自动计算x
#     print(input.shape) # torch.Size([1, 1, 1, 196608])# 搭建神经网络,设置预定的输出特征值为10
class Lgl(torch.nn.Module):def __init__(self):super(Lgl, self).__init__()self.linear1 = torch.nn.Linear(196608,10)  # 输入数据的特征值196608,输出特征值10def forward(self, input):output = self.linear1(input)return output
# 实例化
l = Lgl()
# 进行线性操作for data in dataloader:imgs, target = dataprint(imgs.shape)  # torch.Size([64, 3, 32, 32])# 将img进行reshape成1,1,x的形式input = torch.reshape(imgs,(1,1,1,-1)) # 每次一张图,1通道,1*自动计算xoutput = l(input)print(output.shape) # torch.Size([1, 1, 1, 10])
原先的图片shape:torch.Size([64, 3, 32, 32])
reshape后的图片shape:torch.Size([1, 1, 1, 196608])
经过线性后的图片shape:torch.Size([1, 1, 1, 10])
原先的图片shape:torch.Size([64, 3, 32, 32])
reshape后的图片shape:torch.Size([1, 1, 1, 196608])
经过线性后的图片shape:torch.Size([1, 1, 1, 10])
……

除了使用reshape后,还可以使用torch.flatten()进行修改尺寸,将其自动修改为一维。
torch.flatten(input, start_dim=0, end_dim=- 1)
将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor

修改成flatten后代码如下

import torch
import torchvision
from torch.utils.data import DataLoader# 准备数据
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 加载数据集
dataloader = DataLoader(test_set,batch_size=64)# 查看输入的通道数
# for data in dataloader:
#     imgs, target = data
#     print(imgs.shape)  # torch.Size([64, 3, 32, 32])
#     # 将img进行reshape成1,1,x的形式
#     input = torch.reshape(imgs,(1,1,1,-1)) # 每次一张图,1通道,1*自动计算x
#     print(input.shape) # torch.Size([1, 1, 1, 196608])# 搭建神经网络,设置预定的输出特征值为10
class Lgl(torch.nn.Module):def __init__(self):super(Lgl, self).__init__()self.linear1 = torch.nn.Linear(196608,10)  # 输入数据的特征值196608,输出特征值10def forward(self, input):output = self.linear1(input)return output
# 实例化
l = Lgl()
# 进行线性操作for data in dataloader:imgs, target = dataprint(f"原先的图片shape:{imgs.shape}")  # torch.Size([64, 3, 32, 32])# 将img进行reshape成1,1,x的形式input = torch.flatten(imgs) # 每次一张图,1通道,1*自动计算xprint(f"flatten后的图片shape:{input.shape}")output = l(input)print(f"经过线性后的图片shape:{output.shape}") # torch.Size([1, 1, 1, 10])
原先的图片shape:torch.Size([64, 3, 32, 32])
flatten后的图片shape:torch.Size([196608])
经过线性后的图片shape:torch.Size([10])
原先的图片shape:torch.Size([64, 3, 32, 32])
flatten后的图片shape:torch.Size([196608])
经过线性后的图片shape:torch.Size([10])
……

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/25507.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL索引3——Explain关键字和索引使用规则(SQL提示、索引失效、最左前缀法则)

目录 Explain关键字 索引性能分析 Id ——select的查询序列号 Select_type——select查询的类型 Table——表名称 Type——select的连接类型 Possible_key ——显示可能应用在这张表的索引 Key——实际用到的索引 Key_len——实际索引使用到的字节数 Ref ——索引命…

【Linux】五、进程

一、冯诺依曼体系结构 存储器:指的是内存; 输入设备:键盘、摄像头、话筒,磁盘,网卡; 输出设备:显示器、音响、磁盘、网卡; 中央处理器(CPU):运算器…

【开源项目--稻草】Day04

【开源项目--稻草】Day04 1. 续 VUE1.1 完善VUEAJAX完成注册功能 Spring验证框架什么是Spring验证框架使用Spring-Validation 稻草问答-学生首页显示首页制作首页的流程开发标签列表标签列表显示原理 从业务逻辑层开始编写控制层代码开发问题列表开发业务逻辑层开发页面和JS代码…

HTML5 Canvas(画布)

<canvas>标签定义图形&#xff0c;比如图表和其他图像&#xff0c;你必须用脚本来绘制图形。 在画布上&#xff08; Canvas &#xff09;画一个共红色矩形&#xff0c;渐变矩形&#xff0c;彩色矩形&#xff0c;和一些彩色文字。 什么是 Canvas&#xff1f; HTML5<c…

机器学习深度学习——序列模型(NLP启动!)

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位即将上大四&#xff0c;正专攻机器学习的保研er &#x1f30c;上期文章&#xff1a;机器学习&&深度学习——卷积神经网络&#xff08;LeNet&#xff09; &#x1f4da;订阅专栏&#xff1a;机器学习&&深度…

vue3+vite项目配置ESlint、pritter插件

配置ESlint、pritter插件 在 Vue 3 Vite 项目中&#xff0c;你可以通过以下步骤配置 ESLint 和 Prettier 插件&#xff1a; 安装插件&#xff1a; 在项目根目录下&#xff0c;打开终端并执行以下命令安装 ESLint 和 Prettier 插件&#xff1a; npm install eslint prettier e…

Mr. Cappuccino的第55杯咖啡——Mybatis一级缓存二级缓存

Mybatis一级缓存&二级缓存 概述一级缓存特点演示前准备效果演示在同一个SqlSession中在不同的SqlSession中 源代码怎么禁止使用一级缓存一级缓存在什么情况下会被清除 二级缓存特点演示前准备效果演示在不同的SqlSession中 源代码怎么关闭二级缓存 一级缓存&#xff08;Spr…

ubuntu20.4 sgx环境配置

一、driver安装 1.在该下载地址将3个.bin文件下载下来&#xff0c;下载地址&#xff1a;https://download.01.org/intel-sgx/latest/linux-latest/distro/ubuntu20.04-server/ 2.到下载文件夹下输入下面命令&#xff0c;以赋予.bin文件的执行权限 sudo chmod 777 sgx_linux_x64…

HTTP 常用状态码 301 302 304 403

HTTP 常用状态码 301 302 304 403 301 永久重定向&#xff0c;浏览器会把重定向后的地址缓存起来&#xff0c;将来用户再次访问原始地址时&#xff0c;直接引导用户访问新地址 302 临时重定向&#xff0c;浏览器会引导用户进入新地址&#xff0c;但不会缓存原始地址&#xff0c…

Python模块—Pytest模块

文章目录 PyTest1. args参数2. pytest-ordering3. fixture&#xff08;前置函数&#xff09;4. parametrize&#xff08;参数化&#xff09;5. fixture 与 parametrize 结合6. pyyaml&#xff08;数据源&#xff09;7. pytest-xdist&#xff08;分布式测试&#xff09;8. allur…

LA@行列式性质

文章目录 行列式性质&#x1f388;转置不变性质交换性质多重交换移动(抽出插入)&#x1f47a; 因子提取性质拆和性质倍加性质 手算行列式的主要方法原理:任何行列式都可以化为三角行列式 行列式性质&#x1f388; 设行列式 ∣ A ∣ d e t ( a i j ) |A|\mathrm{det}(a_{ij}) …

vue 关于axios的使用方法

axios定义&#xff1a; axios 前端 ajax请求工具 1. 在浏览器与nodejs可以使用 2. 可以拦截请求与相应 3. 扩展与封装自定义方法 4. 不依赖dom节点 安装 npm i axios -S 先在vue全局中挂载 import axios from ‘axios’ Vue.prototype.$h…

Docker 安装 Tomcat

目录 一、查看 tomcat 版本 二、拉取 Tomcat Docker 镜像 三、创建 Tomcat 容器 四、访问 Tomcat 五、停止和启动容器 一、查看 tomcat 版本 访问 tomcat 镜像库地址&#xff1a;https://hub.docker.com/_/tomcat&#xff0c;可以通过 Tags 查看其他版本的 tomcat; 二、拉…

Elasticsearch8.8.0 SpringBoot实战操作各种案例(索引操作、聚合、复杂查询、嵌套等)

Elasticsearch8.8.0 全网最新版教程 从入门到精通 通俗易懂 配置项目 引入依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.16</version></dependency><dependency>&l…

Android Studio 的Gradle版本修改

使用Android Studio构建项目时&#xff0c;需要配置Gradle&#xff0c;与Gradle插件。 Gradle是一个构建工具&#xff0c;用于管理和自动化Android项目的构建过程。它使用Groovy或Kotlin作为脚本语言&#xff0c;并提供了强大的配置能力来定义项目的依赖关系、编译选项、打包方…

Jtti:linux如何配置dns域名解析服务器

要配置Linux上的DNS域名解析服务器&#xff0c;您可以按照以下步骤进行操作&#xff1a; 1. 安装BIND软件包&#xff1a;BIND是Linux上最常用的DNS服务器软件&#xff0c;您可以使用以下命令安装它&#xff1a; sudo apt-get install bind9 2. 配置BIND&#xff1a;BIND的配置…

Spring Cloud常见问题处理和代码分析

目录 1. 问题&#xff1a;如何在 Spring Cloud 中实现服务注册和发现&#xff1f;2. 问题&#xff1a;如何在 Spring Cloud 中实现分布式配置&#xff1f;3. 问题&#xff1a;如何在 Spring Cloud 中实现服务间的调用&#xff1f;4. 问题&#xff1a;如何在 Spring Cloud 中实现…

HCIA---OSI/RM--开放式系统互联参考模型

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.OSI--开放式系统互联参考模型简介 OSI开放式系统互联参考模型是一种用于计算机网络通信…

解密Redis:应对面试中的缓存相关问题2

面试官&#xff1a;Redis集群有哪些方案&#xff0c;知道嘛&#xff1f; 候选人&#xff1a;嗯~~&#xff0c;在Redis中提供的集群方案总共有三种&#xff1a;主从复制、哨兵模式、Redis分片集群。 面试官&#xff1a;那你来介绍一下主从同步。 候选人&#xff1a;嗯&#xff…

基于WebRTC升级的低延时直播

快直播-基于WebRTC升级的低延时直播-腾讯云开发者社区-腾讯云 标准WebRTC支持的音视频编码格式已经无法满足国内直播行业需求。标准WebRTC支持的视频编码格式是VP8/VP9和H.264&#xff0c;音频编码格式是Opus&#xff0c;而国内推流的音视频格式基本上是H.264/H.265AAC的形式。…