最简单知识点PyTorch中的nn.Linear(1, 1)

一、nn.Linear(1, 1)

nn.Linear(1, 1) 是 PyTorch 中的一个线性层(全连接层)的定义。

nn 是 PyTorch 的神经网络模块(torch.nn)的常用缩写。

nn.Linear(1, 1) 的含义如下:

  • 第一个参数 1:输入特征的数量。这表示该层接受一个长度为 1 的向量作为输入
  • 第二个参数 1:输出特征的数量。这表示该层产生一个长度为 1 的向量作为输出

因此,nn.Linear(1, 1) 定义了一个简单的线性变换,其数学形式为:y=x⋅w+b
其中:

  • x 是输入向量(长度为 1)。
  • w 是权重(也是一个长度为 1 的向量)。
  • b 是偏置项(一个标量)。
  • y 是输出向量(长度为 1)。

在实际应用中,这样的线性层可能不常用,因为对于从长度为 1 的输入到长度为 1 的输出的映射,这实际上就是一个简单的线性变换,但在某些特定场景或作为更复杂模型的一部分时,它仍然可能是有用的。

二、简单举例

假设我们有一个简单的任务,需要预测一个线性关系,比如根据给定的输入值 x 来预测输出值 y,其中 y 是 x 的线性变换。在这种情况下,nn.Linear(1, 1) 可以用来表示这个线性关系。

以下是一个使用 PyTorch 和 nn.Linear(1, 1) 的简单例子:

import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
model = nn.Linear(1, 1)
# 定义损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器
# 假设我们有一些简单的线性数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32) # 假设 y = 2 * x
# 训练模型
for epoch in range(100): # 假设我们训练 100 个 epoch
        # 前向传播
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        # 反向传播和优化
        optimizer.zero_grad() # 清除梯度
        loss.backward() # 反向传播计算梯度
        optimizer.step() # 应用梯度更新权重
        # 打印损失值(可选)
        if (epoch+1) % 10 == 0:
                print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
# 测试模型
with torch.no_grad(): # 不需要计算梯度
        x_test = torch.tensor([[5.0]], dtype=torch.float32)
        y_pred = model(x_test)
        print(f'Predicted output for x=5: {y_pred.item()}')

运行截图:

图1 上述代码运行输出

在这个例子中,我们创建了一个简单的线性模型 nn.Linear(1, 1) 来学习输入 x 和输出 y 之间的线性关系。我们使用均方误差损失函数 nn.MSELoss() 随机梯度下降优化器 optim.SGD() 来训练模型。通过多次迭代(epoch),模型逐渐学习权重和偏置项(w, b)以最小化预测值与实际值之间的误差。最后,我们使用训练好的模型对新的输入值 x=5 进行预测,并打印出预测结果。

三、举一反三——nn.Linear(2, 1) 

nn.Linear(2, 1) 是PyTorch深度学习框架中用于定义一个线性层的语句。在深度学习中,线性层(也被称为全连接层或密集层)是一种非常基础的神经网络层,用于执行线性变换。

含义

nn.Linear(2, 1) 表示一个线性层,它接收一个具有2个特征的输入,并输出一个具有1个特征的结果。具体来说:

  • 第一个参数 2 表示输入特征的数量,即该层期望的输入维度是2。
  • 第二个参数 1 表示输出特征的数量,即该层输出的维度是1。

作用

这个线性层的作用是对输入的2个特征进行线性组合,然后输出一个单一的数值。数学上,这个过程可以表示为:

y = x1 * w1 + x2 * w2 + b

其中:

  • x1 和 x2 是输入特征。
  • w1 和 w2 是权重,它们在训练过程中会被学习。
  • b 是偏置项,也是一个在训练过程中会被学习的参数。
  • y 是该层的输出。

可能的应用场景

nn.Linear(2, 1) 可以应用于多种场景,特别是当需要将两个特征合并为一个单一特征时。以下是一些具体的例子:

  1. 回归问题:在简单的回归问题中,如果你有两个特征并希望预测一个连续的数值输出,你可以使用 nn.Linear(2, 1)。例如,预测房价时,你可能会根据房屋的面积和卧室数量来预测价格。

  2. 特征压缩:在某些情况下,你可能希望将多个特征压缩成一个特征,以便于后续处理或可视化。例如,在降维或特征工程中,nn.Linear(2, 1) 可以用于将两个特征转换为一个新的综合特征。

  3. 神经网络的一部分:在构建更复杂的神经网络时,nn.Linear(2, 1) 可以作为神经网络的一部分。例如,在多层感知机(MLP)中,这样的层可以与其他层(如激活层、dropout层等)结合使用,以构建能够处理复杂任务的模型。

需要注意的是,虽然 nn.Linear(2, 1) 本身只能执行线性变换,但在实际使用时,通常会与其他非线性层(如ReLU或sigmoid激活函数)结合使用,以构建能够学习非线性关系的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/802081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【人工智能】AI赋能城市交通 未来城市的驱动力

前言 随着城市化进程的不断加速,交通拥堵、环境污染等问题日益凸显,人们对交通系统的效率和可持续性提出了更高的要求。在这样的背景下,智能交通技术正成为改善城市交通的重要驱动力。本文将探讨智能交通技术在解决城市交通挑战方面的应用和未…

谷歌留痕霸屏要怎么做?

谷歌留痕霸屏,就是让你的网站或者页面在谷歌搜索结果里尽可能多地出现,就像是在你的潜在客户眼前留下深刻印象一样,你要做的就是在一些高权重平台发布有价值的信息,同时巧妙地留下你的品牌名、产品名或者任何你想要推广的关键词&a…

css实现各级标题自动编号

本文在博客同步发布,您也可以在这里看到最新的文章 Markdown编辑器大多不会提供分级标题的自动编号功能,但我们可以通过简单的css样式设置实现。 本文介绍了使用css实现各级标题自动编号的方法,本方法同样适用于typora编辑器和wordpress主题…

六角螺母缺陷分类数据集:3440张图像

六角螺母缺陷数据集:包含变形,划痕,断裂,生锈,以及优质螺母图片数据,共计3440张,无标注 一.变形螺母-1839 二.断裂螺母-287 三.划痕螺母-473 四.生锈螺母-529 五.优良螺母-312 适用于CV项目&am…

Flutter之Flex组件布局

目录 Flex属性值 轴向:direction:Axis.horizontal 主轴方向:mainAxisAlignment:MainAxisAlignment.center 交叉轴方向:crossAxisAlignment:CrossAxisAlignment 主轴尺寸:mainAxisSize 文字方向:textDirection:TextDirection 竖直方向排序:verticalDirection:VerticalDir…

灵猫论文好用吗 #媒体#笔记

灵猫论文是一款专门用于论文写作、查重降重的工具,它的使用方便、高效,深受广大论文作者的喜爱。那么,灵猫论文到底好用吗?答案是肯定的! 首先,灵猫论文提供了强大的查重降重功能,能够帮助用户快…

MySQL8.0新特性详解及全局优化

文章目录 一、前言二、开窗函数三、新增函数索引四、group by不再隐式排序五、新增降序索引六、binlog日志文件过期时间精确到秒七、undo文件不再使用系统表空间八、默认字符集由latin1变为utf8mb4九、自增变量持久化十、删除了.frm等文件 一、前言 目前MySQL8.0及以上版本在我…

Commitizen:规范化你的 Git 提交信息

简介 在团队协作开发过程中,规范化的 Git 提交信息可以提高代码维护的效率,便于追踪和定位问题。Commitizen 是一个帮助我们规范化 Git 提交信息的工具,它提供了一种交互式的方式来生成符合约定格式的提交信息。 原理 Commitizen 的核心原…

Vue3跟Vue2比,性能真的有所提升吗?

答案是肯定的。 说起Vue3的改进,很多人都会说出响应式的改变,与Vue2相比,Vue3采用了proxy的方式对响应式做了重写,而Vue2则是采用defineProperty的方式将对象的属性进行深度遍历,而这种方式想要实现响应式的前与后&am…

每日学习笔记:C++ STL算法之容器元素复制与搬移

本文API 复制元素: copy() copy_if(....,op) copy_n() copy_backward() 搬移元素: move() move_backward() 复制元素 搬移元素

SQL注入利用学习-Union联合注入

联合注入的原理 在SQL语句中查询数据时,使用select 相关语句与where 条件子句筛选符合条件的记录。 select * from person where id 1; #在person表中,筛选出id1的记录如果该id1 中的1 是用户可以控制输入的部分时,就有可能存在SQL注入漏洞…

Python爬虫与API交互:如何爬取并解析JSON数据

目录 前言 一、什么是API和JSON数据 二、准备环境 三、发送API请求并获取数据 四、解析JSON数据 五、完整代码示例 六、总结 前言 随着互联网的发展,越来越多的网站提供了API接口,供开发者获取实时数据。在爬虫领域中,与API交互并解析…

Pytorch中nn.Linear使用方法

nn.Linear定义一个神经网络的线性层: torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数biasTrue # 是否包含偏置)nn.Linear其实就是对输入(n表示样本数量,i表示样本特…

【数据结构与算法】力扣 142. 环形链表 II

题目描述 给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统…

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第二套

华为海思校园招聘-芯片-数字 IC 方向 题目分享(有参考答案)——第二套(共九套,每套四十个选择题) 部分题目分享,完整版获取(WX:didadidadidida313,加我备注:CSDN huawei…

Git-LFS 远程命令执行漏洞 CVE-2020-27955 漏洞复现

今天遇到了一个比较有意思的洞,复现一下下.......... 漏洞描述 Git LFS 是 Github 开发的一个 Git 的扩展,用于实现 Git 对大文件的支持 一些受影响的产品包括Git,GitHub CLI,GitHub Desktop,Visual Studio&#xff0…

51单片机之自己配串口寄存器实现波特率9600

本配置是根据手册进行开发配置的 1、首先配置SCON 所以综上所诉 SCON 0x40 (0100 0000) 2、PCON不用配置 3、配置定时器1 4、波特率的计算 5、配置AUXR 6、对比 7、实现 8、优化(实现字符串) 引入TI (智能延时&…

对于嵌入式工程师,需要掌握的知识是广还是精?

我刚开始接触嵌入式的时候,感觉学这个好变态啊。 要学的东西太多了,数字电路、模拟电路、C语言、汇编、51单片机、Protel 99SE、Pcb Layout、STM32单片机、RTOS、Linux、ARM等等.... 可以说,随便拿个魔法电路出来,想达到精的程度&…

【C++】C++11可变参数模板

👀樊梓慕:个人主页 🎥个人专栏:《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 🌝每一个不曾起舞的日子,都是对生命的辜负 目录 前言 可变参数模板的定义…

Java绘图坐标体系

一、介绍 下图说明了Java坐标系。坐标原点位于左上角,以像素为单位。在Java坐标系中,第一个是x坐标,表示当前位置为水平方向,距离坐标原点x个像素;第二个是y坐标,表示当前位置为垂直方向,距离坐…