haiku实现三角乘法模块

三角乘法(TriangleMultiplication)是作为一种更对称、更便宜的三角注意力(TriangleAttention)替代模块。

import jax
import haiku
import jax.numpy as jnpdef _layer_norm(axis=-1, name='layer_norm'):return common_modules.LayerNorm(axis=axis,create_scale=True,create_offset=True,eps=1e-5,use_fast_variance=True,scale_init=hk.initializers.Constant(1.),offset_init=hk.initializers.Constant(0.),param_axis=axis,name=name)class TriangleMultiplication(hk.Module):"""Triangle multiplication layer ("outgoing" or "incoming").Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming""""def __init__(self, config, global_config, name='triangle_multiplication'):super().__init__(name=name)self.config = configself.global_config = global_configdef __call__(self, left_act, left_mask, is_training=True):"""Builds TriangleMultiplication module.Arguments:left_act: Pair activations, shape [N_res, N_res, c_z]left_mask: Pair mask, shape [N_res, N_res].is_training: Whether the module is in training mode.Returns:Outputs, same shape/type as left_act."""del is_trainingif self.config.fuse_projection_weights:return self._fused_triangle_multiplication(left_act, left_mask)else:return self._triangle_multiplication(left_act, left_mask)# @hk.transparent 是 Haiku 中的函数修饰器,用于标记函数为透明模式。# 透明模式用于在神经网络模块内共享参数。@hk.transparentdef _triangle_multiplication(self, left_act, left_mask):"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""c = self.configgc = self.global_configmask = left_mask[..., None]act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,name='layer_norm_input')(left_act)input_act = actleft_projection = common_modules.Linear(c.num_intermediate_channel,name='left_projection')left_proj_act = mask * left_projection(act)right_projection = common_modules.Linear(c.num_intermediate_channel,name='right_projection')right_proj_act = mask * right_projection(act)left_gate_values = jax.nn.sigmoid(common_modules.Linear(c.num_intermediate_channel,bias_init=1.,initializer=utils.final_init(gc),name='left_gate')(act))right_gate_values = jax.nn.sigmoid(common_modules.Linear(c.num_intermediate_channel,bias_init=1.,initializer=utils.final_init(gc),name='right_gate')(act))left_proj_act *= left_gate_valuesright_proj_act *= right_gate_values# "Outgoing" edges equation: 'ikc,jkc->ijc'# "Incoming" edges equation: 'kjc,kic->ijc'# Note on the Suppl. Alg. 11 & 12 notation:# For the "outgoing" edges, a = left_proj_act and b = right_proj_act# For the "incoming" edges, it's swapped:#   b = left_proj_act and a = right_proj_actact = jnp.einsum(c.equation, left_proj_act, right_proj_act)act = common_modules.LayerNorm(axis=[-1],create_scale=True,create_offset=True,name='center_layer_norm')(act)output_channel = int(input_act.shape[-1])act = common_modules.Linear(output_channel,initializer=utils.final_init(gc),name='output_projection')(act)gate_values = jax.nn.sigmoid(common_modules.Linear(output_channel,bias_init=1.,initializer=utils.final_init(gc),name='gating_linear')(input_act))act *= gate_valuesreturn act@hk.transparentdef _fused_triangle_multiplication(self, left_act, left_mask):"""TriangleMultiplication with fused projection weights."""mask = left_mask[..., None]c = self.configgc = self.global_configleft_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)# Both left and right projections are fused into projection.projection = common_modules.Linear(2*c.num_intermediate_channel, name='projection')proj_act = mask * projection(left_act)# Both left + right gate are fused into gate_values.gate_values = common_modules.Linear(2 * c.num_intermediate_channel,name='gate',bias_init=1.,initializer=utils.final_init(gc))(left_act)proj_act *= jax.nn.sigmoid(gate_values)left_proj_act = proj_act[:, :, :c.num_intermediate_channel]right_proj_act = proj_act[:, :, c.num_intermediate_channel:]act = jnp.einsum(c.equation, left_proj_act, right_proj_act)act = _layer_norm(axis=-1, name='center_norm')(act)output_channel = int(left_act.shape[-1])act = common_modules.Linear(output_channel,initializer=utils.final_init(gc),name='output_projection')(act)gate_values = common_modules.Linear(output_channel,bias_init=1.,initializer=utils.final_init(gc),name='gating_linear')(left_act)act *= jax.nn.sigmoid(gate_values)return act


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629423.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【CF闯关练习】—— 1400分(C. Make Good、B. Applejack and Storages)

&#x1f30f;博客主页&#xff1a;PH_modest的博客主页 &#x1f6a9;当前专栏&#xff1a;cf闯关练习 &#x1f48c;其他专栏&#xff1a; &#x1f534;每日一题 &#x1f7e1; C跬步积累 &#x1f7e2; C语言跬步积累 &#x1f308;座右铭&#xff1a;广积粮&#xff0c;缓…

isis小实验

要求: 1.合理规划level1-2 2.r1访问r5走r6且走上面 3.全网可达 个人理解:以重发布的视角:is-level level1即L1可以看做rip,L2可以看做OSPF,L1-2可以看作是既要rip又要OSPF,优点:isis只用在每个路由器上宣告一次 缺点:isis需要每个接口上输isis enable 1(序号)特点:L1-2会自动下…

vue流程图

效果图 组件 <template><div class="processBox" v-if="list.length"><div class="childs"><div class="child" v-for="(item,index) in list" :key="item.id +-child-+index"><div…

C++中的指针、引用和数组

参考文档&#xff1a;《21天学通C&#xff08;第8版&#xff09;》 C中对于指针、引用和数组使用时&#xff0c;充斥着 * 、& 、[]符号&#xff0c;对于像我这样的初学者面对这些符号难免会陷入混乱。 当然&#xff0c;C中对符号 * 、& 、[] 赋予了多重意义也是让人容易…

C //练习 6-5 编写函数undef,它将从由lookup和install维护的表中删除一个变量及其定义。

C程序设计语言 &#xff08;第二版&#xff09; 练习 6-5 练习 6-5 编写函数undef&#xff0c;它将从由lookup和install维护的表中删除一个变量及其定义。 注意&#xff1a;代码在win32控制台运行&#xff0c;在不同的IDE环境下&#xff0c;有部分可能需要变更。 IDE工具&am…

Java http 响应式请求和非响应式请求有什么区别

在Java中&#xff0c;HTTP的响应式请求和非响应式请求有以下区别&#xff1a; HTTP协议本身并不直接支持响应式请求&#xff0c;因为HTTP是基于请求-响应模型的。然而&#xff0c;可以通过使用其他技术和协议来实现响应式请求。 响应方式&#xff1a;响应式请求是指使用响应式编…

SpringBoot 调用错:getWriter() has already been called for this response

这个错误通常表明您尝试从Spring MVC返回一个已使用的HttpServletResponse对象。 原因&#xff1a;这可能是由于直接调用HttpServletResponse的getWriter()或getOutputStream()方法&#xff0c;或者由于在控制器方法中抛出异常而自动调用HttpServletResponse的write()方法。 修…

第10章_多线程扩展练习(Thread类中的方法,线程创建,线程通信)

文章目录 第10章_多线程扩展练习Thread类中的方法1、新年倒计时 线程创建2、奇偶数输出3、强行加塞4、奇偶数打印5、龟兔赛跑友谊赛6、龟兔赛跑冠军赛7、多人过山洞8、奇偶数连续打印9、字母连续打印 线程通信10、奇偶数交替打印11、银行账户-112、银行账户-2 第10章_多线程扩展…

协方差矩阵自适应调整的进化策略(CMA-ES)

关于CMA-ES&#xff0c;其中 CMA 为协方差矩阵自适应(Covariance Matrix Adaptation)&#xff0c;而进化策略&#xff08;Evolution strategies, ES&#xff09;是一种无梯度随机优化算法。CMA-ES 是一种随机或随机化方法&#xff0c;用于非线性、非凸函数的实参数&#xff08;…

SparkSQL——DataFrame

DataFrame Dataframe 是什么 DataFrame 是 SparkSQL中一个表示关系型数据库中 表的函数式抽象, 其作用是让 Spark处理大规模结构化数据的时候更加容易. 一般 DataFrame可以处理结构化的数据, 或者是半结构化的数据, 因为这两类数据中都可以获取到 Schema信息. 也就是说 DataFra…

数据结构之tuple类

前言 tuple 是元组类。tuple 就很有意思了&#xff0c;它和上一篇文章介绍的list 十分相似&#xff0c;都是线性表。最大的不同就是list 可以改变&#xff0c;而tuple 是不可变的。元组就像是列表的补充&#xff0c;我们甚至可以这么理解&#xff1a;元组就是只读的列表。 1.…

自动驾驶模拟器

目录 Carla 自动驾驶模拟器 Udacity自动驾驶模拟器 Carla 自动驾驶模拟器 pip install carla 需要下载地图 Udacity自动驾驶模拟器

一文带你揭秘淘宝终端技术

作者&#xff1a;周杰&#xff08;寻弦&#xff09; 在这个数字化迅速发展的时代&#xff0c;技术的每一次飞跃都不仅仅意味着一个产品的升级&#xff0c;更是对未来世界的一次大胆想象。从 PC 到 iPhone&#xff0c;从 Model 3 到 ChatGPT&#xff0c;都引领了全新的一个行业。…

智慧校园大数据平台功能模块

学校概况模块 智慧校园大数据平台的“学校概况”模块,主要给学校和院系领导使用,能够从宏观、全局把控学校教学、管理、科研、资产等各个方面的整体情况,可以预测学校的发展趋势并且给出决策建议。 比如在消费方面,校领导可以看到近一个月的消费金额和地点的情况,也可以…

AttributeError: module ‘openai‘ has no attribute ‘error‘解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

我的创作纪念日(730天)

机缘 不知不觉来到CSDN已经730天了&#xff0c;这两年来我收获丰富&#xff0c;从原本的只是从CSDN获取知识&#xff0c;到现在的传播知识&#xff0c;我感觉受益良多&#xff0c;一年多的沉淀&#xff0c;让我在这三个月中绽放&#xff0c;粉丝也从原本的两位数到现在的四千&…

【人工智能与深度学习】当输入层维度为1024,输出层维度为100时,为什么全连接层参数量为1024*100+100

当输入层维度为1024&#xff0c;输出层维度为100时&#xff0c;为什么全连接层参数量为1024*100100 在神经网络中&#xff0c;全连接层&#xff08;也称为稠密层或线性层&#xff09;的参数量计算通常包括权重&#xff08;weights&#xff09;和偏置&#xff08;biases&#x…

【ES6】解构语句中的冒号(:)

在解构赋值语法中&#xff0c;冒号&#xff08;:&#xff09;的作用是为提取的字段指定一个新的变量名。 让我们以示例 const { billCode: code, version } route.query 来说明&#xff1a; { billCode: code, version } 表示从 route.query 对象中提取 billCode 和 version…

每日一记:一个windows的bat脚本工具集

最近在工作上遇到要校验文件的问题&#xff0c;例如&#xff0c;下载了一个文件之后&#xff0c;通过查看文件的md5来校验文件是否完整&#xff0c;这个动作在linux上很简单&#xff0c;但在windows上也不难&#xff0c;可以通过 certutil 命令实现&#xff0c;该命令通常可用于…

SpringBoot项目如何优雅的实现操作日志记录

SpringBoot项目如何优雅的实现操作日志记录 前言 在实际开发当中&#xff0c;对于某些关键业务&#xff0c;我们通常需要记录该操作的内容&#xff0c;一个操作调一次记录方法&#xff0c;每次还得去收集参数等等&#xff0c;会造成大量代码重复。 我们希望代码中只有业务相关…