Mixtral Moe代码解读

一直对稀疏专家网络好奇,有些专家没被选中,那么梯度是否为0,这一轮被选中有梯度,下一轮没被选中无梯度,模型可以训练收敛吗?

  • 由于每个token都会选择topk个专家,所以在每一轮epoch中,所有专家都参与了前向传播,所以梯度都能得到更新
  • 即使真有专家一直没被选中,那么其梯度保持不变,没有参与更新而已
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)# 获取到每个token的mlp层输入特征 
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)# 得到每个专家的打分,维度是batch * sequence, num_experts,取topk个专家
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)# 取到topk个专家的打分,需要计算在归一化一下,用于对后面的expert计算出来的结果进行加权
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# routing_weights、selected_experts 维度是一致的,取了topk   (bs * sl, topk)
routing_weights = routing_weights.to(hidden_states.dtype)final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)# 如果不做后面的维度切换,那expert_mask的维度是 (bs*sl, topk, n_experts),但是后面要遍历n_experts来计算,所以颠倒一下,得到(n_experts, topk, bs * sl); 
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)for expert_idx in range(self.num_experts):expert_layer = self.experts[expert_idx]idx, top_x = torch.where(expert_mask[expert_idx])"""这样取到expert_mask[expert_idx],从上面的注释可以知道维度是[topk, bs * sl];torch.where的结果,第一个结果代表选到了哪一行,第二个代表选择了哪一列对应到实际意义,top_x表示取的列,也就是取哪些token而行表示,取到的这些token,根据路由gate计算,当前expert是排行第几;所以这里变量名字可能有点混淆,"""# 没有token需要当前的expert计算if top_x.shape[0] == 0:continue# tensor index使用list比tensor快top_x_list = top_x.tolist()idx_list = idx.tolist()# 前面hidden states已经转成了 [bs * sl, hs],根据top_x 可以找到需要计算的token,这些token依旧是有序的current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)# 找到这个expert对应的权重 乘进去# 上面计算的权重是routing_weights,维度是bs * sl, topk# 根据top_x_list 对应的token,idx_list表示topk中第几个# 可以直接取到相应的权重current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]# 合到最终的特征里边去final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

参考

  • 理解Mixtral Moe模型原理与代码实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/612145.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python工具-udp-tcp-client-server-demo

python工具-udp-tcp-client-server-demo server tcp-server: python xxx.py -type tcp -ip “127.0.0.1” -port 1234udp-server: python xxx.py -type udp -ip “127.0.0.1” -port 1234 client python xxx.py -type udp -ip “127.0.0.1” -port 1111python xxx.py -type tc…

依赖Kafka的Go单元测试例解

Kafka[1]是Apache基金会开源的一个分布式事件流处理平台,是Java阵营(最初为Scala)中的一款杀手级应用,其提供的高可靠性、高吞吐量和低延迟的数据传输能力,让其到目前为止依旧是现代企业级应用系统以及云原生应用系统中使用的重要中间件。 在…

pytorch无法把共享内存写入文件

环境: 在容器中跑pytorch模型的训练 问题表现: ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm). Traceback (most recent call last): File "/root/anaconda3/lib/python3.8/m…

flutter getTemporaryDirectory()的使用

下面是上传音视频流的截图,先保存在缓存,然后请求接口,成功或者失败会删除文件。 可以在Device File Explorer查看, Android: 会返回 /data/data//cache 这个目录,是应用私有的缓存目录。 iOS: 会返回 Library/Caches 下的一个…

Java学习笔记-day06-响应式编程Reactor API大全(上)

Reactor 是一个基于响应式编程的库&#xff0c;主要用于构建异步和事件驱动的应用程序。Reactor 提供了丰富的 API&#xff0c;包括创建、转换、过滤、组合等操作符&#xff0c;用于处理异步数据流。以下是一些 Reactor 的主要 API 示例&#xff1a; pom依赖 <dependencyMan…

191. 位1的个数

编写一个函数&#xff0c;输入是一个无符号整数&#xff08;以二进制串的形式&#xff09;&#xff0c;返回其二进制表达式中数字位数为 1 的个数&#xff08;也被称为汉明重量&#xff09;。 提示&#xff1a; 请注意&#xff0c;在某些语言&#xff08;如 Java&#xff09;中…

LeetCode 2645. 构造有效字符串的最少插入数

一、题目 1、题目描述 LeetCode 给你一个字符串 word &#xff0c;你可以向其中任何位置插入 "a"、"b" 或 "c" 任意次&#xff0c;返回使 word 有效 需要插入的最少字母数。如果2645. 构造有效字符串的最少插入数 2、接口描述 ​ class Solut…

SpringCloud微服务 【实用篇】| RabbitMQ快速入门、SpringAMQP

目录 一&#xff1a;初始RabbitMQ 1. 同步和异步通讯 1.1 同步调用 1.2 异步调用 2. MQ常见框架 二&#xff1a;RabbitMQ快速入门 1. RabbitMQ概述和安装 2. 常见消息队列模型 3. 快速入门案例 三&#xff1a;SpringAMQP 1. Basic Queue 简单队列模型 2. Work Queu…

Hive事务表转换为非事务表

环境&#xff1a;hive3.1.0 由于建表时默认会建为非事务表 CREATE TABLE bucket_text_table2(column1 string,column2 string,column3 int) CLUSTERED BY (column3) into 5 BUCKETS STORED AS TEXTFILE; 执行完成后&#xff0c;查看默认建表语句&#xff1a; ---------------…

PHP 微信小程序获取 手机号码

PHP代码 $param $_POST; $app_id ""; $app_secret "";$url_get https://api.weixin.qq.com/cgi-bin/token?grant_typeclient_credential&appid . $app_id . &secret . $app_secret;$tmptoken json_decode(curl($url_get), true);$token $tm…

ORACLE RAC DG文件路径错乱解决办法

最近接手了一个客户的RAC-RAC dg环境的维护,登录上去之后发现dg延迟了8天,由于主库的空间非常紧张,归档日志早就删除了,所以准备使用rman基于scn点的备份恢复的方案恢复dg同步 在备份完成之后,使用新的控制文件进行数据恢复的时候报错datafile 43 not found: 检查了一下发现当…

SpringBoot中使用单例模式+ScheduledExecutorService实现异步多线程任务(若依源码学习)

场景 若依前后端分离版手把手教你本地搭建环境并运行项目&#xff1a; 若依前后端分离版手把手教你本地搭建环境并运行项目_本地运行若依前后端分离-CSDN博客 设计模式-单例模式-饿汉式单例模式、懒汉式单例模式、静态内部类在Java中的使用示例&#xff1a; 设计模式-单例模…

数据库 MySQL 索引的原理

在数据库中&#xff0c;索引是一种重要的数据结构&#xff0c;它用于加快数据的检索速度和提高查询性能。在 MySQL 中&#xff0c;索引的实现基于 B树结构。 索引的基本思想是通过维护一个有序的数据结构&#xff0c;来快速定位和访问表中的数据。B树是一种自平衡的二叉搜索树…

Python requirements.txt 详解

文章目录 1 概述1.1 作用1.2 注意 2 操作2.1 生成 requirements.txt2.2 安装 requirements.txt 3 示例3.1 新建 Django 项目3.2 找到 Scripts 目录&#xff0c;执行生成 requirements.txt 命令 1 概述 1.1 作用 作用&#xff1a;记录 当前项目下 所有 依赖包及其版本号&#…

Polars使用指南(一)

pandas是Python数据处理中非常经典的一个科学计算库&#xff0c;表形式的数据结构、丰富的API和灵活的编程语法使得pandas成为最常用的的数据分析工具。但是pandas也有一个最致命的缺陷&#xff0c;就是效率问题&#xff0c;尤其是不支持并行计算。pandas2在性能方面有了极大的…

不知道题目是啥

本题是学校的集训里的题&#xff0c;所有不知道题目名字是啥&#xff0c;直接看题目就好 解题思路&#xff1a;因为字符串只含有小写字母&#xff0c;所以可以创建两个数组分别来存s和t的每个字母出现次数&#xff0c;然后遍历数组&#xff0c;如果s字符串中的某个字母比t的小&…

输电线路分布式故障诊断装置的四大特点介绍-深圳鼎信

输电线路分布式故障诊断装置是一种利用行波测距、无线通信等技术手段实现电网故障定位的设备。这对于电网的故障处理和恢复具有重要意义&#xff0c;可以帮助运维人员提高故障处理的效率&#xff0c;缩短故障处理时间&#xff0c;减少停电时间&#xff0c;提高用户的供电可靠性…

premiere简约大气3D动画logo片头Pr模板Mogrt免费下载

Premiere简约大气3D动画logo片头pr模板mogrt下载&#xff0c;无需插件&#xff0c;高清分辨率&#xff0c;易于自定义&#xff0c;包括教程&#xff0c;不包括音频和图像。免费下载&#xff1a;https://prmuban.com/37065.html

Linux学习(1):目录结构、编辑器和用户管理

Linux学习&#xff08;1&#xff09;&#xff1a;目录结构、编辑器和用户管理 1 Linux目录结构2 vi和vim编辑器2.1 快捷键练习 3 用户管理3.1 添加用户3.2 删除用户即主目录3.3 切换用户 4 用户组 1 Linux目录结构 在linux世界里&#xff0c;一切皆为文件。 linux目录结构&a…

test fuzz-05-模糊测试 kelinci AFL-based fuzzing for Java

拓展阅读 开源 Auto generate mock data for java test.(便于 Java 测试自动生成对象信息) 开源 Junit performance rely on junit5 and jdk8.(java 性能测试框架。性能测试。压测。测试报告生成。) test fuzz-01-模糊测试&#xff08;Fuzz Testing&#xff09; test fuzz-…