pytest学习-pytorch单元测试

pytorch单元测试

  • 一.公共模块[common.py]
  • 二.普通算子测试[test_clone.py]
  • 三.集合通信测试[test_ccl.py]
  • 四.测试命令
  • 五.测试报告

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as distos.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" device="cpu"
device_type="cpu"
device_name="cpu"try:if torch.cuda.is_available():     device_name=torch.cuda.get_device_name().replace(" ","")device="cuda:0"device_type="cuda"ccl_backend='nccl'
except:passhost_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()if not os.path.exists(metric_data_root):os.makedirs(metric_data_root)def device_warmup(device):'''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''left = torch.rand([128,512], dtype = torch.float16).to(device)right = torch.rand([512,128], dtype = torch.float16).to(device)out=torch.matmul(left,right)torch.cuda.synchronize()torch.manual_seed(1) 
np.random.seed(1)def loop_decorator(loops,rank=0):'''循环装饰器,用于统计函数的执行时间,内存占用等'''def decorator(func):def wrapper(*args,**kwargs):latency=[]memory_allocated_t0=torch.cuda.memory_allocated(rank)for _ in range(loops):input_copy=[x.clone() for x in args]beg= datetime.now().timestamp() * 1e6pred= func(*input_copy)gt=kwargs["golden"]torch.cuda.synchronize()end=datetime.now().timestamp() * 1e6mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()latency.append(end-beg)memory_allocated_t1=torch.cuda.memory_allocated(rank)avg_latency=np.mean(latency[len(latency)//2:]).round(3)first_latency=latency[0]return { "first_latency":first_latency,"avg_latency":avg_latency,"memory_allocated":memory_allocated_t1-memory_allocated_t0,"mse":mse}return wrapperreturn decoratorclass TorchUtMetrics:'''用于统计测试结果,比较之前的最小值'''def __init__(self,ut_name,thresold=0.2,rank=0):self.ut_name=f"{ut_name}_{rank}"self.thresold=thresoldself.rank=rankself.data={"ut_name":self.ut_name,"metrics":[]}self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")try:with open(self.metrics_path,"r") as f:self.data=json.loads(f.read())except:passdef __enter__(self):self.beg= datetime.now().timestamp() * 1e6return selfdef __exit__(self, exc_type, exc_val, exc_tb):        self.report()self.save_data()def save_data(self):with open(self.metrics_path,"w") as f:f.write(json.dumps(self.data,indent=4))def set_metrics(self,metrics):self.end=datetime.now().timestamp() * 1e6item=collections.OrderedDict()item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')item["sdk_version"]=sdk_versionitem["device_name"]=device_nameitem["host_name"]=host_nameitem["metrics"]=metricsitem["metrics"]["e2e_time"]=self.end-self.begself.cur_item=itemself.data["metrics"].append(self.cur_item)def get_metric_names(self):return self.data["metrics"][0]["metrics"].keys()def get_min_metric(self,metric_name,devicename=None):min_value=0min_value_index=-1for idx,item in enumerate(self.data["metrics"]):if devicename and (devicename!=item['device_name']):                continue            val=float(item["metrics"][metric_name])if min_value_index==-1 or val<min_value:min_value=valmin_value_index=idxreturn min_value,min_value_indexdef get_metric_info(self,index):metrics=self.data["metrics"][index]return f'{metrics["device_name"]}@{metrics["sdk_version"]}'def report(self):assert len(self.data["metrics"])>0for metric_name in self.get_metric_names():min_value,min_value_index=self.get_min_metric(metric_name)min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)cur_value=float(self.cur_item["metrics"][metric_name])print(f"-------------------------------{metric_name}-------------------------------")print(f"{cur_value}#{device_name}@{sdk_version}")if min_value_index_same_dev>=0:print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")if min_value_index>=0:print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):#如果不满足条件,则跳过这个测试@unittest.skipIf(device_count>1, "Not enough devices") def test_todo(self):print(".TODO")#框架会自动遍历以下参数组合@parametrize("shape", [(10240,20480),(128,256)])@parametrize("dtype", [torch.float16,torch.float32])def test_clone(self,shape,dtype):#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5)def run(input_dev):output=input_dev.clone()return output#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:input_host=torch.ones(shape,dtype=dtype)*np.random.rand()input_dev=input_host.to(device)metrics=run(input_dev,golden=input_host.cpu())m.set_metrics(metrics)assert(metrics["mse"]==0)instantiate_parametrized_tests(TestCaseClone)if __name__ == "__main__":run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):'''CCL测试用例'''def _create_process_group_vccl(self, world_size, store):dist.init_process_group(ccl_backend, world_size=world_size, rank=self.rank, store=store)        pg = dist.distributed_c10d._get_default_group()return pgdef setUp(self):super().setUp()self._spawn_processes()def tearDown(self):super().tearDown()try:os.remove(self.file_name)except OSError:pass@propertydef world_size(self):return 4#框架会自动遍历以下参数组合@unittest.skipIf(device_count<4, "Not enough devices") @parametrize("op",[dist.ReduceOp.SUM])@parametrize("shape", [(1024,8192)])@parametrize("dtype", [torch.int64])def test_allreduce(self,op,shape,dtype):if self.rank >= self.world_size:returnstore = dist.FileStore(self.file_name, self.world_size)pg = self._create_process_group_vccl(self.world_size, store)if not torch.distributed.is_initialized():returntorch.cuda.set_device(self.rank)device = torch.device(device_type,self.rank)device_warmup(device)#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量@loop_decorator(loops=5,rank=self.rank)def run(input_dev):dist.all_reduce(input_dev, op=op)return input_dev#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]gt_=gt[0]for i in range(1,self.world_size):gt_=gt_+gt[i]input_dev=input_host.to(device)metrics=run(input_dev,golden=gt_)m.set_metrics(metrics)assert(metrics["mse"]==0)dist.destroy_process_group(pg)instantiate_parametrized_tests(TestCCL)if __name__ == "__main__":run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/824509.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

wsl安装与日常使用

文章目录 一、前向配置1、搜索功能2、勾选下面几个功能&#xff0c;进行安装二、安装WSL1、打开Windows PowerShell,查找你要安装的linux版本2、选择对应版本进行安装3、输入用户名以及密码 三、配置终端代理1、打开powershell,查看自己的IP把以下信息加入到~/.bashrc中 四、更…

Transformer with Transfer CNN for Remote-Sensing-Image Object Detection

遥感图像&#xff08;RSI&#xff09;中的目标检测始终是遥感界一个充满活力的研究主题。 最近&#xff0c;基于深度卷积神经网络 (CNN) 的方法&#xff0c;包括基于区域 CNN 和基于 You-Only-Look-Once 的方法&#xff0c;已成为 RSI 目标检测的事实上的标准。 CNN 擅长局部特…

夸克AI PPT初体验:一键生成大纲,一键生成PPT,一键更换模板!

大家好&#xff0c;我是木易&#xff0c;一个持续关注AI领域的互联网技术产品经理&#xff0c;国内Top2本科&#xff0c;美国Top10 CS研究生&#xff0c;MBA。我坚信AI是普通人变强的“外挂”&#xff0c;所以创建了“AI信息Gap”这个公众号&#xff0c;专注于分享AI全维度知识…

JavaScript(JS)三种使用方式,三种输出方式,及快速注释。---[用于后续web渗透内容]

JavaScript&#xff08;JS&#xff09;是一种广泛使用的编程语言&#xff0c;允许在网页中添加交互性和动态效果。在HTML中&#xff0c;<script>标签用于引入和执行JavaScript代码。 JS代码 js1.html \\js三种使用方式<!DOCTYPE html> <html lang"en&quo…

vulhub weblogic全系列靶场

简介 Oracle WebLogic Server 是一个统一的可扩展平台&#xff0c;专用于开发、部署和运行 Java 应用等适用于本地环境和云环境的企业应用。它提供了一种强健、成熟和可扩展的 Java Enterprise Edition (EE) 和 Jakarta EE 实施方式。 需要使用的工具 ysoserial使用不同库制作的…

自动驾驶时代的物联网与车载系统安全:挑战与应对策略

随着特斯拉CEO埃隆马斯克近日对未来出行景象的描绘——几乎所有汽车都将实现自动驾驶&#xff0c;这一愿景愈发接近现实。马斯克生动比喻&#xff0c;未来的乘客步入汽车就如同走进一部自动化的电梯&#xff0c;无需任何手动操作。这一转变预示着汽车行业正朝着高度智能化的方向…

Python学习之-typing详解

前言&#xff1a; Python的typing模块自Python 3.5开始引入&#xff0c;提供了类型系统的扩展&#xff0c;能够帮助程序员定义变量、函数的参数和返回值类型等。这使得代码更易于理解和检查&#xff0c;也方便了IDE和一些工具进行类型检查&#xff0c;提升了代码的质量。 typ…

【每日刷题】Day17

【每日刷题】Day17 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 19. 删除链表的倒数第 N 个结点 - 力扣&#xff08;LeetCode&#xff09; 2. 162. 寻找峰值 - 力扣…

Scratch四级:第02讲 字符串

第02讲 字符串 教练:老马的程序人生 微信:ProgrammingAssistant 博客:https://lsgogroup.blog.csdn.net/ 讲课目录 运算模块:有关字符串的积木块遍历字符串项目制作:“解密”项目制作:“成语接龙”项目制作:“加减法混合运算器”字符串 计算机学会(GESP)中属于三级的内…

YOLOv9改进策略 | 损失函数篇 | EIoU、SIoU、WIoU、DIoU、FocusIoU等二十余种损失函数

一、本文介绍 这篇文章介绍了YOLOv9的重大改进&#xff0c;特别是在损失函数方面的创新。它不仅包括了多种IoU损失函数的改进和变体&#xff0c;如SIoU、WIoU、GIoU、DIoU、EIOU、CIoU&#xff0c;还融合了“Focus”思想&#xff0c;创造了一系列新的损失函数。这些组合形式的…

腾讯AI Lab:“自我对抗”提升大模型的推理能力

本文介绍了一种名为“对抗性禁忌”&#xff08;Adversarial Taboo&#xff09;的双人对抗语言游戏&#xff0c;用于通过自我对弈提升大型语言模型的推理能力。 &#x1f449; 具体的流程 1️⃣ 游戏设计&#xff1a;在这个游戏中&#xff0c;有两个角色&#xff1a;攻击者和防守…

基于Ultrascale+系列GTY收发器64b/66b编码方式的数据传输(一)——Async Gearbox使用及上板测试

于20世纪80年代左右由IBM提出的传统8B/10B编码方式在编码效率上较低&#xff08;仅为80%&#xff09;&#xff0c;为了提升编码效率&#xff0c;Dgilent Techologies公司于2000年左右提出了64b/66b编码并应用于10G以太网中。Xilinx GT手册中没有过多64b/66b编码介绍&#xff0c…

绝地求生:PUBG地形破坏功能上线!分享你的游玩感受及反馈赢丰厚奖励

随着29.1版本更新&#xff0c;地形破坏功能及新道具“镐”正式在荣都地图亮相&#xff01;大家现在可以在荣都地图体验“动手挖呀挖”啦。 快来分享你的游玩感受及反馈&#xff0c;即可参与活动赢取精美奖励&#xff01; 参与方式 以发帖/投稿的形式&#xff0c;在 #一决镐下#…

【记录】Python|Selenium 下载 PDF 不预览不弹窗(2024年)

版本&#xff1a; Chrome 124Python 12Selenium 4.19.0 版本与我有差异不要紧&#xff0c;只要别差异太大比如 Chrome 用 57 之前的版本了&#xff0c;就可以看本文。 如果你从前完全没使用过、没安装过Selenium&#xff0c;可以参考这篇博客《【记录】Python3&#xff5c;Sele…

kafka---topic详解

一、分区与高可用 在Kafka中,事件(events 事件即消息)是以topic的形式进行组织的;同时topic是分区(partitioned)的,这意味着一个topic分布在Kafka broker上的多个“存储桶”(buckets)上。这种数据的分布式放置对于可伸缩性非常重要,因为它允许客户端应用程序同时从多个…

Stable Diffusion WebUI 控制网络 ControlNet 插件实现精准控图-详细教程

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里&#xff0c;订阅后可阅读专栏内所有文章。 大家好&#xff0c;我是水滴~~ 本文主要介绍 Stable Diffusion WebUI 一个比较重要的插件 ControlNet&#xff08;控制网络&#xff09;&#xff0c;主…

PHP货运搬家/拉货小程序二开源码搭建的功能

运搬家/拉货小程序的二次开发可以添加许多功能&#xff0c;以增强用户体验和提高业务效率。以下是一些可能的功能&#xff1a; 用户端功能&#xff1a; 注册登录&#xff1a;允许用户创建个人账户并登录以使用应用程序。货物发布&#xff1a;允许用户发布他们需要搬运的货物信息…

HTML转EXE 各平台版本(Windows, IOS, Android)

前言&#xff1a; 在几年前&#xff0c;我在盒子论坛中看到有人提供了一个将HTML打包成EXE文件的程序的软件&#xff0c;好像是外国人做的&#xff0c;该软件是收费的。当时我在想&#xff0c;这个功能不是很难实现呀&#xff0c;于是我就有了开发一个HTML转EXE的工具想法&…

数据可视化-ECharts Html项目实战(13)

在之前的文章中&#xff0c;我们深入学习ECharts动态主题切换和自定义ECharts主题。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢。 数据可视化-ECharts Html项…

写后端项目的分页查询时,解决分页不更新

写基于VueSpringBoot项目&#xff0c;实现分页查询功能时&#xff0c;改完代码后&#xff0c;发现页数不更新&#xff1a; 更改处如下&#xff1a; 显示如图&#xff1a; 发现页数没有变化&#xff0c;两条数据还是显示在同一页&#xff0c;而且每页都10条。且重启项目也没有更…