商汤提出的BRECQ量化框架是个什么?
引言
近年来,深度学习在多个领域取得了显著进展,但其巨大的计算成本和内存占用问题逐渐凸显。为了压缩和加速已训练好的网络,量化成为了一种重要的技术手段。量化主要分为两类:量化感知训练(QAT)和后训练量化(PTQ)。尽管QAT能够取得较好的量化效果,但其需要完整的训练数据集和大量的计算资源。相比之下,PTQ则更加轻量级,仅需要一小部分校准数据即可进行量化,但低比特量化时的精度损失较大。本文提出了一种新的PTQ框架BRECQ(Block Reconstruction based Quantization),首次实现了INT2比特量化的极限突破。
方法概述
BRECQ通过分析量化过程中的二阶误差,并基于神经网络的基本构建块进行重建。其主要贡献包括:
-
基于二阶误差分析定义重建单元:本文定义了一组重建单元,并通过理论和实验证明块重建是实现跨层依赖和泛化误差之间良好平衡的最佳选择。
-
结合遗传算法和块内敏感度度量生成混合精度量化模型:通过遗传算法和块内敏感度度量,BRECQ能够生成具有延迟和大小保证的混合精度量化神经网络,适用于各种硬件平台。
-
广泛的实验验证:本文在多种手工设计和搜索得到的神经架构上进行了大量实验,证明了BRECQ在图像分类和目标检测任务中的有效性。
方法细节
二阶误差分析
量化可以视为对权重的一种特殊扰动。为了定量分析量化引起的损失退化,可以使用泰勒级数展开来近似:
E [ L ( w + Δ w ) ] − E [ L ( w ) ] ≈ Δ w T g ˉ ( w ) + 1 2 Δ w T H ˉ ( w ) Δ w E[L(w+\Delta w)] - E[L(w)] \approx \Delta w^T \bar{g}(w) + \frac{1}{2} \Delta w^T \bar{H}(w) \Delta w E[L(w+Δw)]−E[L(w)]≈ΔwTgˉ(w)+21ΔwTHˉ(w)Δw
其中, g ˉ ( w ) = E [ ∇ w L ] \bar{g}(w) = E[\nabla_w L] gˉ(w)=E[∇wL] 是梯度, H ˉ ( w ) = E [ ∇ w 2 L ] \bar{H}(w) = E[\nabla_w^2 L] Hˉ(w)=E[∇w2L] 是Hessian矩阵, Δ w \Delta w Δw 是权重扰动。
为了处理大规模Hessian矩阵的计算和存储问题,本文将其转化为输出Hessian矩阵,即:
arg min θ ^ Δ θ T H ˉ ( θ ) Δ θ ≈ arg min θ ^ E [ Δ z ( n ) T H ( z ( n ) ) Δ z ( n ) ] \arg \min_{\hat{\theta}} \Delta \theta^T \bar{H}(\theta) \Delta \theta \approx \arg \min_{\hat{\theta}} E[\Delta z^{(n)T} H(z^{(n)}) \Delta z^{(n)}] argθ^minΔθTHˉ(θ)Δθ≈argθ^minE[Δz(n)TH(z(n))Δz(n)]
块重建
网络输出重建虽然能准确估计二阶误差,但在实践中容易导致过拟合。本文提出块重建方法,即在每个块内进行输出重建,忽略块间依赖但考虑块内依赖。块重建的优点在于它能够在跨层依赖和泛化误差之间找到良好的平衡。
近似预激活Hessian
为了计算块内的二阶误差,需要用到预激活Hessian矩阵。本文使用对角Fisher信息矩阵(FIM)来近似预激活Hessian,优化目标变为:
min w ^ E [ Δ z ( i ) T diag ( ( ∂ L ∂ z 1 ( i ) ) 2 , . . . , ( ∂ L ∂ z a ( i ) ) 2 ) Δ z ( i ) ] \min_{\hat{w}} E[\Delta z^{(i)T} \text{diag}((\frac{\partial L}{\partial z^{(i)}_1})^2, ..., (\frac{\partial L}{\partial z^{(i)}_a})^2) \Delta z^{(i)}] w^minE[Δz(i)Tdiag((∂z1(i)∂L)2,...,(∂za(i)∂L)2)Δz(i)]
混合精度量化
为了进一步提升量化效果,BRECQ结合混合精度技术,通过遗传算法搜索最优的比特宽度配置。其优化目标为:
min c L ( w ^ , c ) , s.t. H ( c ) ≤ δ , c ∈ { 2 , 4 , 8 } n \min_c L(\hat{w}, c), \text{ s.t. } H(c) \leq \delta, c \in \{2, 4, 8\}^n cminL(w^,c), s.t. H(c)≤δ,c∈{2,4,8}n
其中, c c c 是比特宽度向量, H ( ⋅ ) H(\cdot) H(⋅) 是硬件性能度量函数, δ \delta δ 是性能阈值。
实验结果
图像分类任务
在ImageNet分类任务上,BRECQ在各种现代深度学习架构上均取得了优异的量化效果。特别地,在2比特权重量化下,BRECQ的精度损失控制在5%以内,远超过其他现有方法。
目标检测任务
在MS COCO目标检测任务上,BRECQ在4比特权重和8比特激活量化下,性能几乎无损。即使在2比特权重量化下,模型仍能保持接近原始的性能。
混合精度量化
通过遗传算法搜索最优的混合精度配置,BRECQ能够在相同延迟下显著提升量化模型的精度,并适应不同的硬件要求。
结论
BRECQ是一种基于块重建的后训练量化框架,通过二阶误差分析和混合精度技术,实现了INT2比特量化的极限突破。实验结果表明,BRECQ在多种任务和模型上均取得了优异的量化效果,为深度学习模型的压缩和加速提供了新的思路。
代码示例
由于篇幅限制,这里仅展示BRECQ框架中部分关键步骤的伪代码实现。完整的实现代码请参考论文附带的源代码。
# 伪代码:块重建优化算法
def block_reconstruction(model, calibration_data, iterations):for block in model.blocks:input_data, fp_output = prepare_input_output(block, calibration_data)for _ in range(iterations):quantized_output = quantize_block(block)delta_z = fp_output - quantized_outputupdate_block_weights(block, delta_z)return model# 伪代码:遗传算法搜索混合精度配置
def genetic_algorithm_search(population, mutation_prob, iterations, threshold):for _ in range(iterations):fitness = evaluate_fitness(population)sorted_population = sort_population(fitness)crossover_population = crossover(sorted_population)mutation_population = mutate(sorted_population, mutation_prob)population = combine_populations(crossover_population, mutation_population, threshold)return best_individual(population)
希望本文解析能帮助大家更好地理解和应用BRECQ框架。