模型量化(二)—— 训练后量化PTQ(全代码)

训练后量化(Post-training Quantization,PTQ)是一种常见的模型量化技术,它在模型训练完成之后应用,旨在减少模型的大小和提高推理速度,同时尽量保持模型的性能。训练后量化对于部署到资源受限的设备上,如移动设备和嵌入式设备,特别有用。

在我们量化时,量化操作可以应用于模型的输入、权重 和 激活(即神经元输出值)上。

但我们发现,对于激活值,我们执行反量化时,并不知道这些激活值对应的浮点数矩阵的最大值和最小值,即我们执行非对称或对称量化里面的 𝛼, β 参数,所以我们拿到一个模型时,最多只能对它的权重W和输入X做量化,对于激活值Y的反量化,我们需要一组小的calibration set数据来初步计算对于Y的S和Z参数。

不熟悉非对称或对称量化的朋友可以康康这篇:《模型量化(一)—— 非对称量化、对称量化(全代码)》

在这里插入图片描述
 

目录

  • PTQ流程:
  • 全代码
    • 预训练模型
    • 加入Observer
    • 校准模型
      • 量化模型

 

PTQ流程:

在这里插入图片描述
Observer,顾名思义就是模型在正常inference的时候会被记录下正常的浮点激活值,用来算激活值对应的S和Z参数。

Calibrate后模型的W和Y都有对应的S和Z了,模型名义上量化完成。浮点的输入X也能off-line地实时算它对应的S和Z。

所以量化后的模型运行时,先对浮点输入进行量化,然后与整型的W矩阵相乘,得到整型的激活值,这时再反量化为浮点激活值,对应于下一个神经元的浮点输入,依次循环。
大家可能会想吗,这么麻烦,又是量化又是反量化,怎么还会压缩模型和加速模型呢?

压缩模型:原本所有的W都是浮点数存储,比如float32,现在转换为int8存储,模型尺寸减了大概4倍;再额外存一些神经元或网络层的S和Z参数(取决于量化的粗粒度),相对于W来说占内存很小(如果是很细粒度的量化可能这部分也得好好考虑,量化的粒度分为权重级量化、层级量化、通道级量化等)。

加速模型:主要的收益是使得模型中占大头的 W * X 操作变成了整型相乘,功耗和时延最低(浮点数相乘时功耗和时延最大)。3 * 100 * 100 * 10的全连接网络中,有213个神经元,但是有 3 * 100 * 100 * 10 = 300M个参数!这还是忽略了bias。量化相当于就是让这 300M 次乘法更轻量。而相对的 overhead 就是对开头的3个输入进行一下量化 和 对210和神经元的输出进行一下反量化,这部分开销随着网络层数与参数的增加几乎可以忽略不计。
一些专门的深度学习加速器和现代CPU/GPU提供了对低位宽整数(如int8)的优化支持,用这些硬件后可以更加体现模型量化的优势。

量化会带来一定的量化误差,即模型精度会受影响,这肯定的,但按经验来说几乎没什么影响,不要压到int4或int2这么极限就行。

 

全代码

预训练模型

import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os# Make torch deterministic
_ = torch.manual_seed(0)transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)# Define the device
device = "cpu"# Define the model
class VerySimpleNet(nn.Module):def __init__(self, hidden_size_1=100, hidden_size_2=100):super(VerySimpleNet,self).__init__()self.linear1 = nn.Linear(28*28, hidden_size_1) self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) self.linear3 = nn.Linear(hidden_size_2, 10)self.relu = nn.ReLU()def forward(self, img):x = img.view(-1, 28*28)x = self.relu(self.linear1(x))x = self.relu(self.linear2(x))x = self.linear3(x)return xnet = VerySimpleNet().to(device)# Train the model
def train(train_loader, net, epochs=5, total_iterations_limit=None):cross_el = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(net.parameters(), lr=0.001)total_iterations = 0for epoch in range(epochs):net.train()loss_sum = 0num_iterations = 0data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')if total_iterations_limit is not None:data_iterator.total = total_iterations_limitfor data in data_iterator:num_iterations += 1total_iterations += 1x, y = datax = x.to(device)y = y.to(device)optimizer.zero_grad()output = net(x.view(-1, 28*28))loss = cross_el(output, y)loss_sum += loss.item()avg_loss = loss_sum / num_iterationsdata_iterator.set_postfix(loss=avg_loss)loss.backward()optimizer.step()if total_iterations_limit is not None and total_iterations >= total_iterations_limit:returndef print_size_of_model(model):torch.save(model.state_dict(), "temp_delme.p")print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)os.remove('temp_delme.p')MODEL_FILENAME = 'simplenet_ptq.pt'if Path(MODEL_FILENAME).exists():net.load_state_dict(torch.load(MODEL_FILENAME))print('Loaded model from disk')
else:train(train_loader, net, epochs=1)# Save the model to disktorch.save(net.state_dict(), MODEL_FILENAME)# Define the testing loop
def test(model: nn.Module, total_iterations: int = None):correct = 0total = 0iterations = 0model.eval()with torch.no_grad():for data in tqdm(test_loader, desc='Testing'):x, y = datax = x.to(device)y = y.to(device)output = model(x.view(-1, 784))for idx, i in enumerate(output):if torch.argmax(i) == y[idx]:correct +=1total +=1iterations += 1if total_iterations is not None and iterations >= total_iterations:breakprint(f'Accuracy: {round(correct/total, 3)}')# Print weights and size of the model before quantization# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)print('Size of the model before quantization')
print_size_of_model(net)print(f'Accuracy of the model before quantization: ')
test(net)

加入Observer

# Insert min-max observers in the modelclass QuantizedVerySimpleNet(nn.Module):def __init__(self, hidden_size_1=100, hidden_size_2=100):super(QuantizedVerySimpleNet,self).__init__()self.quant = torch.quantization.QuantStub()self.linear1 = nn.Linear(28*28, hidden_size_1) self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) self.linear3 = nn.Linear(hidden_size_2, 10)self.relu = nn.ReLU()self.dequant = torch.quantization.DeQuantStub()def forward(self, img):x = img.view(-1, 28*28)x = self.quant(x)x = self.relu(self.linear1(x))x = self.relu(self.linear2(x))x = self.linear3(x)x = self.dequant(x)return xnet_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

校准模型

#用测试集再跑一次装了observer的模型
test(net_quantized)print(f'Check statistics of the various layers')
net_quantized

在这里插入图片描述
这时看到激活层的 𝛼, β 都有了,good!

量化模型

# Quantize the model using the statistics collectednet_quantized = torch.ao.quantization.convert(net_quantized)print(f'Check statistics of the various layers')
net_quantized

在这里插入图片描述

# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))# Compare the dequantized weights and the original weights
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')# Print size and accuracy of the quantized model
print('Size of the model after quantization')
print_size_of_model(net_quantized)
print('Testing the model after quantization')
test(net_quantized)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/741973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过对话式人工智能实现个性化用户体验

智能交流新时代:如何选择对话式人工智能产品 在快速发展的数字环境中,对话式人工智能正在彻底改变企业与客户互动的方式。 通过集成机器学习、自然语言处理和语音识别等先进技术,对话式人工智能可提供个性化、无缝的用户体验。 了解对话式人…

中间件 | Kafka - [常见问题]

INDEX 1 消息丢失1.1 消息丢失位置1.2 如何避免消息丢失 2 顺序消费 1 消息丢失 1.1 消息丢失位置 1:producer 向 kafka 投递消息时2:kafka-topic 中 leader 已经写入了消息,向副本写入消息前挂了时3:消费者从 kafka 拉取了消息&…

OSCP靶场--Depreciated

OSCP靶场–Depreciated 考点(1. graphql枚举 2.CVE-2021-4034提权) 1.nmap扫描 ┌──(root㉿kali)-[~/Desktop] └─# nmap -sV -sC -p- 192.168.155.170 --min-rate 2500 Starting Nmap 7.92 ( https://nmap.org ) at 2024-03-13 04:19 EDT Nmap scan report for 192.168.…

平台靠不住了,独立站,自主权,LTD营销枢纽助力企业应对全球化挑战

当今全球化的市场环境中,我国的出海品牌和供应链面临着很大的挑战,但同时也蕴含着机遇。随着跨境电商的兴起,像亚马逊、TikTok等大的电商平台成为中国卖家走向世界的重要桥梁。不过,平台的政策改变和外部环境的不确定性因素给依赖…

Rabbit算法:轻量高效的加密利器

title: Rabbit算法:轻量高效的加密利器 date: 2024/3/13 18:14:31 updated: 2024/3/13 18:14:31 tags: Rabbit算法流密码高安全性高性能密钥调度加密解密抗攻击性 Rabbit算法起源: Rabbit算法是由Martin Boesgaard和Mette Vesterager提出的一种流密码算…

(C语言)strcpy与strcpy详解,与模拟实现

目录 1. strcpy strcpy模拟实现&#xff1a; 实现方法1&#xff1a; 实现方法2&#xff1a; 2. strcat strcat模拟实现&#xff1a; 1. strcpy 作用&#xff1a;完成字符串的复制。 头文件&#xff1a;<string.h> destination是字符串要复制到的地点&#xff0c;s…

这款自动引流软件居然能让你的营销效果翻倍提升!

在数字化时代&#xff0c;营销策略的高效执行对企业来说至关重要。自动引流软件作为现代企业营销工具箱中的一员&#xff0c;其重要性不言而喻。这类软件通过智能化、自动化的方式&#xff0c;将潜在客户吸引到企业的销售渠道中&#xff0c;从而为企业带来可观的收益和品牌曝光…

SpringBoot集成netty实现websocket通信

实现推送消息给指定的用户 一、依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://m…

武汉儿童医院变电所电力运维平台系统的设计及应用

彭姝麟 Acrelpsl 1 引言 2015年国务院发布《中共中央、国务院关于进一步深化电力体制改革的若干意见》&#xff08;中发[2015]9号&#xff09;&#xff0c;简称“电改9号文”。而本次新电改的重点是“三放开一独立三强化”&#xff1a;输配以外的经营性电价放开、售电业务放开…

法规解读 | 坚持总体国家安全观,新修订的《保守国家秘密法》今年5月1日起施行!

2024年2月27日&#xff0c;第十四届全国人大常委会第八次会议表决通过新修订的《中华人民共和国保守国家秘密法》&#xff08;以下简称保密法&#xff09;&#xff0c;自2024年5月1日起施行。 本次保密法修订坚持总体国家安全观&#xff0c;统筹发展与安全。 一方面吸收了一些工…

政务云安全风险分析与解决思路探讨

1.1概述 为了掌握某市政务网站的网络安全整体情况&#xff0c;在相关监管机构授权后&#xff0c;我们组织人员抽取了某市78个政务网站进行安全扫描&#xff0c;通过安全扫描&#xff0c;对该市政务网站的整体安全情况进行预估。 1.2工具扫描结果 本次利用漏洞扫描服务VSS共扫…

app逆向-ratel框架-sekiro框架的安装使用

文章目录 一、前言二、初次尝试三、原⽣APP的使⽤四、ratel框架结合sekiro框架使用 一、前言 sekiro主要支持多节点的程序调用&#xff0c;所以他归属于RPC&#xff08;Remote Procedure Call&#xff09;框架&#xff1a;API管理、鉴权、分布式、负载均衡、跨语言 开源文档&…

如何在群晖NAS部署WPS容器并实现无公网IP远程访问本地office软件

文章目录 1. 拉取WPS Office镜像2. 运行WPS Office镜像容器3. 本地访问WPS Office4. 群晖安装Cpolar5. 配置WPS Office远程地址6. 远程访问WPS Office小结 7. 固定公网地址 wps-office是一个在Linux服务器上部署WPS Office的镜像。它基于WPS Office的Linux版本&#xff0c;通过…

【C语言】指针详解2

&#x1f451;个人主页&#xff1a;啊Q闻 &#x1f387;收录专栏&#xff1a;《C语言》 &#x1f389;道阻且长&#xff0c;行则将至 前言 这篇博客分享的指针部分为与数组有关的指针知识&#xff0c;包括一位数组和二维数组 指针详解1的博客 【C语言】指针…

算法思想总结:双指针算法

一、移动零 . - 力扣&#xff08;LeetCode&#xff09; 移动零 该题重要信息&#xff1a;1、保持非0元素的相对位置。2、原地对数组进行操作 思路&#xff1a;双指针算法 class Solution { public:void moveZeroes(vector<int>& nums){int nnums.size();for(int cur…

【Linux】Shell编程【二】

目录 Shell流程控制条件测试注意事项示例[ condition ]与[[ condition ]]的区别 if条件单分支语法示例1&#xff1a;统计根分区使用率示例2&#xff1a;创建目录 双分支if条件语句语法案例1&#xff1a;备份mysql数据库案例2&#xff1a;判断apache是否启动&#xff0c;如果没有…

网络学习:9个计算机的“网络层”知识点

目录 一、IP 地址 1.1 分类表示法&#xff1a; 1.1.1 分类表示地址的其他说明 1.2 无分类编址 CIDR 二、IP 数据报文格式 Q: IP 报文里有什么&#xff1f;可以不按顺序或者字节来讲一讲 三、 路由概念 3.1 路由表 3.2 路由网络匹配 3.3 ARP 解析 3.4 RARP 逆地址解析…

考试题库:华为HCIA-Datacom易错题⑦(含答案解析)

华为认证HCIA-Datacom易错题举例和答案分析。 1、现有一台交换机通过某端口与一个指定端口相连&#xff0c;但是该端口不转发任何报文&#xff0c;却可以通过接收BPDU来监听网络变化&#xff0c;那么该端口的角色应该是&#xff08; &#xff09;。 A、Designated端口 B、Al…

分布式搜索elasticsearch(1)

1.初识elasticsearch 1.1.了解ES 1.1.1.elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎&#xff0c;具备非常多强大功能&#xff0c;可以帮助我们从海量数据中快速找到需要的内容 例如&#xff1a; 在GitHub搜索代码 在电商网站搜索商品 在百度搜索答案…

c++函数SetConsoleTextAttribute

前言 正文 1.作用&#xff1a; 2.函数格式(重点)&#xff1a; 3.参数(重点)&#xff1a; 前言 实用(真的) 正文 1.作用&#xff1a; 更改cmd的背景色与字体颜色 2.函数格式(重点)&#xff1a; SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE),10进制参数); …