为pytorch前向和反向的Tensor生成描述性统计

为pytorch前向和反向的Tensor生成描述性统计

  • 代码

在调试Megatron-DeepSpeed的精度时,我们希望对比每一层前向和反向传播的输入输出误差。然而,由于数据量过大,直接保存所有数据不太现实。因此,我们生成了输入输出tensor的描述性统计信息,并等间隔抽样N个数据点,以比较这些点的相对误差,从而查找精度异常的位置。为了准确定位,我们通过类名和对象ID生成唯一的对象名称(形式为[类名-创建的第几个])以及前向和反向传播的次数。通过保存上述信息,我们可以详细记录并回溯当时的实际输入输出数据。

代码

cat > linear_test.py <<-'EOF'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from datetime import datetime# 设置设备
device = "cpu"if torch.cuda.is_available():device = "cuda:4"def is_tensor(val):# 判断是否为tensor或Parameterreturn isinstance(val, (torch.Tensor, nn.Parameter))def describe_tensor(tensor):# 返回tensor的描述,包括形状和部分数据统计信息shape = list(tensor.shape)tensor_data = tensor.cpu().float().detach().numpy().ravel()num_points = min(16, len(tensor_data))indices = np.linspace(0, len(tensor_data) - 1, num_points, dtype=int)stats = [np.max(tensor_data), np.min(tensor_data), np.mean(tensor_data), np.std(tensor_data)]sample_data = tensor_data[indices]stats_str = ",".join(f"{x:.5f}" for x in stats)sample_str = ",".join(f"{x:.5f}" for x in sample_data)return f"{shape}-{stats_str},{sample_str}"def generate_random_data(shape):# 生成符合指定形状的随机数据max_val, min_val, mean, std = 0.04025, -0.04651, 0.0, 0.00134data = np.random.normal(mean, std, shape)data = (data - data.min()) / (data.max() - data.min()) * (max_val - min_val) + min_valreturn dataindex_counter = 0def log_tensor_data(name, tensor):# 打印tensor的日志数据global index_counterindex_counter += 1timestamp = datetime.now().strftime("%H%M%S%f")if is_tensor(tensor):print(f"{timestamp},{index_counter},{name},0,{describe_tensor(tensor)}")elif isinstance(tensor, (tuple, list)):for idx, t in enumerate(tensor):if is_tensor(t):print(f"{timestamp},{index_counter},{name},{idx},{describe_tensor(t)}")def log_gradient(model):# 打印模型参数梯度信息for name, param in model.named_parameters():if param.grad is not None:log_tensor_data(f"grad-{name}", param.grad)# 对象和类名缓存
object_cache = {}
class_name_count = {}def get_unique_name(class_name, obj_id):# 生成唯一的对象名称if class_name not in class_name_count:class_name_count[class_name] = 0uid = f"{class_name}_{obj_id}"if uid not in object_cache:class_name_count[class_name] += 1object_cache[uid] = {"idx": class_name_count[class_name]}return f'{class_name}-{object_cache[uid]["idx"]}'def initialize_module_attributes(module):# 初始化模块属性if not hasattr(module, 'uuid'):module.uuid = get_unique_name(module.__class__.__name__, id(module))if not hasattr(module, 'backward_step'):module.backward_step = 0if not hasattr(module, 'forward_step'):module.forward_step = 0def forward_decorator():# 包装forward函数的修饰器def decorator(func):def wrapped(*args, **kwargs):module = args[0]initialize_module_attributes(module)module.forward_step += 1log_tensor_data(f"forward-{module.uuid}-{module.forward_step}-input", args)output = func(*args, **kwargs)log_tensor_data(f"forward-{module.uuid}-{module.forward_step}-output", output)return outputreturn wrappedreturn decoratordef pre_backward_hook(module, grad_input):# 反向传播前的钩子函数initialize_module_attributes(module)module.backward_step += 1log_tensor_data(f"backward-{module.uuid}-{module.backward_step}-input", grad_input)def post_backward_hook(module, grad_input, grad_output):# 反向传播后的钩子函数initialize_module_attributes(module)log_tensor_data(f"backward-{module.uuid}-{module.backward_step}-output", grad_output)def register_backward_hooks(module):# 注册反向传播钩子module.register_full_backward_pre_hook(pre_backward_hook)module.register_full_backward_hook(post_backward_hook)class CustomLinear(nn.Module):def __init__(self, shape):super(CustomLinear, self).__init__()weight_data = torch.from_numpy(generate_random_data(shape)).half().to(device)self.weight = nn.Parameter(weight_data)self.register_parameter('bias', None)register_backward_hooks(self)@forward_decorator()def forward(self, input_):return F.linear(input_, self.weight, self.bias)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer1 = CustomLinear((5504, 4096))self.layer2 = CustomLinear((4096, 5504))@forward_decorator()def forward(self, input_):out = self.layer1(input_)out = self.layer2(out)return out
# 设置随机种子
np.random.seed(1)
torch.manual_seed(2)# 创建和训练模型
model = MyModel().half().to(device)
model.train()input_data = torch.from_numpy(generate_random_data((1024, 12, 4096))).half().to(device)
target_data = torch.from_numpy(generate_random_data((1024, 12, 4096))).half().to(device)for _ in range(2):outputs = model(input_data)outputs.backward(target_data)  # 使用全一的梯度来反向传播log_gradient(model)
EOF
python3 linear_test.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/12970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

有哪些好用的3dMax大神插件?

有哪些好用的3dMax大神插件&#xff1f; Mesh Insert 3DMAX网格插入插件Mesh Insert&#xff0c;在选择的面上安门窗、打螺丝、挖洞、插入眼耳口鼻及其它网格模型等可以分分钟搞定&#xff01;它通过将面选择替换为库中的资源来加快建模过程。非常适合硬网格和有机建模&#xf…

iOS ------ 多线程基础

一&#xff0c;进程和线程 1&#xff0c;进程 定义&#xff1a; 进程是指在系统中正在运行的一个应用程序每个进程之间是独立的&#xff0c;每个进程均运行在其专有的且受保护的内存进程是系统进行资源分配和调度的一个独立单位 补充&#xff1a;iOS系统是相对封闭的系统&a…

服务网格 SolarMesh v1.13 重磅发布

SolarMesh是行云创新推出的流量治理平台&#xff0c;它基于Istio&#xff0c;为部署在K8s集群上的应用提供全面的流量治理能力。 在之前的版本中&#xff0c;SolarMesh提供的能力有&#xff1a;流量视图&#xff0c;流量控制策略批量配置&#xff0c;API级别的流量数据采集和展…

【上海大学计算机组成原理实验报告】五、机器语言程序实验

一、实验目的 理解计算机执行程序的实际过程。 学习编制机器语言简单程序的方法。 二、实验原理 根据实验指导书的相关内容&#xff0c;指令的形式化表示是指采用一种规范化的符号系统&#xff0c;以更清晰、精确地描述和表示指令的逻辑功能和操作步骤。 汇编是一种编程语言…

MM模块学习二 (供应商,物料后台相关配置)

公司代码配置 新建条目&#xff08;只是建了一个名字出来&#xff0c;后面很多表都是没有得&#xff09; 接下来定义公司代码&#xff1a; 公司代码复制完成&#xff08;后续修改交给财务顾问去做&#xff09; 复制工厂&#xff1a; 复制工厂完成&#xff1a; 修改复制过去的工…

Linux服务器lvm磁盘管理fdisk和df磁盘大小不同修改

服务器端由于硬盘是通过VCenter原来100G磁盘复制的虚拟机,复制完成后,原来100G的磁盘通过选择 磁盘重新复制出150G的磁盘,开机后发现还是原来的100G的磁盘,通过fdisk -l 查看有个sdb是150G, 但是已经划转的lvm盘只有100G, 通过df查看也是原来的100G: pvs查看pv里也是10…

用c++实现快速排序、最大子段和问题

6.2.2 快速排序 【问题】快速排序(quick sort)的分治策略如下&#xff08;图6-5)。 (1)划分&#xff1a;&#xff08;选定一个记录作为轴值&#xff0c;以轴值为基准将整个序列划分为两个子序列&#xff0c;轴值的位置在划分的过程中确定&#xff0c;并且左侧子序列的所有记录…

26 分钟惊讶世界,GPT-4o 引领未来人机交互

前言 原文链接&#xff1a;OpenAI最新模型——GPT-4o&#xff0c;实时语音视频交互&#xff0c;未来人机交互近在眼前 - Kaiho小站 北京时间 5 月 14 日凌晨&#xff0c;OpenAI 发布新一代模型——GPT-4o&#xff0c;仅在 ChatGPT 面世 17 个月后&#xff0c;OpenAI 再次通过…

【EasyX】快速入门——静态图形篇

1.基本说明 EasyX 是针对 C 的图形库&#xff0c;可以帮助 C/C 初学者快速上手图形和游戏编程。 比如&#xff0c;可以基于 EasyX 图形库很快的用几何图形画一个房子&#xff0c;或者一辆移动的小车&#xff0c;可以编写俄罗斯方块、贪吃蛇、黑白棋等小游戏&#xff0c;可以练…

使用VMware或VirtualBox安装eNSP Pro并使用CRT连接设备

文章目录 使用Oracle Virtual Box安装eNSP Pro创建虚拟机配置网卡配置带外管理网络 使用VMware Workstation安装eNSP Pro转换文件格式及虚拟磁盘模式配置网卡创建虚拟机配置使用CRT连接管理设备 前一段时间是开放了eNSP Pro的账号权限&#xff0c;但是在写博客时&#xff0c;权…

27.哀家要长脑子了!

目录 1.316. 去除重复字母 - 力扣&#xff08;LeetCode&#xff09; 2. 1209. 删除字符串中的所有相邻重复项 II - 力扣&#xff08;LeetCode 哎哟 烦死了 刚刚不小心退出又没保存 又要写一遍 烦死了 最近刷题不得劲啊 感觉这脑子没长一点 1.316. 去除重复字母 - 力扣&am…

(实测验证)【移远EC800M-CN 】GNSS功能打开和关闭关闭步骤验证

引言 本文章使用自研“超小体积TTL转4GGPS集成模块”进行实测验证&#xff1b; 一、打开GNSS功能 步骤一、通过 ATQGPSCFG 配置 GNSS 参数 &#xff08;1&#xff09;该命令用于查询和配置 GNSS 不同的设置&#xff0c;包括 NMEA 语句输出端口、NMEA 语句的输出类型等。 1.1…

NSSCTF | [SWPUCTF 2021 新生赛]easyupload2.0

先传一个普通的一句话木马试一试 GIF89a <?php eval($_POST[shell]);?> 可以看到回显&#xff0c;不允许上传php文件。 使用Burpsuite抓包只修改ContentType后发现也不能绕过&#xff0c;说明服务器使用了黑名单后缀限制&#xff0c;那么我们可以使用其他的后缀代替ph…

电路板维修【四】

【开关电源输出电压偏低不稳&#xff0c;用示波器立马锁定故障范围】&#xff1a;https://www.bilibili.com/video/BV1pf421D73K?vd_source3cc3c07b09206097d0d8b0aefdf07958 可以用示波器查看MOS的输出波形来查看其是否损坏&#xff1a; 电源芯片的供电电压来回跳变&#xf…

基于卷积神经网络CNN,使用二维卷积Conv2D实现MNIST数字识别的四种方法

前言 系列专栏&#xff1a;机器学习&#xff1a;高级应用与实践【项目实战100】【2024】✨︎ 在本专栏中不仅包含一些适合初学者的最新机器学习项目&#xff0c;每个项目都处理一组不同的问题&#xff0c;包括监督和无监督学习、分类、回归和聚类&#xff0c;而且涉及创建深度学…

ROS 2边学边练(48)-- 将URDF与robot_state_publisher一起使用

前言 本篇将完成一个行走的机器人&#xff0c;并以tf2消息的方式实时发布机器人状态&#xff0c;以便我们在Rviz中同步查看。 首先&#xff0c;我们创建描述机器人装配的URDF模型。接下来&#xff0c;我们编写一个节点&#xff0c;用于模拟运动并发布JointState和位姿变换。然后…

醉了,面个功能测试,还问我Python装饰器

Python 装饰器是个强大的工具&#xff0c;可帮你生成整洁、可重用和可维护的代码。某种意义上说&#xff0c;会不会用装饰器是区分新手和老鸟的重要标志。如果你不熟悉装饰器&#xff0c;你可以将它们视为将函数作为输入并在不改变其主要用途的情况下扩展其功能的函数。装饰器可…

dhcp(接口和全局地址池模式)

接口地址池和全局地址池 dhcp应用 1.全部开启dhcp功能 2.ar5 0口接口地址池 1口全局地址池 3.ar6和ar7配置&#xff0c;查看能否自动获取ip 左右不同两个网络&#xff0c;接口和全局地址池的区别 部分截图 ar6 ar7 ar5

(实测验证)【移远EC800M-CN 】TCP 透传

引言 本文章使用自研“超小体积TTL转4GGPS集成模块”进行实测验证&#xff1b; 1、配置移远EC800M-CN TCP 透传 串口助手发送&#xff1a; ATQIOPEN1,0,"TCP","36.137.226.30",39755,0,2 //配置服务器地址和端口号&#xff1b; 4G模组返回…

07-Fortran基础--Fortran指针(Pointer)的使用

07-Fortran基础--Fortran指针Pointer的使用 0 引言1 指针&#xff08;Poionter&#xff09;的有关内容1.1 一般类型指针1.2 数组指针1.3 派生类(type)指针1.4 函数指针 2 可运行code 0 引言 Fortran是一种广泛使用的编程语言&#xff0c;特别适合科学计算和数值分析。Fortran 9…