stable diffusion 量化加速点

文章目录

    • 一、导出为dynamic shape
      • 1)函数讲解(函数导出、输出检查)
      • 2)代码展示
    • 二、导出为static shape
      • 1)函数讲解(略)
      • 2)代码展示
    • 三、序列化为FP32测速
      • 1)测速
      • 2)代码
    • 四、序列化为FP16测速
      • 1)测速
      • 2)代码同上
    • 五、发现并解决解决CLIP FP16溢出,并测速
      • 1)如何找到溢出的算子
      • 2)CLIP溢出算子解决方案
      • 3)其他FP16算子溢出的解决方案
    • 六、cuda-graph代码优化并测速
    • 七、图片迭代次数优化PD、合并GroupNorm算子制作plugin,UNet和ControlNet拼batch测试
      • 1)迭代次数优化
      • 2)合并GroupNorm算子
      • 3)UNet和ControlNet拼batch
    • 八、根据smooth-quant算法优化INT8量化,对比测速PD
      • 1)smooth-quant算法原理
      • 2)smooth-quant算法代码
      • 3)测速PD损失

一、导出为dynamic shape

1)函数讲解(函数导出、输出检查)

①torch.onnx.export

    torch.onnx.export(clip_model,(tokens),onnx_path,verbose=True,opset_version=18,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes,)
(1)export_params:默认为true,表示导出的 ONNX 模型文件会包含模型的所有参数(如权重、偏置等)。而当设置为 False 时,导出的 ONNX 模型文件仅包含模型的计算图结构,不包含模型的参数。这意味着导出的 ONNX 文件会小很多,因为它没有存储大量的参数数据
(2)verbose:为true表示,将会输出大量打印日志信息
(3)do_constant_folding:一般为true,是一个布尔类型的参数,其作用是控制在导出 ONNX 模型时是否进行常量折叠优化从而提高推理性能。为TRUE开启常量折叠优化。在导出 ONNX 模型时,会对图中所有仅包含常量输入的操作进行预先计算,并用计算结果替换这些操作,以此简化计算图,减少模型的计算量和复杂度。
(4)input_names和output_names:输入、输出参数
(5)dynamic_axes:是一个字典,其键为输入或输出张量的名称,值也是一个字典,用于指定该张量中哪些维度是动态的。内层字典的键是维度索引(从 0 开始),值是一个字符串,用于标识这个动态维度,通常在 ONNX 运行时会使用这个标识来指定具体的维度大小
(6)opset_version:指定optset的版本输入参数举例:dynamic_axes = {"x": {0: "batch_size"},"hint": {0: "batch_size"},"timesteps": {0: "batch_size"},"context": {0: "batch_size", 1: "sequence_length"},"output": {0: "batch_size", 1: "hint_height", 2: "hint_width"}}dynamic_axes = {"input_ids": {1: "S"}, "last_hidden_state": {1: "S"}}dynamic_axes = {"x": {0: "latent"},}

②误差检查

#onnx_path onnx文件目录
#input_dicts  输入参数
#torch_outputs  模型输出结果
def onnxruntime_check(onnx_path, input_dicts, torch_outputs):onnx_model = onnx.load(onnx_path)# onnx.checker.check_model(onnx_model)sess = rt.InferenceSession(onnx_path)# outputs = self.get_output_names()# latent input# data = np.zeros((4, 77), dtype=np.int32)result = sess.run(None, input_dicts)cnt = 0for i in range(0, len(torch_outputs)):ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)cnt = cnt +1if ret is False:#print(f"onnxruntime_check {i} ret:{ret}  result[i]:{result[i]}  torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")print("Error onnxruntime_check")# import pdb; pdb.set_trace()#print("cnt:", cnt)

2)代码展示

  • 代码
import numpy as np
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3
import cv2
import datetime
from share import *
import configimport cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import osfrom pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from onnx import shape_inference
import onnx_graphsurgeon as gs
import onnx
import onnxruntime as rtdef optimize(onnx_path, opt_onnx_path):from onnxsim import simplifymodel = onnx.load(onnx_path)graph = gs.import_onnx(model)print(f"{onnx_path} simplify start !")# self.info("init", graph)model_simp, check = simplify(model)# self.info("opt", gs.import_onnx(model_simp))onnx.save(model_simp, opt_onnx_path, save_as_external_data=True)assert check, "Simplified ONNX model could not be validated"print(f"{onnx_path} simplify done !")def onnxruntime_check(onnx_path, input_dicts, torch_outputs):onnx_model = onnx.load(onnx_path)# onnx.checker.check_model(onnx_model)sess = rt.InferenceSession(onnx_path)# outputs = self.get_output_names()# latent input# data = np.zeros((4, 77), dtype=np.int32)result = sess.run(None, input_dicts)cnt = 0for i in range(0, len(torch_outputs)):ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)cnt = cnt +1if ret is False:#print(f"onnxruntime_check {i} ret:{ret}  result[i]:{result[i]}  torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")print("Error onnxruntime_check")# import pdb; pdb.set_trace()#print("cnt:", cnt)class hackathon():def initialize(self):self.apply_canny = CannyDetector()self.model = create_model('./models/cldm_v15.yaml').cpu()self.model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cpu'))# self.model.load_state_dict(load_state_dict('/home/player/ControlNet/models/control_sd15_canny.pth', location='cuda'))self.model = self.model.cpu()self.model.eval()self.ddim_sampler = DDIMSampler(self.model)hk = hackathon()
hk.initialize()def export_clip_model():clip_model = hk.model.cond_stage_modelimport typesdef forward(self, tokens):outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")if self.layer == "last":z = outputs.last_hidden_stateelif self.layer == "pooled":z = outputs.pooler_output[:, None, :]else:z = outputs.hidden_states[self.layer_idx]return zclip_model.forward = types.MethodType(forward, clip_model)onnx_path = "./onnx/CLIP.onnx"tokens = torch.zeros(1, 77, dtype=torch.int32)input_names = ["input_ids"]output_names = ["last_hidden_state"]dynamic_axes = {"input_ids": {1: "S"}, "last_hidden_state": {1: "S"}}torch.onnx.export(clip_model,(tokens),onnx_path,verbose=True,opset_version=18,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes,)print("======================= CLIP model export onnx done!")# verify onnx modeloutput = clip_model(tokens)input_dicts = {"input_ids": tokens.numpy()}onnxruntime_check(onnx_path, input_dicts, [output])print("======================= CLIP onnx model verify done!")# opt_onnx_path = "./onnx/CLIP.opt.onnx"# optimize(onnx_path, opt_onnx_path)def export_control_net_model():control_net_model = hk.model.control_modelonnx_path = "./onnx/control_net_model.onnx"def get_shape(B=1,S=64):return [(B, 4, 32, 48),(B, 3, 256, 384),tuple([B])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75349.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

7-openwrt-one通过web页面配置访客网络、无线中继等功能

前几个章节一直在介绍编译、分区之类的,都还没正常开始使用这个路由器的wifi。默认wifi是没有启动的,前面还是通过手动修改uci配置启动的,这个章节介绍下官方web页面的使用。特别是访客网络、无线中继 1、开启wifi,配置wifi基本信息 我们使用有线连接路由器,通过192.168.…

AcWing 6099. 座位

原题目链接 问题描述 有 n 头奶牛(n ≥ 5),编号为 1 ∼ n,按照某种顺序围着一张圆桌坐成一圈。 奶牛之间存在如下的朋友关系: 如果两头奶牛相邻,则它们是朋友;如果两头奶牛之间只隔着一头奶…

44、Spring Boot 详细讲义(一)

Spring Boot 详细讲义 目录 Spring Boot 简介Spring Boot 快速入门Spring Boot 核心功能Spring Boot 技术栈与集成Spring Boot 高级主题Spring Boot 项目实战Spring Boot 最佳实践总结 一、Spring Boot 简介 1. Spring Boot 概念和核心特点 1.1、什么是 Spring Boot&#…

配置mac mini M4 的一些软件

最近更换了 mac mini M4 ,想要重新下载配置软件 ,记录一下。 Homebrew是什么? homebrew是一款Mac OS平台下的软件包管理工具,拥有安装、卸载、更新、查看、搜索等功能。通过简单的指令可以实现包管理,而不用关心各种…

网络空间安全(54)CSRF

一、定义与原理 CSRF(Cross-Site Request Forgery),全称为跨站请求伪造,也被称为One Click Attack或Session Riding,缩写为CSRF或XSRF。它是一种网络安全漏洞,攻击者通过伪造用户的请求,利用用户…

分布式文件存储系统FastDFS

文章目录 1 分布式文件存储1_分布式文件存储的由来2_常见的分布式存储框架 2 FastDFS介绍3 FastDFS安装1_拉取镜像文件2_构建Tracker服务3_构建Storage服务4_测试图片上传 4 客户端操作1_Fastdfs-java-client2_文件上传3_文件下载4_获取文件信息5_问题 5 SpringBoot整合 1 分布…

安装了VM Tools,仍无法复制拖动-解决方案

今天在安装ubuntu时遇到了困扰许久的问题,安装了VM Tools,仍无法拖动主机文件到虚拟机,主要有两种原因并对应解决办法。 1.相关虚拟机设置选项卡中-客户机隔离-两个功能没有勾选 解决方案:勾选重启虚拟机即可 2.(这个…

Jmeter分布式测试启动

代理客户端配置 打开jmeter.properties文件,取消注释并设置端口(如server_port1099), 并添加server.rmi.ssl.disabletrue禁用SSL加密。 (Linux系统)修改jmeter-server文件中的RMI_HOST_DEF为代理机实际IP。…

火语言RPA--Oracle-导入数据表格

【组件功能】:导入特定的表格数据到包含同样字段的数据表 将表格对象数据通过数据库操作对象导入到指定数据库。 配置预览 配置说明 源表格 表格来源有“来自表格对象”和“来自表达式”2种,表达式支持DataTable类型变量。 对象 对应来自表格对象&…

Java的Selenium的特殊元素操作与定位之验证码

1.使用OCR技术识别验证 步骤: 截取整个网页的截图。 定位验证码图片元素。 根据验证码图片的位置和大小,从截图中裁剪出验证码图片。 使用OCR工具(如Tesseract)识别验证码图片中的文本。 2.手动处理验证码 步骤:…

OpenStack Yoga版安装笔记(十七)安全组笔记

一、安全组与iptables的关系 OpenStack的安全组(Security Group)默认是通过Linux的iptables实现的。以下是其主要实现原理和机制: 安全组与iptables的关系 OpenStack的安全组规则通过iptables的规则链实现。每条安全组规则会被转换为相应的i…

starrocks split函数和trino split函数差异性

在trino419和starrocks3.2.8中分别执行下面这两条sql,出来的结果是不一样的 select split(,,,)[1] as t1 select coalesce(split(,,&#

Spring Data JPA中的List底层:深入解析ArrayList的奥秘!!!

&#x1f31f; Spring Data JPA中的List底层&#xff1a;深入解析ArrayList的奥秘 &#x1f4a1; 你是否好奇过&#xff0c;为什么Spring Data JPA的查询方法返回的List<T>总是默认为ArrayList&#xff1f;本文将通过技术原理解析、验证实验和性能优化指南&#xff0c;为…

腾讯云智测试开发面经

1、投递时间线 2.20投递简历,3.11第一轮面试,3.30第二轮面试,4.4第三轮面试,4.10第四轮面试,4.11offer意向书 2、第一轮面试 第一轮面试技术面,面试官是导师,面试时长40多分钟 1)自我介绍 2)数组和列表的区别 3)了解哪些数据库 4)进程和线程的区别 5)了解哪…

【深度学习】【目标检测】【Ultralytics-YOLO系列】YOLOV3源码整体结构解析

【深度学习】【目标检测】【Ultralytics-YOLO系列】YOLOV3源码整体结构解析 文章目录 【深度学习】【目标检测】【Ultralytics-YOLO系列】YOLOV3源码整体结构解析前言代码结构整体data文件结构模型训练超参数配置文件解析数据集配置文件解析 models文件结构utils文件结构runs文…

Python常用排序算法

1. 冒泡排序 冒泡排序是一种简单的排序算法&#xff0c;它重复地遍历要排序的列表&#xff0c;比较相邻的元素&#xff0c;如果他们的顺序错误就交换他们。 def bubble_sort(arr):# 遍历所有数组元素for i in range(len(arr)):# 最后i个元素是已经排序好的for j in range(0, …

解锁塔能科技,开启工厂绿色转型与可持续发展双引擎

在全球积极推进可持续发展的大背景下&#xff0c;能源的高效利用与节能减排&#xff0c;已成为各行各业迈向高质量发展进程中无法回避的核心任务。工厂作为能源消耗大户与污染排放重点源头&#xff0c;其绿色转型迫在眉睫&#xff0c;这不仅关乎企业自身的长远发展&#xff0c;…

Spring Boot 线程池配置详解

Spring Boot 线程池配置详解 一、核心配置参数及作用 基础参数核心线程数 (corePoolSize)‌ 作用‌:线程池中始终保持存活的线程数量,即使空闲也不回收‌。 建议‌:根据任务类型设定(如 I/O 密集型任务可设为 CPU 核心数 2)‌。 最大线程数 (maxPoolSize)‌ 作用‌:…

入侵检测系统(IDS)和入侵防御系统(IPS)有啥区别?

入侵检测系统&#xff08;IDS&#xff09;和入侵防御系统&#xff08;IPS&#xff09;是网络安全中的两种关键技术&#xff0c;它们的核心区别在于 检测后的响应方式 和 部署位置。以下是详细对比&#xff1a; 1. 核心功能 - IDS&#xff08;入侵检测系统&#xff09; - 仅监…

【MySQL 数据库】数据表的操作

&#x1f525;博客主页&#x1f525;&#xff1a;【 坊钰_CSDN博客 】 欢迎各位点赞&#x1f44d;评论✍收藏⭐ 目录 1. 表的查看 1.1 语法 2. 表的创建 2.1 语法 2.2 练习 3. 查看表结构 3.1 语法 3.2 示例 4. 表的修改 4.1 语法 4.2 示例操作 4.2.1 向表中添加字段…