模型量化笔记--KL散度量化

KL散度量化

前面介绍的非对称量化中,是将数据中的min值和max值直接映射到[-128, 127]。
同样的,前面介绍的对称量化是将数据的最大绝对值 ∣ m a x ∣ |max| max直接映射到127。
上面两种直接映射的方法比较粗暴,而TensorRT中的int8量化是基于KL散度来选取最佳的阈值T来映射到127中。超出阈值 ± ∣ T ∣ \pm|T| ±T的数据会直接映射为阈值(类似于截断映射)。

KL散度定义

KL散度常用来衡量两个分布P和Q之间的差异,KL散度越小,两个分布越相似,其公式定义如下:
D K L = ∑ i P ( x i ) l o g ( P ( x i ) Q ( x i ) ) D_{KL} = \sum_{i}P(x_i)log(\frac{P(x_i)}{Q(x_i)}) DKL=iP(xi)log(Q(xi)P(xi))

TensorRT实现KL散度量化的步骤

  1. 基于原始输入数据生成拥有2048个bin的直方图
hist, bin_edges = np.histogram(P, bins = 2048)
  1. 在[128, 2048]返回内循环执行3-5步,寻找最佳的划分 b i n i bin_{i} bini
  2. [ 0 , b i n i ] [0,bin{i}] [0,bini]范围内的直方图数据作为原始P, 并将 b i n i bin_{i} bini之后的直方图数据进行求和,并累加到 b i n i − 1 bin_{i-1} bini1中形成以 b i n i bin_{i} bini作为划分的最终P分布
  3. 对P分布进行量化形成Q分布(一般是划分和合并bins,计算合并后的平均值作为Q分布对应bins的值)
  1. 计算P分布和Q分布的KL散度
  2. 根据最小的KL散度来选取最佳的 b i n b e s t bin_{best} binbest,将bin_edges[ b i n b e s t bin_{best} binbest]作为最终的阈值threshold,即映射到127的阈值T
  3. 根据最佳的阈值T来计算scale
    s c a l e = T i n t m a x = T 127 scale = \frac{T}{int_{max}} = \frac{T}{127} scale=intmaxT=127T
  4. 根据对称量化来量化原始数据(权重、激活值等等)

TensorRT使用KL散度量化的目的

通过KL散度选取合适的阈值T,根据阈值计算对应的缩放系数scale,力求int8量化后的数值能更准确表示出量化前的FP32数值。

代码案例

import random
import numpy as np
import matplotlib.pyplot as plt  
import copy
import scipy.stats as stats# 随机生成测试数据
def generator_P(size):walk = []avg = random.uniform(3.000, 600.999)std = random.uniform(500.000, 1024.959)for _ in range(size):walk.append(random.gauss(avg, std)) # 生成符合高斯分布的随机数return walk# 平滑p和q,防止出现nan值,因为KL散度会计算log(p/q), 当q为0值时会出现nan
def smooth_distribution(p, eps=0.0001):is_zeros = (p == 0).astype(np.float32)is_nonzeros = (p != 0).astype(np.float32)n_zeros = is_zeros.sum()n_nonzeros = p.size - n_zerosif not n_nonzeros:raise ValueError('The discrete probability distribution is malformed. All entries are 0.')eps1 = eps * float(n_zeros) / float(n_nonzeros)assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)hist = p.astype(np.float32)hist += eps * is_zeros + (-eps1) * is_nonzerosassert (hist <= 0).sum() == 0return histdef threshold_distribution(distribution, target_bin = 128):distribution = distribution[1:]length = distribution.size # 2047threshold_sum = sum(distribution[target_bin:]) # [128: ]kl_divergence = np.zeros(length - target_bin) # 初始化 2047 - 128 = 1919 个KL散度值for threshold in range(target_bin, length): # 遍历threshold寻找KL散度最低的阈值sliced_nd_hist = copy.deepcopy(distribution[:threshold]) # [0, threshold)内的作为Pp = sliced_nd_hist.copy() # 生成pp[threshold - 1] += threshold_sum # 把 [threshold:] 后的累加和加到 p[threshold - 1] 中threshold_sum = threshold_sum - distribution[threshold] # 更新下一轮的累加和,即上一轮的累加和减去即将移入P分布的区间数据is_nonzeros = (p != 0).astype(np.int64) # [0:threshold]内不为0的区间quantized_bins = np.zeros(target_bin, dtype = np.int64) # 初始化量化后的binsnum_merged_bins = sliced_nd_hist.size // target_bin # 计算多少个区间需要合并来计算平均值,例如最初有8个bins,需要合并到4个bins,则每两个bins需要进行合并# 合并binsfor j in range(target_bin): start = j * num_merged_bins # 合并开始的binsstop = start + num_merged_bins # 合并结束的binsquantized_bins[j] = sliced_nd_hist[start:stop].sum() # 计算区间内bins的总和quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum()# 计算qq = np.zeros(sliced_nd_hist.size, dtype = np.float64) # 初始化量化后的qfor j in range(target_bin):start = j * num_merged_binsif j == target_bin - 1:stop = -1else:stop = start + num_merged_bins # 每num_merged_bins个bins进行合并组成qnorm = is_nonzeros[start:stop].sum() # 看看合并区间里,不为0的区间个数if norm != 0:q[start:stop] = float(quantized_bins[j]) / float(norm) # 用均值(假如区间内都不为0)填充q# 平滑p和qp = smooth_distribution(p)q = smooth_distribution(q)# 计算p和q之间的KL散度kl_divergence[threshold - target_bin] = stats.entropy(p, q)# 寻找最小KL散度对应threshold的索引min_kl_divergence = np.argmin(kl_divergence)threshold_value = min_kl_divergence + target_bin # 计算真正的threshold, 基于最初的128, 因为一开始就是从128开始不断向外计算来扩大P的范围return threshold_valueif __name__ == '__main__':# 随机初始化测试数据size = 20480 P = generator_P(size) P = np.array(P)P = P[P > 0] # 保留大于0的数# print("maximum activation value", max(np.absolute(P))) # 最大的激活值hist, bin_edges = np.histogram(P, bins = 2048) # 生成直方图 hist表示每一个bins对应的数量, bins表示截止 threshold = threshold_distribution(hist, target_bin = 128) # 返回KL散度最小的划分binsprint("threshold: ", threshold)print("threshold edges:", bin_edges[threshold]) # 截止到threshold对应的bins, 能够表示的范围 bin_edges[-1]表示上面最大的激活值,即能够表示所有数# 计算scale# scale = bin_edges[threshold] / int_max # 即bin_edges[threshold] / 127 # 在最初的对称量化中,我们是用绝对值最大的数值作为bin_edges[threhold], 而TensorRT就是利用KL散度来评估最佳的bin_edges[threshold]# 分成 split_zie 组, density表示是否要normedplt.title("Relu activation value Histogram")plt.xlabel("Activation values")plt.ylabel("Normalized number of Counts")plt.hist(P, bins=2047)plt.vlines(bin_edges[threshold], 0, 30, colors = "r", linestyles = "dashed") # 红线向左就是能够表示的所有范围plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/109179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

家中种绿植有什么风水讲究?

现在越来越多的人&#xff0c;都居住在小区高楼里&#xff0c;与绿植的接触也越来越少&#xff0c; 因此&#xff0c;很多人会选择在自己家中种上几株绿植。在家里种植植物&#xff0c;不仅美观&#xff0c;陶冶情操&#xff0c;还能净化空气&#xff0c;为家中增添好的风水。 …

凉鞋的 Unity 笔记 109. 专题一 小结

109. 专题一 小结 在这一篇&#xff0c;我们来对第一个专题做一个小的总结。 到目前为止&#xff0c;大家应该能够感受到此教程的基调。 内容的难度非常简单&#xff0c;接近于零基础的程度&#xff0c;不过通过这些零基础内容所介绍的通识内容其实是笔者好多年的时间一点点…

下拉选择器的树状结构图

类似&#xff1a;【Vue-Treeselect 和 vue3-treeselect】树形下拉框 一&#xff1a;图 二&#xff1a;如果有多层级的数据结构&#xff0c;可以用treeselect插件实现 1、安装&#xff1a; npm install --save riophae/vue-treeselect 2、实现&#xff1a; <el-form ref&qu…

树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类

一、介绍 树叶识别系统。使用Python作为主要编程语言开发&#xff0c;通过收集常见的6中树叶&#xff08;‘广玉兰’, ‘杜鹃’, ‘梧桐’, ‘樟叶’, ‘芭蕉’, ‘银杏’&#xff09;图片作为数据集&#xff0c;然后使用TensorFlow搭建ResNet50算法网络模型&#xff0c;通过对…

vue3弹窗中循环生成表单的校验和重置问题

应用场景&#xff1a; 1、弹框里的表单是根据后台返回的时段生成的&#xff0c;后台返回几个时段&#xff0c;就渲染几组表单。 -1- 重置&#xff1a;遍历每个表单&#xff0c;获取当前表单的引用&#xff0c;在resetFields() -2- 校验&#xff1a;创建一个数组来存储每个表单的…

java线程

1. 总体路线 pom依赖 <properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> </properties> <dependencies><dependency> <groupId>org.projectlombo…

浏览器SSL证书过期怎么解决?

SSL证书是互联网安全的基石&#xff0c;它们用于保护网站和应用程序的数据传输。然而&#xff0c;SSL证书有一定的有效期&#xff0c;一旦证书过期&#xff0c;将导致浏览器显示安全警告&#xff0c;可能影响用户体验并降低网站的可信度。本文将详细介绍浏览器SSL证书过期问题的…

Jmeter执行接口自动化测试-如何初始化清空旧数据

需求分析&#xff1a; 每次执行完自动化测试&#xff0c;我们不会执行删除接口把数据删除&#xff0c;而需要留着手工测试&#xff0c;此时会导致下次执行测试有旧数据我们手工可能也会新增数据&#xff0c;导致下次执行自动化测试有旧数据 下面介绍两种清空数据的方法 一、通过…

QT的QStringList的使用

初始 化 默认构造函数创建一个空列表。可以使用初始值设定项列表构造函数创建包含元素的列表&#xff1a; QStringList fonts { "Arial", "Helvetica", "Times" }; 添加字符串 可以使用insert 、append&#xff08;&#xff09; 和 operator…

产品需求分析师的基本职责(合集)

产品需求分析师的基本职责1 职责 1、主要对用友司库云产品进行调研及产品规划; 2、根据司库云业务需求进行详细需求的用户故事、原型设计、需求分析、详细需求文档编写等; 3、进行产品的需求管理、需求验证、产品演示等需求工作; 4、配合开发、UE人员完成对产品的开发任务;…

酒店报修管理系统哪家好?设备巡检系统对酒店运营有什么帮助?

酒店报修管理系统是一款关键的软件工具&#xff0c;可以帮助酒店员工和客户更有效地管理酒店的各项运营活动。下面我们将通过问答形式&#xff0c;深入探讨酒店管理系统的特性和功效&#xff0c;以便了解它如何提升酒店员工的工作效率&#xff0c;以及如何将酒店的各个部门和员…

【区间 DP】运用区间 DP 解决古老原题

题目描述 这是 LeetCode 上的 「664. 奇怪的打印机」 &#xff0c;难度为 「困难」。 Tag : 「区间 DP」 有台奇怪的打印机有以下两个特殊要求&#xff1a; 打印机每次只能打印由 同一个字符 组成的序列。 每次可以在任意起始和结束位置打印新字符&#xff0c;并且会覆盖掉原来…

DNDC模型土壤碳储量、温室气体排放、农田减排、土地变化、气候变化中的实践应用

查看原文>>>DNDC模型土壤碳储量、温室气体排放、农田减排、土地变化、气候变化中的实践应用 目录 一、DNDC模型介绍 二、DNDC初步操作 三、遥感和GIS基础 四、DNDC气象数据 五、DNDC土地数据 六、DNDC土壤数据 七、DNDC结果分析 八、DNDC率定验证 九、土壤碳…

autox.js的三个版本universal、armeabi-v7a、arm64-v8a的区别

APK版本说明&#xff1a; universal: 通用版&#xff08;不在乎安装包大小/懒得选就用这个版本&#xff0c;包含以下2种CPU架构so&#xff09; armeabi-v7a: 32位ARM设备&#xff08;备用机首选&#xff09; arm64-v8a: 64位ARM设备&#xff08;主流旗舰机&#xff09; ABI在…

【Hello Algorithm】暴力递归到动态规划(四)

动态规划的数组压缩技巧 - 机器人走格子问题 题目是leetcode62题目原题 表示如下 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中…

分享Java NET Python三大技术下AutojsPro7云控代码

引言 有图有真相&#xff0c;那短视频就更是真相了。下面是三大语言的短视频。 Java源码版云控示例&#xff1a; Java源码版云控示例在线视频 Net源码版云控示例&#xff1a; Net源码版云控示例在线视频亚丁号-知识付费平台 支付后可见 扫码付费可见 Python源码版云控示例&…

openGauss Meetup(天津站)精彩回顾 | openGauss天津用户组正式成立

由openGauss社区、天开发展集团、天津市软件行业协会、天大智图&#xff08;天津&#xff09;科技有限公司联合主办的“openGauss Meetup • 天津站”已于10月13日落下帷幕&#xff0c;此次活动邀请到众多业内技术专家&#xff0c;从技术创新、学术创新、发展创新、以及生态共建…

【Python机器学习】零基础掌握CalibratedClassifierCV概率校准

有没有想过如何提高分类模型的可靠性? 在现实生活中,许多决策都依赖于分类模型。例如在医疗诊断中,一个模型可能用于预测一个肿瘤是良性还是恶性的。但是这些模型有时会给出不准确的概率估计。 考虑一个场景,医生使用一个模型来预测肿瘤性质。假设有以下模拟数据: 患者I…

基于内存的分布式NoSQL数据库Redis(五)数据存储与RDB设计

文章目录 知识点18&#xff1a;数据存储设计知识点19&#xff1a;Redis持久化&#xff1a;RDB设计知识点20&#xff1a;Redis持久化&#xff1a;RDB测试后记 知识点18&#xff1a;数据存储设计 目标&#xff1a;掌握常见数据存储的设计 实施 问题 数据存储如何保证数据安全&am…

QT实现凸凹边形等距缩放

参考&#xff1a;https://blog.csdn.net/weixin_39383896/article/details/99615371和https://blog.csdn.net/qq_15821883/article/details/117421400 代码逻辑思路&#xff1a; 1、获取向量AB、BC的坐标。 2、计算向量AB、BC的长度。 3、根据点乘获取cosθ大小。 4、根据cosθ…