大模型 - 知识蒸馏原理解析

知识蒸馏的详细过程和原理解析

知识蒸馏是一种通过将大型预训练模型(教师模型)的知识传递给较小模型(学生模型)的方法。这样可以在减少模型的复杂度和计算资源需求的同时,尽量保留模型的性能。以下是知识蒸馏的详细过程和每个步骤中用到的原理。

1. 输入数据

假设我们有一个图像分类任务,输入数据 x x x 是一张图像。这个图像同时馈送给教师模型和学生模型。

2. 教师模型

  • 教师模型是一个已经训练好的大模型,它对输入 x x x 进行预测。
  • 教师模型的输出经过一个带温度参数 T T T 的 softmax 函数,得到软标签(soft labels)。温度参数 T T T 用于平滑预测概率,使得输出概率分布更平缓。

具体来说,假设教师模型输出的 logits 为 [ 2.0 , 1.0 , 0.1 ] [2.0, 1.0, 0.1] [2.0,1.0,0.1],在温度 T = 2 T=2 T=2 下,softmax 计算如下:

softmax ( z i ; T = 2 ) = e z i / 2 ∑ j e z j / 2 \text{softmax}(z_i; T=2) = \frac{e^{z_i / 2}}{\sum_{j} e^{z_j / 2}} softmax(zi;T=2)=jezj/2ezi/2

计算得:
softmax ( 2.0 / 2 ) = e 1.0 e 1.0 + e 0.5 + e 0.05 = 0.504 \text{softmax}(2.0 / 2) = \frac{e^{1.0}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.504 softmax(2.0/2)=e1.0+e0.5+e0.05e1.0=0.504
softmax ( 1.0 / 2 ) = e 0.5 e 1.0 + e 0.5 + e 0.05 = 0.277 \text{softmax}(1.0 / 2) = \frac{e^{0.5}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.277 softmax(1.0/2)=e1.0+e0.5+e0.05e0.5=0.277
softmax ( 0.1 / 2 ) = e 0.05 e 1.0 + e 0.5 + e 0.05 = 0.219 \text{softmax}(0.1 / 2) = \frac{e^{0.05}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.219 softmax(0.1/2)=e1.0+e0.5+e0.05e0.05=0.219

软标签为 [ 0.504 , 0.277 , 0.219 ] [0.504, 0.277, 0.219] [0.504,0.277,0.219]

3. 学生模型

  • 学生模型是一个较小的模型,它也对输入 x x x 进行预测。
  • 学生模型的输出经过两个 softmax 函数处理,一个带温度 T T T 得到软预测(soft predictions),另一个带温度 T = 1 T=1 T=1 得到硬预测(hard predictions)。

假设学生模型输出的 logits 为 [ 1.8 , 0.9 , 0.4 ] [1.8, 0.9, 0.4] [1.8,0.9,0.4],在温度 T = 2 T=2 T=2 下,softmax 计算如下:

softmax ( 1.8 / 2 ) = e 0.9 e 0.9 + e 0.45 + e 0.2 = 0.474 \text{softmax}(1.8 / 2) = \frac{e^{0.9}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.474 softmax(1.8/2)=e0.9+e0.45+e0.2e0.9=0.474
softmax ( 0.9 / 2 ) = e 0.45 e 0.9 + e 0.45 + e 0.2 = 0.301 \text{softmax}(0.9 / 2) = \frac{e^{0.45}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.301 softmax(0.9/2)=e0.9+e0.45+e0.2e0.45=0.301
softmax ( 0.4 / 2 ) = e 0.2 e 0.9 + e 0.45 + e 0.2 = 0.225 \text{softmax}(0.4 / 2) = \frac{e^{0.2}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.225 softmax(0.4/2)=e0.9+e0.45+e0.2e0.2=0.225

软预测为 [ 0.474 , 0.301 , 0.225 ] [0.474, 0.301, 0.225] [0.474,0.301,0.225]

硬预测( T = 1 T=1 T=1)的 softmax 计算如下:
softmax ( 1.8 ) = e 1.8 e 1.8 + e 0.9 + e 0.4 = 0.659 \text{softmax}(1.8) = \frac{e^{1.8}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.659 softmax(1.8)=e1.8+e0.9+e0.4e1.8=0.659
softmax ( 0.9 ) = e 0.9 e 1.8 + e 0.9 + e 0.4 = 0.242 \text{softmax}(0.9) = \frac{e^{0.9}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.242 softmax(0.9)=e1.8+e0.9+e0.4e0.9=0.242
softmax ( 0.4 ) = e 0.4 e 1.8 + e 0.9 + e 0.4 = 0.099 \text{softmax}(0.4) = \frac{e^{0.4}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.099 softmax(0.4)=e1.8+e0.9+e0.4e0.4=0.099

硬预测为 [ 0.659 , 0.242 , 0.099 ] [0.659, 0.242, 0.099] [0.659,0.242,0.099]

4. 蒸馏损失(Distillation Loss)

  • 蒸馏损失是教师模型的软标签和学生模型的软预测之间的差异,通常使用 KL 散度(Kullback-Leibler Divergence)作为损失函数。

D K L ( P ∥ Q ) = ∑ x ∈ X P ( x ) log ⁡ ( P ( x ) Q ( x ) ) D_{KL}(P \parallel Q) = \sum_{x \in X} P(x) \log \left( \frac{P(x)}{Q(x)} \right) DKL(PQ)=xXP(x)log(Q(x)P(x))

假设软标签 P P P [ 0.504 , 0.277 , 0.219 ] [0.504, 0.277, 0.219] [0.504,0.277,0.219],软预测 Q Q Q [ 0.474 , 0.301 , 0.225 ] [0.474, 0.301, 0.225] [0.474,0.301,0.225]
D K L ( P ∥ Q ) = 0.504 log ⁡ ( 0.504 0.474 ) + 0.277 log ⁡ ( 0.277 0.301 ) + 0.219 log ⁡ ( 0.219 0.225 ) D_{KL}(P \parallel Q) = 0.504 \log \left( \frac{0.504}{0.474} \right) + 0.277 \log \left( \frac{0.277}{0.301} \right) + 0.219 \log \left( \frac{0.219}{0.225} \right) DKL(PQ)=0.504log(0.4740.504)+0.277log(0.3010.277)+0.219log(0.2250.219)
计算得:
D K L ( P ∥ Q ) = 0.504 ⋅ 0.0623 + 0.277 ⋅ − 0.0848 + 0.219 ⋅ − 0.0267 D_{KL}(P \parallel Q) = 0.504 \cdot 0.0623 + 0.277 \cdot -0.0848 + 0.219 \cdot -0.0267 DKL(PQ)=0.5040.0623+0.2770.0848+0.2190.0267
= 0.0314 − 0.0235 − 0.0058 = 0.0314 - 0.0235 - 0.0058 =0.03140.02350.0058
= 0.0021 = 0.0021 =0.0021

5. 学生损失(Student Loss)

  • 学生损失是学生模型的硬预测和真实标签(硬标签)之间的差异,通常使用交叉熵损失函数。

假设真实标签 y y y 为类别 1,则 one-hot 编码为 [ 1 , 0 , 0 ] [1, 0, 0] [1,0,0],硬预测为 [ 0.659 , 0.242 , 0.099 ] [0.659, 0.242, 0.099] [0.659,0.242,0.099],交叉熵损失为:

H ( y , y ^ ) = − ∑ i y i log ⁡ ( y ^ i ) H(y, \hat{y}) = - \sum_{i} y_i \log(\hat{y}_i) H(y,y^)=iyilog(y^i)
H ( y , y ^ ) = − ( 1 ⋅ log ⁡ ( 0.659 ) + 0 ⋅ log ⁡ ( 0.242 ) + 0 ⋅ log ⁡ ( 0.099 ) ) H(y, \hat{y}) = - (1 \cdot \log(0.659) + 0 \cdot \log(0.242) + 0 \cdot \log(0.099)) H(y,y^)=(1log(0.659)+0log(0.242)+0log(0.099))
= − log ⁡ ( 0.659 ) = 0.416 = - \log(0.659) = 0.416 =log(0.659)=0.416

6. 总损失(Total Loss)

  • 总损失是蒸馏损失和学生损失的加权和:
    Total Loss = α × Student Loss + β × Distillation Loss \text{Total Loss} = \alpha \times \text{Student Loss} + \beta \times \text{Distillation Loss} Total Loss=α×Student Loss+β×Distillation Loss

假设 α = 1 \alpha = 1 α=1 β = 0.5 \beta = 0.5 β=0.5,则总损失为:
Total Loss = 1 × 0.416 + 0.5 × 0.0021 = 0.417 \text{Total Loss} = 1 \times 0.416 + 0.5 \times 0.0021 = 0.417 Total Loss=1×0.416+0.5×0.0021=0.417

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F# 定义教师模型和学生模型
class TeacherModel(nn.Module):def __init__(self):super(TeacherModel, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x)class StudentModel(nn.Module):def __init__(self):super(StudentModel, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x)# 定义蒸馏损失函数
def distillation_loss(soft_labels, soft_predictions, T):soft_labels = F.softmax(soft_labels / T, dim=1)soft_predictions = F.log_softmax(soft_predictions / T, dim=1)loss = F.kl_div(soft_predictions, soft_labels, reduction='batchmean') * (T ** 2)return loss# 定义学生损失函数
def student_loss(hard_labels, hard_predictions):return F.cross_entropy(hard_predictions, hard_labels)# 超参数
alpha = 1.0
beta = 0.5
temperature = 2.0
learning_rate = 0.001
num_epochs = 10# 数据加载器(使用MNIST数据集作为示例)
from torchvision import datasets, transforms
train_loader = torch.utils.data.DataLoader(datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),batch_size=64, shuffle=True)# 初始化模型、优化器
teacher_model = TeacherModel()
student_model = StudentModel()
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)# 假设教师模型已经预训练好,这里直接加载预训练权重
# teacher_model.load_state_dict(torch.load('teacher_model.pth'))# 训练过程
teacher_model.eval()  # 教师模型设为评估模式,不进行训练
student_model.train()  # 学生模型设为训练模式for epoch in range(num_epochs):total_loss = 0for data, target in train_loader:data = data.view(data.size(0), -1)  # 展开图像数据# 教师模型预测with torch.no_grad():teacher_output = teacher_model(data)# 学生模型预测student_output = student_model(data)soft_predictions = student_output / temperaturehard_predictions = student_output# 计算蒸馏损失和学生损失dist_loss = distillation_loss(teacher_output, student_output, temperature)stud_loss = student_loss(target, hard_predictions)# 计算总损失loss = alpha * stud_loss + beta * dist_loss# 优化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}')# 保存学生模型
torch.save(student_model.state_dict(), 'student_model.pth')

代码解释

  1. 模型定义:定义了一个简单的全连接层的教师模型和学生模型。

  2. 蒸馏损失和学生损失函数

    • distillation_loss 计算KL散度作为蒸馏损失。
    • student_loss 计算交叉熵损失作为学生损失。
  3. 超参数

    • alphabeta 分别是学生损失和蒸馏损失的权重。
    • temperature 是温度参数,用于平滑教师模型的输出。
  4. 数据加载:使用MNIST数据集作为示例。

  5. 模型初始化:初始化教师模型和学生模型,并定义优化器。

  6. 训练过程

    • 教师模型设为评估模式,学生模型设为训练模式。
    • 在每个训练周期中,对每个批次数据进行预测,计算损失,并进行优化。
  7. 保存模型:在训练结束后保存学生模型的权重。

该代码示例展示了如何通过PyTorch实现模型蒸馏的训练过程。如果有其他需求或需要进一步解释的地方,请告诉我。

总结

知识蒸馏通过教师模型提供的软标签引导学生模型,使得学生模型不仅关注硬标签的分类准确性,还能从软标签中学习更丰富的类别间关系,从而在模型压缩的同时尽量保留性能。这种方法特别适用于在资源受限的环境中部署高效的深度学习模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42243.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python视觉轨迹几何惯性单元超维计算结构算法

🎯要点 🎯视觉轨迹几何惯性单元超维计算结构算法 | 🎯超维计算结构视觉场景理解 | 🎯超维计算结构算法解瑞文矩阵 | 🎯超维矢量计算递归神经算法 🍪语言内容分比 🍇Python蒙特卡罗惯性导航 蒙…

“来来来,借一步说话”,让前端抓狂的可视化大屏界面。

可视化大屏的前端开发难度要远远高于普通前端,尤其是当设计师搞出一些花哨的效果,很容易让UI和前端陷入口水大战中。 可视化大屏的前端开发相比普通前端开发的难度要高,主要是因为以下几个方面: 1. 数据量大: 可视化…

基于STM32的通用红外遥控器设计: 解码、学习与发射(代码示例)

摘要: 本文将带你使用STM32打造一款功能强大的万能红外遥控器,它可以学习和复制多种红外信号,并通过OLED屏幕和按键实现便捷操作。我们将深入探讨红外通信原理、STM32编程、OLED显示和EEPROM数据存储等关键技术,并提供完整的代码示…

ulimit设置:生成core文件

ulimit -a命令查看使用情况 1. ulimit -c unlimited 可以生成core文件 2.设置core文件名称带进程id(PID),修改"/proc/sys/kernel/core_uses_pid"文件,可以将进程的id作为作为扩展名,文件内容为1表示使用扩…

pyqt5实时调用摄像头并生成图片到缓存然后使用图像识别功能

pyqt5实时调用摄像头并生成图片到缓存然后使用图像识别功能 1、流程 1、进入循环,打开摄像头 2、读取图片 3、通过QImage显示图片 4、将 QImage 转换为 PIL 图像,并保存到缓存 5、从缓存中获取图像数据并进行识别 6、输出识别结果2、导入库 pip install opencv-python需要找…

阶段三:项目开发---搭建项目前后端系统基础架构:任务10:SpringBoot框架的原理和使用

任务描述 1、熟悉SpringBoot框架的原理及使用 2、使用IDEA创建基于SpringBoot、MyBatis、MySQL的Java项目 3、当前任务请在client节点上进行 任务指导 1、SpringBoot框架的选择和原理 2、MyBatis-Plus的选择和原理 3、使用IDEA创建基于SpringBootMyBatis-PlusMySQL的Jav…

使用Spring Security实现细粒度的权限控制

使用Spring Security实现细粒度的权限控制 大家好,我是微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! Spring Security是Spring框架的一个强大和高度可定制的认证和访问控制框架。它用于保护Spring应用程序的部…

前端面试题19(vue性能优化)

Vue.js应用的性能优化是一个多方面的过程,涉及初始化加载、运行时渲染以及用户交互等多个环节。以下是一些关键的Vue性能优化策略,包括详细的说明和示例代码: 1. 懒加载组件 对于大型应用,可以使用懒加载来减少初始加载时间。Vu…

7.6 做题笔记

推荐在 cnblogs 上阅读。 7.6 做题笔记 笔记、梳理、题解合三为一的产物。 P2569 [SCOI2010] 股票交易 考虑 DP,数据允许开到平方级别。 设 f i , j f_{i,j} fi,j​ 表示第 i i i 天持有 j j j 张股票的最大钱。 四种转移: 凭空买入&#xff0c…

vite+vue3整合less教程

1、安装依赖 pnpm install -D less less-loader2、定义全局css变量文件 src/assets/css/global.less :root {--public_background_font_Color: red;--publicHouver_background_Color: #fff;--header_background_Color: #fff;--menu_background: #fff; }3、引入less src/main.…

官网首屏:激发你的小宇宙和第六感,为了漂亮,干就完了。

官网的首屏是指用户打开网站后首先看到的页面,通常是整个网站最重要的一部分。首屏的设计和内容对于吸引用户的注意力、传达品牌形象和价值、促使用户继续浏览和进行交互非常关键。以下是官网首屏的重要性的几个方面: 1. 第一印象: 首屏是用…

微信小程序毕业设计-医院挂号预约系统项目开发实战(附源码+论文)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计…

用C#调用Windows API向指定窗口发送按键消息详解与示例

文章目录 1. 按键消息的定义及功能2. 引入所需的命名空间3. 定义Windows API函数4. 定义发送消息的方法5. 获取窗口句柄6. 调用API发送按键消息7. 使用示例注意事项总结 在C#中调用Windows API向指定窗口发送按键消息是一种常见的操作,这通常用于自动化脚本、游戏辅…

批量文本编辑管理神器:一键修改多处内容,轻松转换编码,助力工作效率飞跃提升!

在信息爆炸的时代,文本处理已成为我们日常工作中不可或缺的一部分。无论是处理文档、整理数据还是编辑资料,都需要对大量的文本进行管理和修改。然而,传统的文本编辑方式往往效率低下,容易出错,难以满足现代工作的高效…

[Day 26] 區塊鏈與人工智能的聯動應用:理論、技術與實踐

數據科學與AI的整合應用 數據科學(Data Science)和人工智能(AI)在現代技術世界中扮演著至關重要的角色。兩者的整合應用能夠為企業和研究人員提供強大的工具,以更好地理解、預測和解決各種複雜的問題。本文將深入探討…

JimuReport 积木报表 v1.7.7 版本发布,一款免费的报表工具

项目介绍 一款免费的数据可视化报表工具,含报表和大屏设计,像搭建积木一样在线设计报表!功能涵盖,数据报表、打印设计、图表报表、大屏设计等! Web 版报表设计器,类似于excel操作风格,通过拖拽完…

二刷算法训练营Day53 | 动态规划(14/17)

目录 详细布置: 1. 392. 判断子序列 2. 115. 不同的子序列 详细布置: 1. 392. 判断子序列 给定字符串 s 和 t ,判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些(也可以不删除)字符而不改变剩余…

【昇思25天学习打卡营打卡指南-第十八天】基于MobileNetv2的垃圾分类

基于MobileNetv2的垃圾分类 MobileNetv2模型原理介绍 MobileNet网络是由Google团队于2017年提出的专注于移动端、嵌入式或IoT设备的轻量级CNN网络,相比于传统的卷积神经网络,MobileNet网络使用深度可分离卷积(Depthwise Separable Convolut…

jQuery UI 主题

jQuery UI 主题 jQuery UI 是一个建立在 jQuery JavaScript 库之上的用户界面交互、特效、小部件和主题框架。它提供了一系列的预构建组件,如拖放、排序、折叠等,以及一个强大的主题系统,允许开发者轻松地自定义和控制用户界面的外观和感觉。 主题概述 jQuery UI 主题是一…

【手写数据库内核组件】01 解析树的结构,不同类型的数据结构组多层的链表树,抽象类型统一引用格式

不同类型的链表 ​专栏内容: postgresql使用入门基础手写数据库toadb并发编程 个人主页:我的主页 管理社区:开源数据库 座右铭:天行健,君子以自强不息;地势坤,君子以厚德载物. 文章目录 不同类型…