pytorch 自定义函数

pytorch 自定义函数

介绍:https://zhuanlan.zhihu.com/p/344802526
主要构建 static method forward 和 backward

比如 layernorm: 参考:https://github.com/zhangyi-3/KBNet/blob/main/basicsr/models/archs/kb_utils.py

导数的推导:https://blog.csdn.net/qinduohao333/article/details/132309091

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LayerNormFunction(torch.autograd.Function):@staticmethoddef forward(ctx, x, weight, bias, eps):ctx.eps = epsN, C, H, W = x.size()mu = x.mean(1, keepdim=True)var = (x - mu).pow(2).mean(1, keepdim=True)# print('mu, var', mu.mean(), var.mean())# d.append([mu.mean(), var.mean()])y = (x - mu) / (var + eps).sqrt()weight, bias, y = weight.contiguous(), bias.contiguous(), y.contiguous()  # avoid cuda errorctx.save_for_backward(y, var, weight)y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)return y@staticmethoddef backward(ctx, grad_output):eps = ctx.epsN, C, H, W = grad_output.size()# y, var, weight = ctx.saved_variablesy, var, weight = ctx.saved_tensorsg = grad_output * weight.view(1, C, 1, 1)mean_g = g.mean(dim=1, keepdim=True)mean_gy = (g * y).mean(dim=1, keepdim=True)gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), Noneclass LayerNorm2d(nn.Module):def __init__(self, channels, eps=1e-6, requires_grad=True):super(LayerNorm2d, self).__init__()self.register_parameter('weight', nn.Parameter(torch.ones(channels), requires_grad=requires_grad))self.register_parameter('bias', nn.Parameter(torch.zeros(channels), requires_grad=requires_grad))self.eps = epsdef forward(self, x):return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/721064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux常用命令(超详细)

一、基本命令 1.1 关机和重启 关机 shutdown -h now 立刻关机 shutdown -h 5 5分钟后关机 poweroff 立刻关机 重启 shutdown -r now 立刻重启 shutdown -r 5 5分钟后重启 reboot 立刻重启 1.2 帮助命令 –help命令 shutdown --help: ifconfig --help:查看…

SpringBoot 接口报错该如何解决?

在Spring Boot应用中,接口报错可能由多种原因引起,包括但不限于业务逻辑错误、异常处理不当、依赖库问题、配置错误等。解决接口报错的过程需要分析具体的错误信息、排查可能的原因,并采取相应的调试和修复措施。 以下是解决Spring Boot接口…

AWS ECR(AWS云里面的docker镜像私库)

问题 上一篇文章,在AWS云上面部署了k8s集群,这次接下来,需要在一个docker镜像私库。 步骤 创建docker镜像私库 打开AWS ECR主页,创建一个docker镜像私库,如下图: 设置私有镜像库名称,直接创…

AI短视频矩阵运营软件|抖音视频矩阵控制工具

【罐头鱼AI传单功能介绍】 罐头鱼AI传单是一款专为短视频矩阵运营而设计的智能软件,旨在帮助用户高效管理和运营多个抖音账号,并提供一系列强大的功能来优化视频内容创作和发布流程。QQ:290615413以下是软件框架,详细介绍其功能和特点&#…

第一弹:Flutter安装和配置

目标: 1)配置Flutter开发环境 2)创建第一个Flutter Demo项目 Flutter中文开发者网站: https://flutter.cn/ 一、配置Flutter开发环境 Flutter开发环境已经提供集成IDE开发环境,因此需要配置开发环境的时候&#xf…

【项目管理】CMMI-质量保证过程

质量保证过程(PQA):通过质量保证活动,确保过程与产品满足过程、规程及相应的要求,确保问题得到关注与解决,使工作人员和管理者能够客观地了解过程与相关的工作产品。QA工程师应实施质量保证策划活动,客观地…

【云呐】固定资产管理系统选型有哪些

固定资产管理是企业经营管理过程中非常重要的任务。它涉及到公司的核心资产,包括土地、建筑、设备、车辆等。许多企业选择引入固定资产管理系统,以确保这些资产的高效管理和使用。但是,面对市场上众多的固定资产管理系统,如何选择…

【自然语言处理六-最重要的模型-transformer-上】

自然语言处理六-最重要的模型-transformer-上 什么是transformer模型transformer 模型在自然语言处理领域的应用transformer 架构encoderinput处理部分(词嵌入和postional encoding)attention部分addNorm Feedforward & add && NormFeedforw…

固定资产管理系统包括哪些

固定资产管理是企业经营过程中一项非常重要的任务。它涉及到公司的核心资产,包括土地、建筑物、设备、车辆等。为了有效地管理这些资产,许多企业选择使用固定资产管理系统。那么,固定资产管理系统的内容是什么呢?本文将为您进行全…

代码随想录day13(2)栈与队列:用队列实现栈(leetcode225)

题目要求:使用队列实现栈的push、pop、top(取栈顶元素)、empty操作。 思路:首先的思路就是使用两个队列,入栈的操作不变,如果想弹出一个元素,先将队列1元素出队,只保留一个元素&…

力扣hot100:1.两数之和(哈希表)

输入中可能存在重复值 。 分析&#xff1a; 本题需要返回的是数组下标&#xff0c;因此如果需要使用排序然后双指针的话&#xff0c;需要用到哈希表&#xff0c;但是由于输入中可能存在重复值&#xff0c;因此哈希表的value值必须是vector<int>。 使用双指针求目标值targ…

第十七天-反爬与反反爬-验证码识别

目录 反爬虫介绍 基于身份识别反爬和解决思路 Headers反爬-使用User-agent Headers反爬-使用coookie字段 Headers反爬-使用Referer字段 基于参数反爬 验证码反爬 1.验证码介绍 2.验证码分类&#xff1a; 3.验证码作用 4.处理方案 5.图片识别引擎:ocr 6.使用打码平…

基于springboot+vue的周边游平台个人管理系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

【leetcode】三数之和 双指针

/*** param {number[]} nums* return {number[][]}*/ var threeSum function(nums) {nums.sort((a,b)>a-b);let result[];for(let i0;i<nums.length-2;i){if(nums[i]>0) return result;//因为求三数之和等于0&#xff0c;如果第一个数已经大于0&#xff0c;后面肯定无…

完蛋!这DLC不行啊?!

作者&#xff1a;胖头鱼的鱼缸&#xff08;尹海文&#xff09; Oracle ACE Associate: Database&#xff08;Oracle与MySQL&#xff09; 国内某科技公司 DBA总监 10年数据库行业经验&#xff0c;现主要从事数据库服务工作 拥有OCM 11g/12c/19c、MySQL 8.0 OCP、Exadata、CDP等认…

pycharm社区版+miniconda 环境配置学习

使用电脑为win10 64bit 一、下载&#xff1a; 1.pycharm社区版地址&#xff1a; 其他版本 - PyCharm 2.miniconda下载地址&#xff1a; Index of /anaconda/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 3.python下载地址&#xff1a; https://www.python.or…

云原生数据库 GaiaDB 支持新的管理工具啦

GaiaDB 是百度智能云自研的新一代企业级关系型数据库&#xff0c;最大容量可扩展 500TB 以上&#xff0c;吞吐达到 150 万以上 QPS。 作为一款 100% 兼容 MySQL 的云原生数据库产品&#xff0c;用户可以通过多种客户端工具连接 GaiaDB 实例&#xff0c;例如 MySQL Workbench、N…

掌握代理IP技术:从基础设置到高级应用,步步进阶教程

代理IP技术是一种网络通信技术&#xff0c;通过代理服务器转发客户端的网络请求&#xff0c;可以实现IP地址隐藏、突破地域限制、提高访问速度等多种功能。以下是一个从基础设置到高级应用的步步进阶教程&#xff1a; 1. 基础设置&#xff1a; - HTTP代理设置&#xff1a;在大多…

【机器学习】生成对抗网络GAN

概述 生成对抗网络&#xff08;Generative Adversarial Network&#xff0c;GAN&#xff09;是一种深度学习模型架构&#xff0c;由生成器&#xff08;Generator&#xff09;和判别器&#xff08;Discriminator&#xff09;两部分组成&#xff0c;旨在通过对抗训练的方式生成逼…

Python之Web开发初学者教程----卸载ubuntu系统

Python之Web开发初学者教程----卸载ubuntu系统 Windows 10自带了Subsytem for Linux (WSL)功能&#xff0c;可以让用户在Windows命令行环境下运行Linux命令。用户可以在Windows应用商店中下载和安装Ubuntu子系统&#xff0c;有时在使用过程中需要完全删除Ubuntu子系统以释放硬…