Pytorch中nn.Linear使用方法

nn.Linear定义一个神经网络的线性层:

torch.nn.Linear(in_features,             # 输入的神经元个数out_features,            # 输出神经元个数bias=True                # 是否包含偏置)

nn.Linear其实就是对输入x_{n\times i}(n表示样本数量,i表示样本特征数)执行了一个线性变换,即:

Y_{n\times o } = X_{n\times i}W_{i\times o} + b

其中W矩阵是模型要学习的参数,b是1*O的向量偏置(即1行O列),n表示输入向量的个数(也可以理解为行数,比如一次输入100个样本数据,则n=100),i为每个样本的特征数,也可以理解为神经元的个数,O为输出样本的特征数,即输出神经元的个数。

from torch import nn
import torchmodel = nn.Linear(3, 1)           # 每个样本输入特征数设置为3,输出特征数设置为1input = torch.Tensor([2, 4, 6])   # 给一个样本,该样本有3个特征,这3个特征分别是2、4、6
output = model(input)print("nn.Linear 输出大小:{}".format(output.shape))
print(output)
print("")print("查看模型参数W和b的值")
# 查看模型参数
for param in model.parameters():print(param)输出:
nn.Linear 输出大小:torch.Size([1])    #输出结果表示只有一个样本输出,且该样本只有一个特征值1
tensor([-0.7842], grad_fn=<AddBackward0>)查看模型参数W和b的值
Parameter containing:
tensor([[ 0.2353, -0.5686,  0.1759]], requires_grad=True)
Parameter containing:
tensor([-0.0356], requires_grad=True)

可以看到,模型有4个参数,分别为W的三个权重和b的一个偏置。手动计算验证结果:

0.2353*2 + (-0.5686)*4 + 0.1759*6 + (-0.0356) = -0.7839999999999997

假设有5个输入样本A、B、C、D、E(即batch_size为5),每个样本的特征数量为3,定义线性层时,输入特征为3,所以in_feature=3,想让下一层的神经元个数为5,所以out_feature=5,则模型参数为:

model = nn.Linear(in_features=3, out_features=5, bias=True)

此时参数矩阵W大小为3行3列

from torch import nn
import torchmodel = nn.Linear(3, 5)           # 每个样本输入特征数设置为3,输出特征数设置为1input = torch.Tensor([[2, 4, 6],[8,10,12],[14,16,18],[20,22,24],[26,28,30]])   # 给一个样本,该样本有3个特征,这3个特征分别是2、4、6print(input)output = model(input)print("nn.Linear 输出大小:{}".format(output.shape))
print(output)
print("")print("查看模型参数W和b的值")
# 查看模型参数
for param in model.parameters():print(param)输出:
tensor([[ 2.,  4.,  6.],[ 8., 10., 12.],[14., 16., 18.],[20., 22., 24.],[26., 28., 30.]])
nn.Linear 输出大小:torch.Size([5, 5])
tensor([[ -0.9616,  -0.9744,   2.6266,  -0.5605,  -4.2236],[ -1.7251,  -4.4417,   5.9969,  -1.3649, -11.0200],[ -2.4886,  -7.9090,   9.3673,  -2.1692, -17.8163],[ -3.2522, -11.3763,  12.7376,  -2.9736, -24.6127],[ -4.0157, -14.8436,  16.1079,  -3.7779, -31.4090]],grad_fn=<AddmmBackward>)查看模型参数W和b的值
Parameter containing:
tensor([[ 0.0714,  0.1456, -0.3443],[-0.5098, -0.0893,  0.0211],[ 0.3489, -0.2682,  0.4811],[ 0.0768, -0.3863,  0.1755],[-0.2832, -0.4325, -0.4170]], requires_grad=True)
Parameter containing:
tensor([ 0.3789,  0.2753,  0.1153, -0.2216,  0.5748], requires_grad=True)

第一个样本特征为[2、4、6],输出为[ -0.9616,  -0.9744,   2.6266,  -0.5605,  -4.2236],验证过程如下:

%w是模型参数矩阵
w = [[ 0.0714,  0.1456, -0.3443],[-0.5098, -0.0893,  0.0211],[ 0.3489, -0.2682,  0.4811],[ 0.0768, -0.3863,  0.1755],[-0.2832, -0.4325, -0.4170]];
x = [2,4,6];
b = [0.3789,  0.2753,  0.1153, -0.2216,  0.5748];   %偏置向量
x*w'+b输出:-0.9617   -0.9749    2.6269   -0.5602   -4.2236

第2个样本验证:

w = [[ 0.0714,  0.1456, -0.3443],[-0.5098, -0.0893,  0.0211],[ 0.3489, -0.2682,  0.4811],[ 0.0768, -0.3863,  0.1755],[-0.2832, -0.4325, -0.4170]];
x = [8,10,12];
b = [0.3789,  0.2753,  0.1153, -0.2216,  0.5748];
x*w'+b输出:
-1.7255   -4.4429    5.9977   -1.3642  -11.0198

第3、4、5个样本的验证过程类似,从以上验证可以看出,所有样本共享参数矩阵W和偏置b

因为有5个样本,所以相当于依次进行了5次以上操作。

该操作重复了5次,每个样本重复一次:Y_{1\times 5}=X_{1\times 3}W_{3\times 5} + b_{1\times 5}

然后再将5个Y _{1 \times 5}叠加在一起,得到5*5的输出
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/802062.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构与算法】力扣 142. 环形链表 II

题目描述 给定一个链表的头节点 head &#xff0c;返回链表开始入环的第一个节点。 如果链表无环&#xff0c;则返回 null。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;评测系统…

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第二套

华为海思校园招聘-芯片-数字 IC 方向 题目分享&#xff08;有参考答案&#xff09;——第二套&#xff08;共九套&#xff0c;每套四十个选择题&#xff09; 部分题目分享&#xff0c;完整版获取&#xff08;WX:didadidadidida313&#xff0c;加我备注&#xff1a;CSDN huawei…

Git-LFS 远程命令执行漏洞 CVE-2020-27955 漏洞复现

今天遇到了一个比较有意思的洞&#xff0c;复现一下下.......... 漏洞描述 Git LFS 是 Github 开发的一个 Git 的扩展&#xff0c;用于实现 Git 对大文件的支持 一些受影响的产品包括Git&#xff0c;GitHub CLI&#xff0c;GitHub Desktop&#xff0c;Visual Studio&#xff0…

51单片机之自己配串口寄存器实现波特率9600

本配置是根据手册进行开发配置的 1、首先配置SCON 所以综上所诉 SCON 0x40 &#xff08;0100 0000&#xff09; 2、PCON不用配置 3、配置定时器1 4、波特率的计算 5、配置AUXR 6、对比 7、实现 8、优化&#xff08;实现字符串&#xff09; 引入TI &#xff08;智能延时&…

对于嵌入式工程师,需要掌握的知识是广还是精?

我刚开始接触嵌入式的时候&#xff0c;感觉学这个好变态啊。 要学的东西太多了&#xff0c;数字电路、模拟电路、C语言、汇编、51单片机、Protel 99SE、Pcb Layout、STM32单片机、RTOS、Linux、ARM等等.... 可以说&#xff0c;随便拿个魔法电路出来&#xff0c;想达到精的程度&…

【C++】C++11可变参数模板

&#x1f440;樊梓慕&#xff1a;个人主页 &#x1f3a5;个人专栏&#xff1a;《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 &#x1f31d;每一个不曾起舞的日子&#xff0c;都是对生命的辜负 目录 前言 可变参数模板的定义…

Java绘图坐标体系

一、介绍 下图说明了Java坐标系。坐标原点位于左上角&#xff0c;以像素为单位。在Java坐标系中&#xff0c;第一个是x坐标&#xff0c;表示当前位置为水平方向&#xff0c;距离坐标原点x个像素&#xff1b;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐…

LLM大语言模型(九):LangChain封装自定义的LLM

背景 想基于ChatGLM3-6B用LangChain做LLM应用&#xff0c;需要先了解下LangChain中对LLM的封装。本文以一个hello world的封装来示例。 LangChain中对LLM的封装 继承关系&#xff1a;BaseLanguageModel——》BaseLLM——》LLM LLM类 简化和LLM的交互 _call抽象方法定义 ab…

操作系统理论知识快速总览

操作系统整体架构 搬出考研时的思维导图 操作系统主要分为 批处理系统(老古董&#xff0c;基本不用了)实时操作系统(嵌入式中使用较多&#xff0c;RTOS)分时操作系统(PC中使用较多&#xff0c;Linux&#xff0c;Windows) 分时操作系统和实时操作系统的使用场景不同&#xf…

【蓝桥杯第十二届省赛B】(部分详解)

空间 8位1b 1kb1024b(2^10) 1mb1024kb(2^20) 时间显示 #include <iostream> using LLlong long; using namespace std; int main() {LL t;cin>>t;int HH,MM,SS;t/1000;SSt%60;//like370000ms370s,最后360转成分余下10st/60;MMt%60;t/60;HHt%24;printf("%02d:…

[C语言]——动态内存管理

目录 一.为什么要有动态内存分配 二.malloc和free 1.malloc 2.free 三.calloc和realloc 1.calloc 2.realloc 3.空间的释放​编辑 四.常见的动态内存的错误 1.对NULL指针的解引用操作 2.对动态开辟空间的越界访问 3.对非动态开辟内存使用free释放 4.使用free释放⼀块…

外汇110:谷歌起诉应用程序开发商伪造加密投资APP诈骗!

谷歌&#xff08;Google&#xff09;已对两家应用程序开发商提起诉讼&#xff0c;指控其参与“国际在线消费者投资欺诈计划”。该计划欺骗用户从 Google Play 商店和其他渠道下载虚假的安卓&#xff08;Android&#xff09;应用程序&#xff0c;并以承诺更高回报为幌子窃取他们…

SinoDB用户权限

SinoDB用户权限是由数据库对象和操作类型两个要素组成的&#xff0c;定义一个用户的权限就是定义这个用户可以对哪些数据对象进行哪些类型的操作。 SinoDB使用了三级权限来保证数据的安全性&#xff0c;它们分别是数据库级权限&#xff0c;表级权限和字段级权限。 1. 数据库级…

备考ICA----Istio实验17---TCP流量授权

备考ICA----Istio实验17—TCP流量授权 1. 环境准备 1.1 环境部署 kubectl apply -f <(istioctl kube-inject -f istio/samples/tcp-echo/tcp-echo.yaml) -n kim kubectl apply -f <(istioctl kube-inject -f istio/samples/sleep/sleep.yaml) -n kim1.2 测试环境 检测…

LangChain-14 Moderation OpenAI提供的功能:检测内容中是否有违反条例的内容

背景描述 我们在调用OpenAI的接口时&#xff0c;有些内容可能是违反条例的&#xff0c;所以官方提供了一个工具来检测。 安装依赖 pip install --upgrade --quiet langchain-core langchain langchain-openai编写代码 下文中我们使用了: OpenAIModerationChain 这个工具来…

PHP运算符与流程控制

华子目录 运算符赋值运算符算术运算符比较运算符逻辑运算符连接运算符错误抑制符三目运算符自操作运算符 计算机码位运算符 运算符优先级流程控制控制分类顺序结构分支结构if分支switch分支 循环结构for循环while循环continuebreak 运算符 运算符&#xff1a;operator&#xf…

JNA、JNI、原生C++函数调用效率及测试过程

结论 如果JAVA要高效调用C函数&#xff0c;则需要通过JNI封装C函数后进行native方法调用&#xff0c;JNI的执行效率比JNA高600倍左右。从开发效率上来说&#xff0c;JNA开发速度比JNI快许多&#xff0c;因为不需要做二次封装 测试对比 纯C调用&#xff1a; Function call to…

深入了解iOS内存(WWDC 2018)笔记-内存诊断

主要记录下用于分析iOS/macOS 内存问题的笔记。 主要分析命令&#xff1a; vmmap, leaks, malloc_history 一&#xff1a;前言 有 3 种思考方式 你想看到对象的创建吗&#xff1f;你想要查看内存中引用对象或地址的内容吗&#xff1f;或者你只是想看看 一个实例有多大&#…

【强化学习】Actor-Critic

Actor-Critic算法 欢迎访问Blog全部目录&#xff01; 文章目录 Actor-Critic算法1.Actor-Critic原理1.1.简述1.1.优劣势1.3.策略网络和价值网络1.3.1.策略网络&#xff08;Actor)1.3.2.价值网络&#xff08;Critic) 1.4.程序框图和伪代码 2.算法案例&#xff1a;Pendulum-v12…

T-Mamba:用于牙齿 3D CBCT 分割的频率增强门控长程依赖性

T-Mamba&#xff1a;用于牙齿 3D CBCT 分割的频率增强门控长程依赖性 摘要Introduction方法T-Mamba architectureTim block T-Mamba: Frequency-Enhanced Gated Long-Range Dependendcy for Tooth 3D CBCT Segmentation 摘要 三维成像中的高效牙齿分割对于正畸诊断至关重要&am…