BN体系理解——类封装复现

 

 

 

 

 

from pathlib import Path
from typing import Optionalimport torch
import torch.nn as nn
from torch import Tensorclass BN(nn.Module):def __init__(self,num_features,momentum=0.1,eps=1e-8):##num_features是通道数"""初始化方法:param num_features:特征属性的数量,也就是通道数目C"""super(BN, self).__init__()##register_buffer:将属性当成parameter进行处理,唯一的区别就是不参与反向传播的梯度求解self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))self.register_buffer('running_var', torch.zeros(1, num_features, 1, 1))self.running_mean: Optional[Tensor]self.running_var: Optional[Tensor]self.running_mean=torch.zeros([1,num_features,1,1])self.running_var=torch.zeros([1,num_features,1,1])self.gamma=nn.Parameter(torch.ones([1,num_features,1,1]))self.beta=nn.Parameter(torch.zeros(1,num_features,1,1))self.eps=epsself.momentum=momentumdef forward(self,x):"""前向过程output=(x-μ)/α*γ+β:param x: [N,C,H,W]:return: [N,C,H,W]"""if self.training:#训练阶段--》使用当前批次的数据_mean=torch.mean(x,dim=(0,2,3),keepdim=True)_var = torch.var(x, dim=(0, 2, 3), keepdim=True)#将训练过程中的均值和方差保存下来--方便推理的时候使用--》滑动平均self.running_mean=self.momentum*self.running_mean+(1.0-self.momentum)*_meanself.running_var=self.momentum*self.running_var+(1.0-self.momentum)*_varelse:#推理阶段-->使用的是训练过程中的累积数据_mean=self.running_mean_var=self.running_varz=(x-_mean)/torch.sqrt(_var+self.eps)*self.gamma+self.betareturn zif __name__ == '__main__':torch.manual_seed(28)path_dir=Path("./output/models")path_dir.mkdir(parents=True,exist_ok=True)device=torch.device("cuda" if torch.cuda.is_available() else "cpu")bn=BN(num_features=12)bn.to(device)#只针对子模块和参数进行转换#模拟训练过程bn.train()xs=[torch.randn(8,12,32,32).to(device) for _ in range(10)]for _x in xs:bn(_x)print(bn.running_mean.view(-1))print(bn.running_var.view(-1))#模拟推理过程bn.eval()_r=bn(xs[0])print(_r.shape)bn=bn.cpu()#保存都是以cpu保存,恢复再自己转回GPU上#模拟模型保存torch.save(bn,str(path_dir/'bn_model.pkl'))#state_dict:获取当前模块的所有参数(Parameter+register_buffer)torch.save(bn.state_dict(),str(path_dir/"bn_params.pkl"))#pt结构的保存traced_script_module=torch.jit.trace(bn.eval(),xs[0].cpu())traced_script_module.save("./output/bn_model.pt")#模拟模型恢复bn_model=torch.load(str(path_dir/"bn_model.pkl"),map_location='cpu')bn_params=torch.load(str(path_dir/"bn_params.pkl"),map_location='cpu')print(len(bn_params))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/102735.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

产品升级!全球尺度下原核基因组关键基因共进化无标题

微生物是群落型的生存方式,高通量测序时代到来后,掀起了针对微生物群落整体研究的高潮,比如基于功能基因/16S/ITS/扩增子、宏基因组等进行群落多样性分析。但是,我们基于分离培养等方法获得单菌落,针对单菌开展基因组、…

十一、WSGI与Web框架

目录 一、什么是WSGI1.1 WSGI接口的组成部分1.2 关于environ 二、简易的web框架实现2.1 文件结构2.2 在web/my_web.py定义动态响应内容2.3 在html/index.html中定义静态页面内容2.4 在web_server.py中实现web服务器框架2.5 测试 三、让简易的web框架动态请求支持多页面3.1 修改…

Python in Visual Studio Code 2023年10月发布

排版:Alan Wang 我们很高兴地宣布 Visual Studio Code 的 Python 和 Jupyter 扩展于 2023 年 10 月发布! 此版本包括以下公告: Python 调试器扩展更新弃用 Python 3.7 支持Pylint 扩展更换时的 Lint 选项Mypy 扩展报告的范围和守护程序模式G…

mysql面试题36:MySQL的binlog有几种录入格式?分别有什么区别

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:MySQL的binlog有几种录入格式?分别有什么区别 MySQL的binlog(二进制日志)是用于记录数据库的更改操作的一种机制,它可以用于数据恢复、数据复…

Java集合(三)--- List接口

文章目录 一、List接口常用实现类的对比二、List接口中的常用方法代码 提示:以下是本篇文章正文内容,下面案例可供参考 一、List接口常用实现类的对比 二、List接口中的常用方法 代码 package com.tyust.edu;import org.junit.Test;import java.util.A…

2023年【危险化学品生产单位安全生产管理人员】及危险化学品生产单位安全生产管理人员模拟考试题

题库来源:安全生产模拟考试一点通公众号小程序 危险化学品生产单位安全生产管理人员考前必练!安全生产模拟考试一点通每个月更新危险化学品生产单位安全生产管理人员模拟考试题题目及答案!多做几遍,其实通过危险化学品生产单位安…

CRMEB多商户商城系统阿里云集群部署教程

注意: 1.所有服务创建时地域一定要选择一致,这里我用的是杭州K区 2.文件/图片上传一定要用类似oss的云文件服务, 本文不做演示 一、 创建容器镜像服务,容器镜像服务(aliyun.com) ,个人版本就可以 先创建一个命名空间 然后创建一个镜像仓库 查看并记录镜像公网地址…

自定义类型:结构体,枚举,联合 (1)

1 结构体的声明 1.1 结构的基础知识 结构是一些值的集合,这些值称为成员变量。结构的每个成员可以是不同类型的变量。 1.2 结构的声明 struct tag { member-list; }variable-list; 例如描述一个学生: struct是结构体关键字,不能省略。 …

运维大数据平台的建设与实践探索

随着企业数字化转型的推进,运维管理面临着前所未有的挑战和机遇。为应对日益复杂且严峻的挑战,数字免疫系统和智能运维等概念应运而生。数字免疫系统和智能运维作为新兴技术,正引领着运维管理的新趋势。数字免疫系统和智能运维都借助大数据运…

同创永益成为英迈首家签约生态伙伴

日前,同创永益已和英迈签署生态运营战略协议,并正式成为英迈全新打造的GTM生态圈的首位签约合作伙伴。双方将携手对“同创数字韧性平台”产品进行一站式联合解决方案的持续整合,并将大力推动该联合解决方案在市场上的进一步拓展。 云原生时代…

vite+vue3+ts中使用require.context | 报错require is not defined | 获取文件夹中的文件名

vitevue3ts中使用require.context|报错require is not defined|获取文件夹中的文件名 目录 vitevue3ts中使用require.context|报错require is not defined|获取文件夹中的文件名一、问题背景二、报错原因三、解决方法 一、问题背景 如题在vitevue3ts中使用required.context时报…

科技资讯|9月新能源汽车零售74.3万辆,充电桩迎来发展高峰

据中国乘联会发布的初步数据,中国 9 月份乘用车市场零售 202.8 万辆,同比增长 6%,环比增 6%。今年以来,我国乘用车市场累计零售 1,524 万辆,同比增长 2%。 乘联会预计,9 月份新能源车市场零售 74.3 万辆&a…

Java架构师系统架构设计资源估算

目录 1 认识资源估算1.1 预估未来发展1.2 资源估算的意义 2 资源估算方法2.1 确定系统目标2.2 并发用户数2.3 指标数据 3 资源估算的经验法则4 资源估算的常见参考数据4.1 带宽估算4.2 nginx估算4.3 tomcat估算4.4 操作系统估算4.5 redis估算4.6 mysql估算 5 并发人数估算5.1 请…

【Unity3D编辑器开发】Unity3D中制作一个可以随时查看键盘对应KeyCode值面板,方便开发

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址我的个人博客 大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 在开发中,会遇到要使用监控键盘输入的KeyCode值来执…

Python- socket编程

Python中的socket模块为网络通信提供了基础API,使我们能够在应用程序中实现低级的网络交互。使用socket编程,可以创建TCP、UDP和RAW sockets来进行数据通信。 以下是Python socket 编程的简要概述: 1. 核心概念 Socket: 通信的端点&#x…

分布式事务入门

文章目录 分布式事务问题本地事务分布式事务演示分布式事务问题 理论基础CAP定理一致性可用性分区容错矛盾 BASE理论 SeataSeata的架构部署TC服务微服务集成seata 动手实践XA模式两阶段提交Seata的XA模型实现XA模式 AT模式Seata的AT模型流程梳理脏写问题实现AT模式 TCC模式流程…

github小记(一):清除github在add或者commit之后缓存区

github清除在add或者commit之后缓存区 前言1. 第一步之后想要撤销2. 第二步之后想要撤销a. 改变一下rrr.txt的内容b. 想提交本地文件的test文件夹c. 我后悔了突然不想提交了 前言 github自用 一般github上代码提交顺序: 第一步: git add . or git ad…

0基础学习VR全景平台篇 第107篇:全景图调色和细节处理(上,地拍)

上课!全体起立~ 大家好,欢迎观看蛙色官方系列全景摄影课程! 今天教给大家的课程是地拍全景图调色和细节处理,下面我们就开始吧! 1.把照片快速导入LR软件 选择【图库】模块 打开软件后,点击【导入】按…

【Ceph Block Device】块设备挂载使用

文章目录 前言创建pool创建user创建image列出image检索image信息调整image大小增加image大小减少image大小 删除image从pool中删除image从pool中“延迟删除”image从pool中移除“延迟删除的image” 恢复image恢复指定pool中延迟删除的image恢复并重命名image 映射块设备格式化i…

总结四:数据库(MySQL)面经

文章目录 一、SQL1、介绍一下数据库分页2、介绍一下SQL中的聚合函数3、表跟表是怎么关联的?4、说一说你对外连接的了解?5、说一说数据库的左连接和右连接?6、SQL中怎么将行转成列?7、谈谈你对SQL注入的理解?8、将一张表的部分数据…