weight-tying探索

在一些领域,将嵌入层和输出层的权重绑定,以达到减少参数量并使得相同token保持统一的embedding空间的作用。

下面的nn.Linear(3, 10)的权重矩阵的尺寸是10*3,即y = W @ x + b,因此跟nn.Embedding(10, 3)的权重矩阵大小相等。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Model_1(nn.Module):def __init__(self):super(Model_1, self).__init__()self.embedding = nn.Embedding(10, 3)self.head = nn.Linear(3, 10)# self.embedding.weight = self.head.weightdef forward(self, x):output = self.embedding(x)output = self.head(output)return F.softmax(output, dim=-1)    class Model_2(nn.Module):def __init__(self):super(Model_2, self).__init__()self.embedding = nn.Embedding(10, 3)self.head = nn.Linear(3, 10)# 使用下面这行代码,二者权重会同步更新self.embedding.weight = self.head.weightdef forward(self, x):output = self.embedding(x)output = self.head(output)return F.softmax(output, dim=-1)model_1 = Model_1()
model_2 = Model_2()torch.manual_seed(0)
input_indexes = torch.randint(0, 10, (2, 3))
target = torch.zeros(2, 3, 10)
for i in range(2):for j in range(3):target[i, j, input_indexes[i, j]] = 1
print(target)# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
optimizer_1 = torch.optim.Adam(model_1.parameters(), lr=0.001)
optimizer_2 = torch.optim.Adam(model_2.parameters(), lr=0.001)
loss_tying = []
loss_no_tying = []for _ in range(2000):output_1 = model_1(input_indexes)loss = criterion(output_1, target)optimizer_1.zero_grad()loss.backward()optimizer_1.step()loss_no_tying.append(loss.item())output_2 = model_2(input_indexes)loss = criterion(output_2, target)optimizer_2.zero_grad()loss.backward()optimizer_2.step()loss_tying.append(loss.item())# print(output)
print(model_1.embedding.weight==model_1.head.weight)
print(model_2.embedding.weight==model_2.head.weight)
import matplotlib.pyplot as plt
plt.plot(loss_tying, label="use weight tying")
plt.plot(loss_no_tying, label="not use weight tying")
plt.legend()
plt.show()

在这里插入图片描述
可以看到,在这个例子中,使用 weight-tying 后 loss 收敛更快。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/799884.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

语音特征的反应——语谱图

语谱图的横坐标为时间,纵坐标为对应时间点的频率。坐标中的每个点用不同颜色表示,颜色越亮表示频率越大,颜色越淡表示频率越小。可以说语谱图是一个在二维平面展示三维信息的图,既能够表示频率信息,又能够表示时间信息。 创建和绘制语谱图的…

卫星遥感监测森林植被健康度

随着地球环境的日益恶化,森林作为地球上最重要的生态系统之一,其变化对全球气候、生态环境和人类社会经济发展产生深远影响。因此,及时、准确地监测森林变化对于保护生态环境、维护生态平衡、推进可持续发展具有重要意义。卫星遥感影像技术因…

若依框架学习——分页查询列表

条件查询【多条件】列表展示【分页】SaCheckPermissionTableName TableId NotBlank Page分页 响应数据封装类

C#速览入门

C# & .NET C# 程序在 .NET 上运行,而 .NET 是名为公共语言运行时 (CLR) 的虚执行系统和一组类库。 CLR 是 Microsoft 对公共语言基础结构 (CLI) 国际标准的实现。 CLI 是创建执行和开发环境的基础,语言和库可以在其中无缝地协同工作。 用 C# 编写的…

基于springboot实现教师人事档案管理系统项目【项目源码+论文说明】

基于springboot实现IT技术交流和分享平台系统演示 摘要 我国科学技术的不断发展,计算机的应用日渐成熟,其强大的功能给人们留下深刻的印象,它已经应用到了人类社会的各个层次的领域,发挥着重要的不可替换的作用。信息管理作为计算…

asm磁盘组无法写入问题-处理中

有个11204的rac环境,没应用补丁,5号突然报归档满,登录环境后发现奇怪,一个1T磁盘建成的DATA磁盘组使用了近800G,读写正常,一个1.5T磁盘建成的FRA磁盘组,目前还剩余729551M,无法写入归…

SAP ABAP ALV转换例程的问题

为关键用户开发了一个ALV报表,因为导出Excel导致 curr性质的字段 例程的 问题 ,使得负号后置,Excel不能直接运算,需要转换你成数值后才可以,经过调试发现是对应的域 的转换例程的问题 FUNCTION CONVERSION_EXIT_AC152_…

雷达学习之多普勒频率

一、多普勒频率如何产生? 雷达的原理是发射一些无线电脉冲来探测目标,并通过回波的延时来计算目标与雷达的距离,但当目标为运动物体时,在回波向目标传输的同时,目标也会远离或接近回波,所以会导致回波信号…

ctfshow web入门 文件包含 web151--web161

web151 打算用bp改文件形式(可能没操作好)我重新试了一下抓不到 文件上传不成功 改网页前端 鼠标右键&#xff08;检查&#xff09;&#xff0c;把png改为php访问&#xff0c;执行命令 我上传的马是<?php eval($_POST[a]);?> 查看 web152 上传马 把Content-Type改为…

【nnUNetv2实践】二、nnUNetv2快速入门-训练验证推理集成一条龙教程

nnUNet是一个自适应的深度学习框架&#xff0c;专为医学图像分割任务设计。以下是关于nnUNet的详细解释和特点&#xff1a; 自适应框架&#xff1a;nnUNet能够根据具体的医学图像分割任务自动调整模型结构、训练参数等&#xff0c;从而避免了繁琐的手工调参过程。 自动化流程&a…

C++流程控制语句:嵌套循环案例分析【九九乘法表】

在C++编程中,循环语句的嵌套是一种常见且强大的技术手段,它允许我们将多个循环结构相互嵌套,形成多维循环。不论是for循环、while循环还是do…while循环,均可以进行嵌套。 而在实践中,由于for循环具有明确的循环变量初始化、条件判断和更新机制,因此在嵌套循环中,for循…

C语言第四十一弹---猜数字游戏

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 猜数字游戏 1、随机数生成 1.1、rand 1.2、srand 1.3、time 1.4、设置随机数的范围 2、猜数字游戏的分析和设计 2.1、猜数字游戏功能说明 2.2、猜数字游戏…

如何用Java后端处理JS.XHR请求

Touching searching engine destroies dream to utilize php in tomcat vector.The brave isn’t knocked down&#xff0c;turn its path to java back-end. Java Servlet Bible schematic of interaction between JS front-end and Java back-end Question 如何利用Java…

[C++][算法基础]最大异或对(Trie树)

在给定的 N 个整数 &#xff0c;...... 中选出两个进行 xor&#xff08;异或&#xff09;运算&#xff0c;得到的结果最大是多少&#xff1f; 输入格式 第一行输入一个整数 N。 第二行输入 N 个整数 ~ 。 输出格式 输出一个整数表示答案。 数据范围 1≤N≤, 0≤< 输…

【数据结构与算法】力扣 19. 删除链表的倒数第 N 个结点

题目描述 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a; head [1,2,3,4,5], n 2 输出&#xff1a; [1,2,3,5]示例 2&#xff1a; 输入&#xff1a; head [1], n 1 输出&#xff1a; []示例…

Mamba入局遥感图像分割 | Samba: 首个基于SSM的遥感高分图像语义分割框架

文章目录 1、导读 2、背景 3、动机 4、方法 5、实验 6、总结 标题&#xff1a;《Samba: Semantic Segmentation of Remotely Sensed Images with State Space Model》论文&#xff1a;https://arxiv.org/abs/2404.01705源码&#xff1a;https://github.com/zhuqinfeng1999…

在展会上如何介绍产品和公司,柯桥俄语培训

1.Приглашаем Вас… 邀请您…… 2. Позвольте пригласить Вас… 请允许邀请您…… 3.Имеем честь пригласить Вас … 诚挚邀请您…… 4. Посылаем Вам приглашение на… 给您&#xff0…

Vue - 你知道Vue中key的工作原理吗

难度级别:中级及以上 提问概率:80% 在Vue项目开发中,并不推荐使用索引做为key,以为key必须是唯一的,可以使用服务端下发的唯一ID值,也不推荐使用随机值做为key,因为如果每次渲染都监听到不一样的key,那么节点将无法复用,这与Vue节省…

android gradle版本无法下载

android gradle版本无法下载问题解决方法 在引入一个新的android项目的时候&#xff0c;通常会因为无法下载gradle版本而一直卡在同步界面&#xff0c;类似于下面的情况。 这是因为gradle运行时首先会检查distributionUrlhttps://services.gradle.org/distributions/gradle-5.6…

JavaScript逆向爬虫——无限debugger的原理与绕过

debugger 是 JavaScript 中定义的一个专门用于断点调试的关键字&#xff0c;只要遇到它&#xff0c;JavaScript 的执行便会在此处中断&#xff0c;进入调试模式。 有了 debugger 这个关键字&#xff0c;就可以非常方便地对 JavaScript 代码进行调试&#xff0c;比如使用 JavaSc…